Probe Concepts#
Setup#
[ ]:
import importlib.util
DEV = True
if importlib.util.find_spec("google.colab") is not None:
MODE = "colab-dev" if DEV else "colab"
else:
MODE = "local"
[2]:
if MODE == "colab":
%pip install -q lczerolens[hf]
elif MODE == "colab-dev":
!rm -r lczerolens
!git clone https://github.com/Xmaster6y/lczerolens -b main
%pip install -q ./lczerolens --extra hf
Load a Model#
Load a leela network from hf:
[3]:
from lczerolens import LczeroModel
maia_1100 = LczeroModel.from_hf("lczerolens/maia-1100")
maia_1900 = LczeroModel.from_hf("lczerolens/maia-1900")
Import a Game Dataset#
[4]:
from datasets import load_dataset
dataset = load_dataset("lczerolens/tcec-boards", split="train", streaming=True)
dataset
[4]:
IterableDataset({
features: ['gameid', 'moves', 'fen', 'label'],
num_shards: 2
})
[5]:
from datasets import Dataset
work_ds = dataset.shuffle(seed=42).take(1_000)
work_ds = Dataset.from_generator(lambda: (yield from work_ds), features=work_ds.features)
work_ds = work_ds.train_test_split(test_size=0.2)
train_ds = work_ds["train"]
test_ds = work_ds["test"]
work_ds
[5]:
DatasetDict({
train: Dataset({
features: ['gameid', 'moves', 'fen', 'label'],
num_rows: 800
})
test: Dataset({
features: ['gameid', 'moves', 'fen', 'label'],
num_rows: 200
})
})
Create a Concept#
[6]:
from lczerolens.concepts import HasThreat
from lczerolens.data import BoardData
concept = HasThreat(piece="N", relative=True)
train_boards, train_labels = BoardData.concept_collate_fn(list(train_ds), concept)
test_boards, test_labels = BoardData.concept_collate_fn(list(test_ds), concept)
print(f"Positive examples (train): {sum(train_labels)}")
print(f"Positive examples (test): {sum(test_labels)}")
Positive examples (train): 97
Positive examples (test): 27
Train Linear Probes#
[7]:
import torch
from tensordict import TensorDict
from tdhook.latent.probing import Probing, SklearnProbeManager
from sklearn.linear_model import LogisticRegression
probe_manager = SklearnProbeManager(LogisticRegression, {"max_iter": 1000}, lambda x, y: concept.compute_metrics(x, y))
with Probing("block\d/conv2/relu", probe_manager.probe_factory, additional_keys=["labels", "step_type"]).prepare(
maia_1100
) as hooked_module:
with torch.no_grad():
train_inputs = TensorDict(
{
"board": maia_1100.prepare_boards(*train_boards),
"labels": torch.tensor(train_labels),
"step_type": "fit",
},
batch_size=len(train_boards),
)
hooked_module(train_inputs)
test_inputs = TensorDict(
{
"board": maia_1100.prepare_boards(*test_boards),
"labels": torch.tensor(test_labels),
"step_type": "predict",
},
batch_size=len(test_boards),
)
hooked_module(test_inputs)
maia_1100_metrics = probe_manager.predict_metrics
probe_manager.reset_probes()
probe_manager.reset_metrics()
with Probing("block\d/conv2/relu", probe_manager.probe_factory, additional_keys=["labels", "step_type"]).prepare(
maia_1900
) as hooked_module:
with torch.no_grad():
train_inputs = TensorDict(
{
"board": maia_1900.prepare_boards(*train_boards),
"labels": torch.tensor(train_labels),
"step_type": "fit",
},
batch_size=len(train_boards),
)
hooked_module(train_inputs)
test_inputs = TensorDict(
{
"board": maia_1900.prepare_boards(*test_boards),
"labels": torch.tensor(test_labels),
"step_type": "predict",
},
batch_size=len(test_boards),
)
hooked_module(test_inputs)
maia_1900_metrics = probe_manager.predict_metrics
Render the Results#
[8]:
import re
import matplotlib.pyplot as plt
# --------------------------------------------------
def get_layer_number(key: str) -> int:
return int(re.search(r"block(\d+)/conv2/relu", key).group(1))
def collect(metrics_dict):
"""Return {metric_name: {"layers": [...], "values": [...]}}."""
out = {}
for k, v in metrics_dict.items(): # k = layer-key, v = {metric: value}
layer = get_layer_number(k)
for name, val in v.items():
out.setdefault(name, {"layers": [], "values": []})
out[name]["layers"].append(layer)
out[name]["values"].append(val)
# sort every metric by layer
for d in out.values():
layers, vals = zip(*sorted(zip(d["layers"], d["values"])))
d["layers"], d["values"] = layers, vals
return out
metrics_1100 = collect(maia_1100_metrics)
metrics_1900 = collect(maia_1900_metrics)
fig, axes = plt.subplots(2, 2, figsize=(12, 8), sharex=True)
axes = axes.ravel()
for ax, metric in zip(axes, metrics_1100):
ax.plot(metrics_1100[metric]["layers"], metrics_1100[metric]["values"], marker="o", label="Maia 1100")
ax.plot(metrics_1900[metric]["layers"], metrics_1900[metric]["values"], marker="s", label="Maia 1900")
ax.set_title(metric.replace("_", " ").title())
ax.set_xlabel("Layer")
ax.set_ylabel(metric)
ax.grid(True)
ax.legend()
for ax in axes[len(metrics_1100) :]:
ax.axis("off")
fig.suptitle("Probe Metrics per Layer (Train vs Predict)", fontsize=14)
plt.tight_layout()
plt.show()
[ ]: