Run Models on GPU#

Open In Colab

Setup#

[1]:
MODE = "local"  # "colab" | "colab-dev" | "local"
[2]:
if MODE == "colab":
    !pip install -q lczerolens
elif MODE == "colab-dev":
    !rm -r lczerolens
    !git clone https://github.com/Xmaster6y/lczerolens -b main
    !pip install -q ./lczerolens
[3]:
!gdown 1cxC8_8vw7akfPyc9cZxwaAbLG2Zl4XiT -O lc0-10-4238.onnx
[4]:
import torch

if not torch.cuda.is_available():
    raise RuntimeError("This notebook requires a GPU")

Load the Model and the Dataset#

[5]:
from datasets import load_dataset
from lczerolens import LczeroModel

model = LczeroModel.from_path("lc0-10-4238.onnx").to("cuda")

dataset = load_dataset("lczero-planning/boards")
dataset
[5]:
DatasetDict({
    train: Dataset({
        features: ['gameid', 'moves', 'fen'],
        num_rows: 2231423
    })
    test: Dataset({
        features: ['gameid', 'moves', 'fen'],
        num_rows: 557856
    })
})
[6]:
model
[6]:
GraphModule(
  (inputconv): Conv2d(112, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (inputconv/relu): ReLU()
  (block0/conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (block0/conv1/relu): ReLU()
  (block0/conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (block0/conv2/se/pooled): OnnxGlobalAveragePoolWithKnownInputShape()
  (initializers): Module()
  (block0/conv2/se/squeeze): OnnxSqueezeDynamicAxes()
  (block0/conv2/se/matmul1): OnnxMatMul()
  (block0/conv2/se/add1): OnnxBinaryMathOperation()
  (block0/conv2/se/relu): ReLU()
  (block0/conv2/se/matmul2): OnnxMatMul()
  (block0/conv2/se/add2): OnnxBinaryMathOperation()
  (block0/conv2/se/reshape): OnnxReshape()
  (block0/conv2/se/split): OnnxSplit13()
  (block0/conv2/se/sigmoid): Sigmoid()
  (block0/conv2/se/mul): OnnxBinaryMathOperation()
  (block0/conv2/se/add3): OnnxBinaryMathOperation()
  (block0/conv2/mixin): OnnxBinaryMathOperation()
  (block0/conv2/relu): ReLU()
  (block1/conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (block1/conv1/relu): ReLU()
  (block1/conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (block1/conv2/se/pooled): OnnxGlobalAveragePoolWithKnownInputShape()
  (block1/conv2/se/squeeze): OnnxSqueezeDynamicAxes()
  (block1/conv2/se/matmul1): OnnxMatMul()
  (block1/conv2/se/add1): OnnxBinaryMathOperation()
  (block1/conv2/se/relu): ReLU()
  (block1/conv2/se/matmul2): OnnxMatMul()
  (block1/conv2/se/add2): OnnxBinaryMathOperation()
  (block1/conv2/se/reshape): OnnxReshape()
  (block1/conv2/se/split): OnnxSplit13()
  (block1/conv2/se/sigmoid): Sigmoid()
  (block1/conv2/se/mul): OnnxBinaryMathOperation()
  (block1/conv2/se/add3): OnnxBinaryMathOperation()
  (block1/conv2/mixin): OnnxBinaryMathOperation()
  (block1/conv2/relu): ReLU()
  (block2/conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (block2/conv1/relu): ReLU()
  (block2/conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (block2/conv2/se/pooled): OnnxGlobalAveragePoolWithKnownInputShape()
  (block2/conv2/se/squeeze): OnnxSqueezeDynamicAxes()
  (block2/conv2/se/matmul1): OnnxMatMul()
  (block2/conv2/se/add1): OnnxBinaryMathOperation()
  (block2/conv2/se/relu): ReLU()
  (block2/conv2/se/matmul2): OnnxMatMul()
  (block2/conv2/se/add2): OnnxBinaryMathOperation()
  (block2/conv2/se/reshape): OnnxReshape()
  (block2/conv2/se/split): OnnxSplit13()
  (block2/conv2/se/sigmoid): Sigmoid()
  (block2/conv2/se/mul): OnnxBinaryMathOperation()
  (block2/conv2/se/add3): OnnxBinaryMathOperation()
  (block2/conv2/mixin): OnnxBinaryMathOperation()
  (block2/conv2/relu): ReLU()
  (block3/conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (block3/conv1/relu): ReLU()
  (block3/conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (block3/conv2/se/pooled): OnnxGlobalAveragePoolWithKnownInputShape()
  (block3/conv2/se/squeeze): OnnxSqueezeDynamicAxes()
  (block3/conv2/se/matmul1): OnnxMatMul()
  (block3/conv2/se/add1): OnnxBinaryMathOperation()
  (block3/conv2/se/relu): ReLU()
  (block3/conv2/se/matmul2): OnnxMatMul()
  (block3/conv2/se/add2): OnnxBinaryMathOperation()
  (block3/conv2/se/reshape): OnnxReshape()
  (block3/conv2/se/split): OnnxSplit13()
  (block3/conv2/se/sigmoid): Sigmoid()
  (block3/conv2/se/mul): OnnxBinaryMathOperation()
  (block3/conv2/se/add3): OnnxBinaryMathOperation()
  (block3/conv2/mixin): OnnxBinaryMathOperation()
  (block3/conv2/relu): ReLU()
  (block4/conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (block4/conv1/relu): ReLU()
  (block4/conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (block4/conv2/se/pooled): OnnxGlobalAveragePoolWithKnownInputShape()
  (block4/conv2/se/squeeze): OnnxSqueezeDynamicAxes()
  (block4/conv2/se/matmul1): OnnxMatMul()
  (block4/conv2/se/add1): OnnxBinaryMathOperation()
  (block4/conv2/se/relu): ReLU()
  (block4/conv2/se/matmul2): OnnxMatMul()
  (block4/conv2/se/add2): OnnxBinaryMathOperation()
  (block4/conv2/se/reshape): OnnxReshape()
  (block4/conv2/se/split): OnnxSplit13()
  (block4/conv2/se/sigmoid): Sigmoid()
  (block4/conv2/se/mul): OnnxBinaryMathOperation()
  (block4/conv2/se/add3): OnnxBinaryMathOperation()
  (block4/conv2/mixin): OnnxBinaryMathOperation()
  (block4/conv2/relu): ReLU()
  (block5/conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (block5/conv1/relu): ReLU()
  (block5/conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (block5/conv2/se/pooled): OnnxGlobalAveragePoolWithKnownInputShape()
  (block5/conv2/se/squeeze): OnnxSqueezeDynamicAxes()
  (block5/conv2/se/matmul1): OnnxMatMul()
  (block5/conv2/se/add1): OnnxBinaryMathOperation()
  (block5/conv2/se/relu): ReLU()
  (block5/conv2/se/matmul2): OnnxMatMul()
  (block5/conv2/se/add2): OnnxBinaryMathOperation()
  (block5/conv2/se/reshape): OnnxReshape()
  (block5/conv2/se/split): OnnxSplit13()
  (block5/conv2/se/sigmoid): Sigmoid()
  (block5/conv2/se/mul): OnnxBinaryMathOperation()
  (block5/conv2/se/add3): OnnxBinaryMathOperation()
  (block5/conv2/mixin): OnnxBinaryMathOperation()
  (block5/conv2/relu): ReLU()
  (block6/conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (block6/conv1/relu): ReLU()
  (block6/conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (block6/conv2/se/pooled): OnnxGlobalAveragePoolWithKnownInputShape()
  (block6/conv2/se/squeeze): OnnxSqueezeDynamicAxes()
  (block6/conv2/se/matmul1): OnnxMatMul()
  (block6/conv2/se/add1): OnnxBinaryMathOperation()
  (block6/conv2/se/relu): ReLU()
  (block6/conv2/se/matmul2): OnnxMatMul()
  (block6/conv2/se/add2): OnnxBinaryMathOperation()
  (block6/conv2/se/reshape): OnnxReshape()
  (block6/conv2/se/split): OnnxSplit13()
  (block6/conv2/se/sigmoid): Sigmoid()
  (block6/conv2/se/mul): OnnxBinaryMathOperation()
  (block6/conv2/se/add3): OnnxBinaryMathOperation()
  (block6/conv2/mixin): OnnxBinaryMathOperation()
  (block6/conv2/relu): ReLU()
  (block7/conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (block7/conv1/relu): ReLU()
  (block7/conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (block7/conv2/se/pooled): OnnxGlobalAveragePoolWithKnownInputShape()
  (block7/conv2/se/squeeze): OnnxSqueezeDynamicAxes()
  (block7/conv2/se/matmul1): OnnxMatMul()
  (block7/conv2/se/add1): OnnxBinaryMathOperation()
  (block7/conv2/se/relu): ReLU()
  (block7/conv2/se/matmul2): OnnxMatMul()
  (block7/conv2/se/add2): OnnxBinaryMathOperation()
  (block7/conv2/se/reshape): OnnxReshape()
  (block7/conv2/se/split): OnnxSplit13()
  (block7/conv2/se/sigmoid): Sigmoid()
  (block7/conv2/se/mul): OnnxBinaryMathOperation()
  (block7/conv2/se/add3): OnnxBinaryMathOperation()
  (block7/conv2/mixin): OnnxBinaryMathOperation()
  (block7/conv2/relu): ReLU()
  (block8/conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (block8/conv1/relu): ReLU()
  (block8/conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (block8/conv2/se/pooled): OnnxGlobalAveragePoolWithKnownInputShape()
  (block8/conv2/se/squeeze): OnnxSqueezeDynamicAxes()
  (block8/conv2/se/matmul1): OnnxMatMul()
  (block8/conv2/se/add1): OnnxBinaryMathOperation()
  (block8/conv2/se/relu): ReLU()
  (block8/conv2/se/matmul2): OnnxMatMul()
  (block8/conv2/se/add2): OnnxBinaryMathOperation()
  (block8/conv2/se/reshape): OnnxReshape()
  (block8/conv2/se/split): OnnxSplit13()
  (block8/conv2/se/sigmoid): Sigmoid()
  (block8/conv2/se/mul): OnnxBinaryMathOperation()
  (block8/conv2/se/add3): OnnxBinaryMathOperation()
  (block8/conv2/mixin): OnnxBinaryMathOperation()
  (block8/conv2/relu): ReLU()
  (block9/conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (block9/conv1/relu): ReLU()
  (block9/conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (block9/conv2/se/pooled): OnnxGlobalAveragePoolWithKnownInputShape()
  (block9/conv2/se/squeeze): OnnxSqueezeDynamicAxes()
  (block9/conv2/se/matmul1): OnnxMatMul()
  (block9/conv2/se/add1): OnnxBinaryMathOperation()
  (block9/conv2/se/relu): ReLU()
  (block9/conv2/se/matmul2): OnnxMatMul()
  (block9/conv2/se/add2): OnnxBinaryMathOperation()
  (block9/conv2/se/reshape): OnnxReshape()
  (block9/conv2/se/split): OnnxSplit13()
  (block9/conv2/se/sigmoid): Sigmoid()
  (block9/conv2/se/mul): OnnxBinaryMathOperation()
  (block9/conv2/se/add3): OnnxBinaryMathOperation()
  (block9/conv2/mixin): OnnxBinaryMathOperation()
  (block9/conv2/relu): ReLU()
  (policy/conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (policy/conv1/relu): ReLU()
  (policy/conv2): Conv2d(128, 80, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (policy/flatten): OnnxReshape()
  (output/policy): OnnxGather()
  (value/conv): Conv2d(128, 32, kernel_size=(1, 1), stride=(1, 1))
  (value/conv/relu): ReLU()
  (value/reshape): OnnxReshape()
  (value/dense1/matmul): OnnxMatMul()
  (value/dense1/add): OnnxBinaryMathOperation()
  (value/dense1/relu): ReLU()
  (value/dense2/matmul): OnnxMatMul()
  (value/dense2/add): OnnxBinaryMathOperation()
  (output/wdl): Softmax(dim=1)
  (mlh/conv): Conv2d(128, 8, kernel_size=(1, 1), stride=(1, 1))
  (mlh/conv/relu): ReLU()
  (mlh/reshape): OnnxReshape()
  (mlh/dense1/matmul): OnnxMatMul()
  (mlh/dense1/add): OnnxBinaryMathOperation()
  (mlh/dense1/relu): ReLU()
  (mlh/dense2/matmul): OnnxMatMul()
  (mlh/dense2/add): OnnxBinaryMathOperation()
  (mlh/dense2/relu): ReLU()
  (output/mlh): OnnxCopyIdentity()
)

Setup Activation Buffer#

[13]:
import chess
import einops


def collate_fn(batch):
    boards = []
    for x in batch:
        fen = x["fen"]
        moves = x["moves"]
        board = chess.Board(fen)
        for move in moves:
            board.push(chess.Move.from_uci(move))
        boards.append(board)
    return boards


def _compute_fn(batch, model, lens):
    boards = batch
    storage = lens.analyse(*boards, model=model)[0]
    if len(storage.keys()) != 1:
        raise NotImplementedError
    acts = next(iter(storage.values()))
    return einops.rearrange(acts, "b c h w -> (b h w) c")
[17]:
from lczerolens.lenses import ActivationLens, ActivationBuffer

MODULE_NAME = "block9/conv2/relu"
LENS = ActivationLens(MODULE_NAME)
N_BATCHES_IN_BUFFER = 15
COMPUTE_BATCH_SIZE = 1_000
TRAIN_BATCH_SIZE = 10_000


def compute_fn(batch, model):
    return _compute_fn(batch, model, LENS)
[18]:
train_buffer = ActivationBuffer(
    model,
    dataset["train"],
    compute_fn,
    N_BATCHES_IN_BUFFER,
    COMPUTE_BATCH_SIZE,
    TRAIN_BATCH_SIZE,
    dataloader_kwargs={"collate_fn": collate_fn},
)

val_buffer = ActivationBuffer(
    model,
    dataset["test"],
    compute_fn,
    N_BATCHES_IN_BUFFER,
    COMPUTE_BATCH_SIZE,
    TRAIN_BATCH_SIZE,
    dataloader_kwargs={"collate_fn": collate_fn},
)
[19]:
acts = next(iter(train_buffer))
print("Out acts: ", acts.shape)
print("Stored acts: ", torch.cat(train_buffer._buffer, dim=0).shape)
Out acts:  torch.Size([10000, 128])
Stored acts:  torch.Size([960000, 128])

Train a SAE#

Evaluate a SAE#