Evaluate Models on Puzzles#

Open In Colab

Setup#

[1]:
MODE = "local"  # "colab" | "colab-dev" | "local"
[2]:
if MODE == "colab":
    !pip install -q lczerolens
elif MODE == "colab-dev":
    !rm -r lczerolens
    !git clone https://github.com/Xmaster6y/lczerolens -b main
    !pip install -q ./lczerolens
[3]:
!gdown 15__7FHvIR5-JbJvDg2eGUhIPZpkYyM7X -O lc0-19-1876.onnx
!gdown 1CvMyX3KuYxCJUKz9kOb9VX8zIkfISALd -O lc0-19-4508.onnx
Downloading...
From: https://drive.google.com/uc?id=15__7FHvIR5-JbJvDg2eGUhIPZpkYyM7X
To: /Users/xmaster/Work/lczerolens/docs/source/notebooks/features/lc0-19-1876.onnx
100%|██████████████████████████████████████| 97.1M/97.1M [00:02<00:00, 48.0MB/s]
Downloading...
From: https://drive.google.com/uc?id=1CvMyX3KuYxCJUKz9kOb9VX8zIkfISALd
To: /Users/xmaster/Work/lczerolens/docs/source/notebooks/features/lc0-19-4508.onnx
100%|██████████████████████████████████████| 97.1M/97.1M [00:05<00:00, 16.8MB/s]

Load a Model#

Load a leela network from file (already converted to onnx):

[1]:
from lczerolens import LczeroModel

strong_model = LczeroModel.from_path("lc0-19-4508.onnx")
weak_model = LczeroModel.from_path("lc0-19-1876.onnx")
/Users/xmaster/Work/lczerolens/.venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm

Import a Game Dataset#

[2]:
from datasets import load_dataset

dataset = load_dataset("lczerolens/tcec-games", split="train")
dataset
[2]:
Dataset({
    features: ['gameid', 'moves'],
    num_rows: 23297
})
[3]:
from lczerolens.play import Game


def boards_from_dict(batch):
    new_batch = []
    for game_tuple in zip(*batch.values()):
        game = Game.from_dict(dict(zip(batch.keys(), game_tuple)))
        new_batch.extend(game.to_boards(skip_book_exit=True, output_dict=True))
    return {k: [d[k] for d in new_batch] for k in new_batch[0].keys()}


board_dataset = dataset.select(range(1000)).map(boards_from_dict, batched=True, batch_size=100)
board_datasetdict = board_dataset.train_test_split(test_size=0.1, seed=42)
board_datasetdict
[3]:
DatasetDict({
    train: Dataset({
        features: ['gameid', 'moves', 'fen'],
        num_rows: 115561
    })
    test: Dataset({
        features: ['gameid', 'moves', 'fen'],
        num_rows: 12841
    })
})
[24]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(
    board_datasetdict["train"], batch_size=100, shuffle=True, collate_fn=Game.board_collate_fn
)
first_batch = next(iter(train_dataloader))
print(first_batch)
([LczeroBoard('1R6/q3r1k1/4p2p/3p1pp1/1Q1P1P2/4PK1P/5P2/8 b - - 5 104'), LczeroBoard('8/8/7R/p1k4P/rp6/6K1/6P1/8 w - - 0 51'), LczeroBoard('2r5/8/p2pQppk/1P2b2p/P2NP1qP/4PRP1/5RK1/2r5 w - - 2 39'), LczeroBoard('8/4k3/6P1/p2K1B2/P7/8/8/6b1 w - - 65 120'), LczeroBoard('8/8/4pp2/6k1/pp2P2p/3q4/P3R1PP/6RK w - - 22 49'), LczeroBoard('6Q1/1K6/1P6/5k2/1P6/3N4/7q/6b1 b - - 0 67'), LczeroBoard('rb2r1k1/1p3ppn/2p4p/2Qp4/p2P2PP/4PP2/PP3BBq/1R2RK2 w - - 8 28'), LczeroBoard('3n4/p2k1p2/P1p2R2/1p2P1Q1/1P2q3/2pr1NP1/5PK1/8 w - - 6 52'), LczeroBoard('3r2r1/pq2bk2/2p2p2/Pp1n1p1b/3PN2P/1B3P2/1P1B1QR1/3R3K w - - 0 29'), LczeroBoard('r1b1kbr1/pp1p3p/n2Pp1p1/1N3pn1/q1P4P/PpB1P3/3N1PP1/R2QKB1R b KQq - 0 15'), LczeroBoard('8/5p2/2bN1pk1/7p/3K3P/6P1/8/8 b - - 66 183'), LczeroBoard('8/4K3/3P2n1/8/3k4/8/8/7B w - - 1 224'), LczeroBoard('2r3k1/1p2ppbp/pn1p2p1/3P4/1q1NP3/1P2BP2/P1Q3PP/3R2K1 w - - 1 21'), LczeroBoard('3r1rk1/pb2q1pp/1pnpp3/1N6/2P1BP2/2P1Q1P1/P6P/R2R2K1 b - - 6 25'), LczeroBoard('1K6/P3kp1p/2R5/6p1/8/r5P1/7P/8 w - - 5 59'), LczeroBoard('8/8/2rBp3/p1P1Pk2/P4b2/1P5p/5R2/7K w - - 2 82'), LczeroBoard('8/8/5pk1/8/8/2R5/7r/4K3 b - - 3 127'), LczeroBoard('8/8/5k2/2pp1P1p/5K1P/8/3Nb3/8 b - - 5 62'), LczeroBoard('3Rk2r/p1p1b1pp/1p2p3/2p1Pn2/4N3/2P3P1/PP3P1P/R1B3K1 b - - 0 16'), LczeroBoard('5R2/8/4k3/p2p4/P7/3Kp3/3b4/8 b - - 5 181'), LczeroBoard('2r1r1k1/2nnq2p/pp1p2p1/2bP1p2/P3PP2/2BQ2PP/4N1BK/1R3R2 w - - 2 26'), LczeroBoard('2rnr1k1/ppqbpp1p/5bp1/6N1/3PB2P/1QP5/P2B1PP1/1R2R1K1 b - - 0 18'), LczeroBoard('8/8/5k1p/2N1p1p1/5p2/5PP1/1r2BK1P/8 w - - 4 44'), LczeroBoard('3r1k2/8/8/5Np1/1pB4p/7P/4K1P1/8 w - - 6 53'), LczeroBoard('5k2/N4p2/3bp2p/1B1p2p1/2q3P1/R4P2/5PK1/8 b - - 10 50'), LczeroBoard('rn1q1rk1/pb2bppp/5n2/2pp4/8/1PB2NP1/P1QNPPBP/R4RK1 w - - 0 14'), LczeroBoard('1R6/5p2/4p1k1/8/3Pp1Pb/4P3/r2BK3/8 w - - 63 106'), LczeroBoard('6k1/3R1p2/5p1p/8/2KN3P/r4bP1/8/8 w - - 0 51'), LczeroBoard('r2r4/p3kp2/5p1p/3Rp3/Ppp1P1P1/8/1PP2PP1/3R2K1 w - - 0 26'), LczeroBoard('2r3k1/1p3ppp/1B2p1q1/4P3/1bP2Q2/1r6/5RPP/3R2K1 w - - 2 30'), LczeroBoard('b1r3k1/1q2npp1/p6p/1pP5/4pPB1/P2rB1PP/1Q6/2R1R1K1 w - - 3 35'), LczeroBoard('2r5/4p1kp/4Pp2/8/2r1Q1PP/8/4PK2/8 w - - 3 43'), LczeroBoard('5rk1/R7/1p5p/2n2B2/2b2r2/P1R3N1/6P1/6K1 w - - 0 32'), LczeroBoard('1rbq1rk1/p5pp/2p1n3/1pPp1p2/3Pn3/P1NB1N2/1P3PPP/R2Q1RK1 b - - 1 16'), LczeroBoard('3q4/2p2b1k/1bp2ppp/2n5/2N1Pp2/2P4P/P1QN2PB/5K2 w - - 2 32'), LczeroBoard('6k1/2n2pp1/1pN4p/8/8/2P3P1/r4PP1/5RK1 w - - 0 29'), LczeroBoard('r3r1k1/1pq1ppbp/2p2np1/p7/2PP4/1P3NP1/PB2QP1P/R3RK2 b - - 1 19'), LczeroBoard('1r1r4/3nkn2/ppbppp2/6p1/N1P1P2p/1P3P1P/P1NRBKP1/3R4 w - - 3 56'), LczeroBoard('2B5/8/8/4k1p1/PN2n3/3p2PP/4b1K1/8 w - - 0 62'), LczeroBoard('6R1/3b1k2/8/1r6/3BK3/4N3/p7/8 w - - 31 70'), LczeroBoard('1r2k2r/5qp1/p2bpn2/6B1/NpP1p3/1P6/P5QP/4RRK1 w k - 2 27'), LczeroBoard('2k5/6R1/3R4/8/2Kp1rP1/5P2/8/1r6 w - - 4 118'), LczeroBoard('8/8/4k3/2R3p1/5b2/5P2/6Kp/8 b - - 91 109'), LczeroBoard('r1br2k1/p1q2pp1/2n1p2p/2N1b3/Pp6/1B1P1Q2/1PP2PPP/R1B1R2K b - - 2 18'), LczeroBoard('8/r4pp1/3p4/2pk2Pp/p6P/R7/4K3/8 w - - 0 66'), LczeroBoard('r3k2r/3n1pp1/1p2p1p1/2bpP3/3N1P2/P2B4/1PRB2PP/5R1K b kq - 0 25'), LczeroBoard('5r1k/p4qp1/1r5p/1p1PR3/8/3Q2P1/3R1P1P/6K1 b - - 8 33'), LczeroBoard('7k/7p/8/r5P1/4B1K1/8/8/8 w - - 57 75'), LczeroBoard('3rr1k1/1p3q1p/p1p3p1/N7/1P6/P6P/1Q2nPB1/B1R4K b - - 4 31'), LczeroBoard('8/p1rq2bk/Rpn3n1/2r2p1p/3p2pP/1Q1P2P1/3BPPB1/R4NK1 b - - 13 36'), LczeroBoard('r2q3r/1p1b1kpp/1b2p3/p2pPp2/P2P1P2/4B3/1P1QB1PP/R4RK1 b - - 1 20'), LczeroBoard('6r1/ppk5/4p3/3pPp2/n1pP1P1p/P1P4B/2PK1R1P/8 b - - 1 32'), LczeroBoard('4K3/8/5k2/5n2/8/5R2/8/3r4 w - - 22 93'), LczeroBoard('8/3P4/1b3k2/7p/1P3P2/3NR3/6K1/3r4 b - - 2 74'), LczeroBoard('R7/2r5/3p1b2/3k1p1p/3p1P1P/3P2P1/8/3KB3 b - - 36 83'), LczeroBoard('r3kb1r/pb1n1ppp/1q2p3/1ppP4/8/2N3P1/PPQ1PPBP/R1B1K2R b KQkq - 0 12'), LczeroBoard('8/1p4k1/2p3p1/4Pp1p/PbBp1P2/5KPP/8/8 w - - 2 42'), LczeroBoard('8/4B3/8/4nk1p/2p2p1P/2P2P2/4K3/8 b - - 75 100'), LczeroBoard('4r1k1/5n2/8/1Q4p1/4q3/1PB5/1KP1p3/7R b - - 1 54'), LczeroBoard('1nr5/2r2pk1/4p1pp/p3P3/Pp1R1P1P/6P1/1P1R2K1/3B4 b - - 0 40'), LczeroBoard('8/p4pk1/1p4p1/4K3/1R1P4/4r3/P7/8 w - - 1 50'), LczeroBoard('4r1k1/3bP3/6p1/3r4/1P2R1PK/P4R2/8/8 w - - 35 69'), LczeroBoard('3r4/1b1pnr1k/1qn1p1pp/1p3p2/1P3N2/2NPPPP1/1Q4BP/2R1R1K1 b - - 0 25'), LczeroBoard('5b2/p6k/Pp6/1P3N1R/5KP1/3P4/4rP2/8 b - - 0 63'), LczeroBoard('7r/4k2r/2bp1p1p/p1p1pPp1/P1P1P1P1/1P4P1/1KB4R/7R b - - 16 111'), LczeroBoard('rq3rk1/1p5p/p1p1pbp1/B2p1pnP/2PP4/1P2P3/P1Q2PP1/2RB1RK1 b - - 2 20'), LczeroBoard('2r1rqk1/1b5p/p2pPnp1/1NpP2B1/1P3Pn1/P4B1R/3Q4/R3K3 w Q - 0 27'), LczeroBoard('5k2/r4pp1/7p/3R4/4B1PP/1n6/5PK1/8 b - - 1 36'), LczeroBoard('8/5pkp/3p1pp1/3P4/4P1P1/p2P3P/2Q5/q2K4 w - - 6 41'), LczeroBoard('1k6/5r2/2p1pb2/1pPr1p1p/pP1PpPpP/PqB1P1P1/2R2Q2/R5K1 w - - 98 101'), LczeroBoard('r3k2r/1pq1bpp1/p1bp1n1p/2n1p3/P1B1P3/2N2P2/1PPNQ1PP/R1BR2K1 b kq - 2 14'), LczeroBoard('1q5k/5rpp/ppb2n2/2pNpP2/P1P1P2Q/1P3P2/4B3/3RK3 b - - 10 50'), LczeroBoard('r1nq1rk1/3b2bp/p4np1/2pPNpB1/1pP2p1Q/8/PP2B1PP/R2N1RK1 w - - 2 20'), LczeroBoard('8/4R3/7r/1P3k2/2P1N2p/7b/1K6/8 w - - 1 60'), LczeroBoard('5r2/4qppk/1b6/1P1B2R1/4PP1P/pQ3K2/P6P/8 w - - 7 54'), LczeroBoard('8/3kp3/4p2p/6rP/1PK5/2P4R/1P6/8 w - - 3 71'), LczeroBoard('r1bq1rk1/1p2bppp/p1n1pn2/8/3P1B2/1BN2N2/PP3PPP/R2QR1K1 b - - 4 11'), LczeroBoard('r2r4/5qbk/bppp3p/p3p1p1/2PnPn2/PPNNBP2/1R1Q2PP/4RB1K b - - 1 30'), LczeroBoard('r2qk2r/pbppbpp1/1pn1n3/3N3p/3p2Q1/3B1N2/PPP2PPP/R1B1R1K1 w kq - 0 12'), LczeroBoard('4k3/7p/3p2pP/2nPp1P1/qp2P3/8/PQ3NK1/8 w - - 2 71'), LczeroBoard('5kb1/Q3b3/2B3p1/4Pp1p/1p1N1P1P/1p1P2PK/1P6/5q2 w - - 7 117'), LczeroBoard('4r3/8/2k5/2PRp2p/4P1pP/6P1/2K5/8 w - - 15 61'), LczeroBoard('2rqr1k1/3bp1bp/1p1p1pp1/p1nN1P2/2PNP3/1P1nB1PP/P2Q2B1/3R1RK1 b - - 5 23'), LczeroBoard('6k1/2r3p1/5p1p/5P1P/6PK/8/1R6/8 w - - 1 339'), LczeroBoard('8/8/1p1p2p1/3Pp1k1/2P4p/q6P/6P1/5QK1 w - - 0 65'), LczeroBoard('3R4/6k1/8/5r1p/7P/6PK/8/8 w - - 28 91'), LczeroBoard('r1bq1rk1/1p2b1pp/2np4/1Nn1pp2/p1P1P3/2N1BP2/PP2B1PP/R2QK2R b KQ - 1 14'), LczeroBoard('2b3r1/5k2/p2bpp2/PpNp1q1p/1PnP1P2/3N2PP/4R1Q1/4B1K1 w - - 1 48'), LczeroBoard('8/5pk1/p3p1p1/P5p1/Q5P1/5n2/5P2/1q3BK1 w - - 16 52'), LczeroBoard('2q5/p2r4/6k1/1QpN2p1/2PbPpP1/1R3P1p/7P/4K3 w - - 71 85'), LczeroBoard('8/1R6/6k1/6P1/8/8/1p4K1/1r6 b - - 0 99'), LczeroBoard('r5k1/1pp2p2/p2p2pQ/7b/2P1R3/1P4P1/P4PB1/q5K1 w - - 2 23'), LczeroBoard('8/8/8/3R2k1/8/5K2/8/r7 b - - 6 60'), LczeroBoard('8/3k4/2p4r/1p1p1p1P/5P2/P1R1PK2/4r3/7R b - - 52 147'), LczeroBoard('r1b1rn2/pp3k1p/2pp2p1/5p2/3B1P2/3B3R/PPP3PP/3R2K1 w - - 0 22'), LczeroBoard('r1bq1rk1/1p2bppp/1pnp4/2p1p3/P1B1P3/2PP4/1P3PPP/R1BQKN1R b KQ - 0 11'), LczeroBoard('r7/1R3pk1/1N4p1/5rPp/5P2/8/5R1K/8 b - - 0 53'), LczeroBoard('4r3/8/5Qpk/3p4/2p5/1p4qp/4R3/1B5K w - - 8 120'), LczeroBoard('2kr3r/2p1bppp/p1np4/1p3q2/5B2/1PP2N2/1PPQ1PPP/1K1R3R w - - 1 15'), LczeroBoard('8/5pp1/8/5Pk1/Ppp1R1P1/3q4/PP5K/4R3 w - - 0 43')], {})

Create a Concept Dataset#

[5]:
from lczerolens.concepts.threat import HasThreat
from lczerolens.lenses import ActivationLens

concept = HasThreat(piece="Q", relative=True)
[30]:
import numpy as np


def get_activations_and_labels(model, module_name, dataloader, concept, n_batches=3):
    lens = ActivationLens(pattern=module_name)
    activations_list = []
    labels_list = []
    i = 0
    for result in lens.analyse_batched(model, dataloader):
        activations_list.append(result[module_name + "_output"].detach().cpu().numpy())
        i += 1
        if i > n_batches:
            break
    i = 0
    for boards, _ in dataloader:
        for board in boards:
            labels_list.append(concept.compute_label(board))
        i += 1
        if i > n_batches:
            break
    return np.concatenate(activations_list), np.array(labels_list)
[49]:
train_activations, train_labels = get_activations_and_labels(
    strong_model, "block18/conv2/relu", train_dataloader, concept, n_batches=10
)
train_activations.shape
[49]:
(1100, 256, 8, 8)
[57]:
X_train = train_activations.reshape(train_activations.shape[0], -1)
Y_train = train_labels
(X_train.shape, Y_train.shape)
[57]:
((1100, 16384), (1100,))

Train a Linear Probe#

[55]:
from sklearn.linear_model import LogisticRegression

probe = LogisticRegression(max_iter=10000)
probe.fit(X_train, Y_train)
[55]:
LogisticRegression(max_iter=10000)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
[56]:
print(HasThreat.compute_metrics(probe.predict(X_train), Y_train))
{'accuracy': 0.9563636363636364, 'precision': 0.0, 'recall': 0.0, 'f1': 0.0}
/Users/xmaster/Work/lczerolens/.venv/lib/python3.11/site-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
[35]:
test_dataloader = DataLoader(board_datasetdict["test"], batch_size=100, shuffle=True, collate_fn=Game.board_collate_fn)
test_activations, test_labels = get_activations_and_labels(
    strong_model, "block18/conv2/relu", test_dataloader, concept, n_batches=3
)
[53]:
X_test = test_activations.reshape(test_activations.shape[0], -1)
Y_test = test_labels
print(HasThreat.compute_metrics(probe.predict(X_test), Y_test))
{'accuracy': 0.965, 'precision': 0.0, 'recall': 0.0, 'f1': 0.0}
/Users/xmaster/Work/lczerolens/.venv/lib/python3.11/site-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
[38]:
Y_test.sum()
[38]:
np.int64(14)

Evaluate the Probe#

[ ]:
# Generic eval using ProbingLens