Visualise Heatmaps#

Open In Colab

Setup#

[ ]:
import importlib.util

DEV = True

if importlib.util.find_spec("google.colab") is not None:
    MODE = "colab-dev" if DEV else "colab"
else:
    MODE = "local"
[2]:
if MODE == "colab":
    %pip install -q lczerolens[hf]
elif MODE == "colab-dev":
    !rm -r lczerolens
    !git clone https://github.com/Xmaster6y/lczerolens -b main
    %pip install -q ./lczerolens --extra hf

Visualise Attention#

[3]:
from lczerolens import LczeroModel
from lczerolens import LczeroBoard

from tdhook.latent.activation_caching import ActivationCaching


transformer_model = LczeroModel.from_hf("lczerolens/evidence-of-learned-lookahead")
board = LczeroBoard(fen="1rb1rbk1/2qn1p1p/p2p2p1/1ppPp2n/PP2P3/2P1BN1P/R1BN1PP1/3QR1K1 w - - 0 22")

hooking_context = ActivationCaching(key_pattern="encoder\d+/mha/QK/softmax", relative=True).prepare(transformer_model)

with hooking_context as hooked_model:
    hooked_model(board.to_input_tensor())
hooking_context.cache
/Users/xmaster/Work/Chess/lczerolens/.venv/lib/python3.12/site-packages/onnx2torch/node_converters/slice.py:63: UserWarning: Using a non-tuple sequence for multidimensional indexing is deprecated and will be changed in pytorch 2.9; use x[tuple(seq)] instead of x[seq]. In pytorch 2.9 this will be interpreted as tensor index, x[torch.tensor(seq)], which will result either in an error or a different result (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/python_variable_indexing.cpp:312.)
  x = x[pos_axes_slices]
[3]:
TensorDict(
    fields={
        encoder0/mha/QK/softmax: Tensor(shape=torch.Size([1, 24, 64, 64]), device=cpu, dtype=torch.float32, is_shared=False),
        encoder1/mha/QK/softmax: Tensor(shape=torch.Size([1, 24, 64, 64]), device=cpu, dtype=torch.float32, is_shared=False),
        encoder10/mha/QK/softmax: Tensor(shape=torch.Size([1, 24, 64, 64]), device=cpu, dtype=torch.float32, is_shared=False),
        encoder11/mha/QK/softmax: Tensor(shape=torch.Size([1, 24, 64, 64]), device=cpu, dtype=torch.float32, is_shared=False),
        encoder12/mha/QK/softmax: Tensor(shape=torch.Size([1, 24, 64, 64]), device=cpu, dtype=torch.float32, is_shared=False),
        encoder13/mha/QK/softmax: Tensor(shape=torch.Size([1, 24, 64, 64]), device=cpu, dtype=torch.float32, is_shared=False),
        encoder14/mha/QK/softmax: Tensor(shape=torch.Size([1, 24, 64, 64]), device=cpu, dtype=torch.float32, is_shared=False),
        encoder2/mha/QK/softmax: Tensor(shape=torch.Size([1, 24, 64, 64]), device=cpu, dtype=torch.float32, is_shared=False),
        encoder3/mha/QK/softmax: Tensor(shape=torch.Size([1, 24, 64, 64]), device=cpu, dtype=torch.float32, is_shared=False),
        encoder4/mha/QK/softmax: Tensor(shape=torch.Size([1, 24, 64, 64]), device=cpu, dtype=torch.float32, is_shared=False),
        encoder5/mha/QK/softmax: Tensor(shape=torch.Size([1, 24, 64, 64]), device=cpu, dtype=torch.float32, is_shared=False),
        encoder6/mha/QK/softmax: Tensor(shape=torch.Size([1, 24, 64, 64]), device=cpu, dtype=torch.float32, is_shared=False),
        encoder7/mha/QK/softmax: Tensor(shape=torch.Size([1, 24, 64, 64]), device=cpu, dtype=torch.float32, is_shared=False),
        encoder8/mha/QK/softmax: Tensor(shape=torch.Size([1, 24, 64, 64]), device=cpu, dtype=torch.float32, is_shared=False),
        encoder9/mha/QK/softmax: Tensor(shape=torch.Size([1, 24, 64, 64]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
[4]:
import chess
import IPython.display

batch_index = 0
layer = 9
head = 5
piece = chess.F3

attention_weights = hooking_context.cache[f"encoder{layer}/mha/QK/softmax"][batch_index, head]
svg_board, svg_colorbar = board.render_heatmap(attention_weights[piece].detach())
display(IPython.display.HTML(f"{svg_board}{svg_colorbar}"))
. r b . r b k .
. . q n . p . p
p . . p . . p .
. p p P p . . n
P P . . P . . .
. . P . B N . P
R . B N . P P .
. . . Q R . K .
2025-09-04T17:02:54.344695 image/svg+xml Matplotlib v3.10.0, https://matplotlib.org/

Visualise Gradients#

[5]:
from tdhook.attribution.saliency import Saliency
from tensordict import TensorDict

cnn_model = LczeroModel.from_hf("lczerolens/maia-1900")


def init_targets(td, _):
    return TensorDict(out=td["wdl"][..., 0], batch_size=td.batch_size)


td = TensorDict(board=cnn_model.prepare_boards(board), batch_size=1)

saliency_context = Saliency(init_attr_targets=init_targets)
with saliency_context.prepare(cnn_model) as hooked_model:
    output = hooked_model(td)
[6]:
batch_index = 0
plane = 1  # N

svg_board, svg_colorbar = board.render_heatmap(
    output.get(("attr", "board"))[batch_index, plane].view(64).detach(), normalise="abs"
)
display(IPython.display.HTML(f"{svg_board}{svg_colorbar}"))
. r b . r b k .
. . q n . p . p
p . . p . . p .
. p p P p . . n
P P . . P . . .
. . P . B N . P
R . B N . P P .
. . . Q R . K .
2025-09-04T17:02:54.588617 image/svg+xml Matplotlib v3.10.0, https://matplotlib.org/
[7]:
gap_input_grad = output.get(("attr", "board"))[:, :12].mean(dim=1)

svg_board, svg_colorbar = board.render_heatmap(gap_input_grad[batch_index].view(64).detach(), normalise="abs")
display(IPython.display.HTML(f"{svg_board}{svg_colorbar}"))
. r b . r b k .
. . q n . p . p
p . . p . . p .
. p p P p . . n
P P . . P . . .
. . P . B N . P
R . B N . P P .
. . . Q R . K .
2025-09-04T17:02:54.629408 image/svg+xml Matplotlib v3.10.0, https://matplotlib.org/

GradCAM#

Loking at the value#

[8]:
from tdhook.attribution.grad_cam import GradCAM, DimsConfig

cnn_model = LczeroModel.from_hf("lczerolens/maia-1900")


def init_targets(td, _):
    return TensorDict(out=td["wdl"][..., 0], batch_size=td.batch_size)


td = TensorDict(board=cnn_model.prepare_boards(board), batch_size=1)
modules_to_attribute = {
    ":value/conv:": DimsConfig(
        feature_dims=(2, 3),
        pooling_dims=(1,),
    ),
    ":block4/conv2/relu:": DimsConfig(
        feature_dims=(2, 3),
        pooling_dims=(1,),
    ),
}

saliency_context = GradCAM(modules_to_attribute, init_attr_targets=init_targets)
with saliency_context.prepare(cnn_model) as hooked_model:
    output = hooked_model(td)
output
[8]:
TensorDict(
    fields={
        _cache_in: TensorDict(
            fields={
                :block4/conv2/relu:: Tensor(shape=torch.Size([1, 64, 8, 8]), device=cpu, dtype=torch.float32, is_shared=False),
                :value/conv:: Tensor(shape=torch.Size([1, 32, 8, 8]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([1]),
            device=None,
            is_shared=False),
        attr: TensorDict(
            fields={
                :block4/conv2/relu:: Tensor(shape=torch.Size([1, 8, 8]), device=cpu, dtype=torch.float32, is_shared=False),
                :value/conv:: Tensor(shape=torch.Size([1, 8, 8]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([1]),
            device=None,
            is_shared=False),
        board: Tensor(shape=torch.Size([1, 112, 8, 8]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([1]),
    device=None,
    is_shared=False)
[9]:
svg_board, svg_colorbar = board.render_heatmap(
    output.get(("attr", ":value/conv:")).flatten().detach(), normalise="abs"
)
display(IPython.display.HTML(f"{svg_board}{svg_colorbar}"))
. r b . r b k .
. . q n . p . p
p . . p . . p .
. p p P p . . n
P P . . P . . .
. . P . B N . P
R . B N . P P .
. . . Q R . K .
2025-09-04T17:02:54.941315 image/svg+xml Matplotlib v3.10.0, https://matplotlib.org/
[10]:
svg_board, svg_colorbar = board.render_heatmap(
    output.get(("attr", ":block4/conv2/relu:")).flatten().detach(), normalise="abs"
)
display(IPython.display.HTML(f"{svg_board}{svg_colorbar}"))
. r b . r b k .
. . q n . p . p
p . . p . . p .
. p p P p . . n
P P . . P . . .
. . P . B N . P
R . B N . P P .
. . . Q R . K .
2025-09-04T17:02:54.982857 image/svg+xml Matplotlib v3.10.0, https://matplotlib.org/

Looking at the policy#

[11]:
import torch
from lczerolens.sampling import PolicySampler

policy_sampler = PolicySampler(model=cnn_model)
utilities, legal_indices, _ = next(iter(policy_sampler.get_utilities([board])))

topk_indices = torch.topk(utilities, k=3).indices
utilities[topk_indices], legal_indices[topk_indices]
[11]:
(tensor([10.5965, 10.5532, 10.3011]), tensor([590, 378, 668]))
[12]:
topk_index = 0

demo_board = LczeroBoard(fen="1rb1rbk1/2qn1p1p/p2p2p1/1ppPp2n/PP2P3/2P1BN1P/R1BN1PP1/3QR1K1 w - - 0 22")
move = demo_board.decode_move(legal_indices[topk_indices[topk_index]])
demo_board.push(move)
display(demo_board)
../../_images/notebooks_features_visualise-heatmaps_18_0.svg
[13]:
policy_index = legal_indices[topk_indices[topk_index]]


def init_targets(td, _):
    return TensorDict(out=td["policy"][..., policy_index], batch_size=td.batch_size)


td = TensorDict(board=cnn_model.prepare_boards(board), batch_size=1)
modules_to_attribute = {
    ":policy/conv2:": DimsConfig(
        feature_dims=(2, 3),
        pooling_dims=(1,),
    )
}

saliency_context = GradCAM(modules_to_attribute, init_attr_targets=init_targets)
with saliency_context.prepare(cnn_model) as hooked_model:
    output = hooked_model(td)

svg_board, svg_colorbar = board.render_heatmap(
    output.get(("attr", ":policy/conv2:")).view(64).detach(), normalise="abs"
)
display(IPython.display.HTML(f"{svg_board}{svg_colorbar}"))
. r b . r b k .
. . q n . p . p
p . . p . . p .
. p p P p . . n
P P . . P . . .
. . P . B N . P
R . B N . P P .
. . . Q R . K .
2025-09-04T17:02:55.066276 image/svg+xml Matplotlib v3.10.0, https://matplotlib.org/