Visualise Heatmaps#

Open In Colab

Setup#

[1]:
MODE = "local"  # "colab" | "colab-dev" | "local"
[2]:
if MODE == "colab":
    !pip install -q lczerolens
elif MODE == "colab-dev":
    !rm -r lczerolens
    !git clone https://github.com/Xmaster6y/lczerolens -b main
    !pip install -q ./lczerolens
[3]:
!gdown 1CvMyX3KuYxCJUKz9kOb9VX8zIkfISALd -O lc0-19-4508.onnx
!gdown 1PB097ZKd_zTaPHxLK29WKUWmv6KcZ15T -O lc0.onnx
Downloading...
From: https://drive.google.com/uc?id=1CvMyX3KuYxCJUKz9kOb9VX8zIkfISALd
To: /Users/xmaster/Work/lczerolens/docs/source/notebooks/features/lc0-19-4508.onnx
100%|██████████████████████████████████████| 97.1M/97.1M [00:01<00:00, 49.4MB/s]
Downloading...
From (original): https://drive.google.com/uc?id=1PB097ZKd_zTaPHxLK29WKUWmv6KcZ15T
From (redirected): https://drive.google.com/uc?id=1PB097ZKd_zTaPHxLK29WKUWmv6KcZ15T&confirm=t&uuid=d1bc81a6-4ce3-4d23-ae2b-5c11d0e44993
To: /Users/xmaster/Work/lczerolens/docs/source/notebooks/features/lc0.onnx
100%|████████████████████████████████████████| 379M/379M [00:07<00:00, 52.8MB/s]

Visualise Attention#

[4]:
from lczerolens import LczeroModel

transformer_model = LczeroModel.from_path("lc0.onnx")
transformer_model
/Users/xmaster/Work/lczerolens/.venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
[4]:
GraphModule(
  (attn_body/transpose): OnnxTranspose()
  (initializers): Module()
  (attn_body/reshape): OnnxReshape()
  (attn_body/shape): OnnxShape()
  (attn_body/batch): OnnxSlice()
  (attn_body/pos_encoding_shape): OnnxConcat()
  (attn_body/expand): OnnxExpand()
  (attn_body/padded_input): OnnxConcat()
  (attn_body/reshape2): OnnxReshape()
  (attn_body/matmul): OnnxMatMul()
  (attn_body/add): OnnxBinaryMathOperation()
  (attn_body/mish/softplus): Softplus(beta=1.0, threshold=20.0)
  (attn_body/mish/tanh): OnnxFunction()
  (attn_body/mish): OnnxBinaryMathOperation()
  (attn_body/ma_gating/rehape1): OnnxReshape()
  (ip_mul_gate): OnnxBinaryMathOperation()
  (ip_add_gate): OnnxBinaryMathOperation()
  (attn_body/ma_gating/rehape2): OnnxReshape()
  (encoder0/mha/Q/w): OnnxMatMul()
  (encoder0/mha/Q/b): OnnxBinaryMathOperation()
  (encoder0/mha/Q/reshape): OnnxReshape()
  (encoder0/mha/Q/transpose): OnnxTranspose()
  (encoder0/mha/K/w): OnnxMatMul()
  (encoder0/mha/K/b): OnnxBinaryMathOperation()
  (encoder0/mha/K/reshape): OnnxReshape()
  (encoder0/mha/K/transpose): OnnxTranspose()
  (encoder0/mha/V/w): OnnxMatMul()
  (encoder0/mha/V/b): OnnxBinaryMathOperation()
  (encoder0/mha/V/reshape): OnnxReshape()
  (encoder0/mha/V/transpose): OnnxTranspose()
  (encoder0/mha/QK/matmul): OnnxMatMul()
  (encoder0/mha/QK/scale): OnnxBinaryMathOperation()
  (encoder0/smolgen/compress): OnnxMatMul()
  (encoder0/smolgen/compress/reshape): OnnxReshape()
  (encoder0/smolgen/dense1/w): OnnxMatMul()
  (encoder0/smolgen/dense1/b): OnnxBinaryMathOperation()
  (encoder0/smolgen/dense1/swish/sigmoid): Sigmoid()
  (encoder0/smolgen/dense1/swish): OnnxBinaryMathOperation()
  (encoder0/smolgen/ln1): LayerNorm((256,), eps=0.0010000000474974513, elementwise_affine=True)
  (encoder0/smolgen/dense2/w): OnnxMatMul()
  (encoder0/smolgen/dense2/b): OnnxBinaryMathOperation()
  (encoder0/smolgen/dense2/swish/sigmoid): Sigmoid()
  (encoder0/smolgen/dense2/swish): OnnxBinaryMathOperation()
  (encoder0/smolgen/ln2): LayerNorm((6144,), eps=0.0010000000474974513, elementwise_affine=True)
  (encoder0/smolgen/gen_from/reshape): OnnxReshape()
  (encoder0/smolgen/smol_weight_gen): OnnxMatMul()
  (encoder0/smolgen/out/reshape): OnnxReshape()
  (encoder0/smolgen_weights): OnnxBinaryMathOperation()
  (encoder0/mha/QK/softmax): Softmax(dim=3)
  (encoder0/mha/QKV/matmul): OnnxMatMul()
  (encoder0/mha/out/transpose): OnnxTranspose()
  (encoder0/mha/out/reshape): OnnxReshape()
  (encoder0/mha/out/dense/w): OnnxMatMul()
  (encoder0/mha/out/dense/b): OnnxBinaryMathOperation()
  (encoder0/alpha*input): OnnxBinaryMathOperation()
  (encoder0/mha/out/skip): OnnxBinaryMathOperation()
  (encoder0/ln1): LayerNorm((768,), eps=9.999999974752427e-07, elementwise_affine=True)
  (encoder0/ffn/dense1/w): OnnxMatMul()
  (encoder0/ffn/dense1/b): OnnxBinaryMathOperation()
  (encoder0/ffn/dense1/sqrrelu/relu): ReLU()
  (encoder0/ffn/dense1/sqrrelu/sqr): OnnxBinaryMathOperation()
  (encoder0/ffn/dense2/w): OnnxMatMul()
  (encoder0/ffn/dense2/b): OnnxBinaryMathOperation()
  (encoder0/alpha*out1): OnnxBinaryMathOperation()
  (encoder0/ffn/skip): OnnxBinaryMathOperation()
  (encoder0/ln2): LayerNorm((768,), eps=9.999999974752427e-07, elementwise_affine=True)
  (encoder1/mha/Q/w): OnnxMatMul()
  (encoder1/mha/Q/b): OnnxBinaryMathOperation()
  (encoder1/mha/Q/reshape): OnnxReshape()
  (encoder1/mha/Q/transpose): OnnxTranspose()
  (encoder1/mha/K/w): OnnxMatMul()
  (encoder1/mha/K/b): OnnxBinaryMathOperation()
  (encoder1/mha/K/reshape): OnnxReshape()
  (encoder1/mha/K/transpose): OnnxTranspose()
  (encoder1/mha/V/w): OnnxMatMul()
  (encoder1/mha/V/b): OnnxBinaryMathOperation()
  (encoder1/mha/V/reshape): OnnxReshape()
  (encoder1/mha/V/transpose): OnnxTranspose()
  (encoder1/mha/QK/matmul): OnnxMatMul()
  (encoder1/mha/QK/scale): OnnxBinaryMathOperation()
  (encoder1/smolgen/compress): OnnxMatMul()
  (encoder1/smolgen/compress/reshape): OnnxReshape()
  (encoder1/smolgen/dense1/w): OnnxMatMul()
  (encoder1/smolgen/dense1/b): OnnxBinaryMathOperation()
  (encoder1/smolgen/dense1/swish/sigmoid): Sigmoid()
  (encoder1/smolgen/dense1/swish): OnnxBinaryMathOperation()
  (encoder1/smolgen/ln1): LayerNorm((256,), eps=0.0010000000474974513, elementwise_affine=True)
  (encoder1/smolgen/dense2/w): OnnxMatMul()
  (encoder1/smolgen/dense2/b): OnnxBinaryMathOperation()
  (encoder1/smolgen/dense2/swish/sigmoid): Sigmoid()
  (encoder1/smolgen/dense2/swish): OnnxBinaryMathOperation()
  (encoder1/smolgen/ln2): LayerNorm((6144,), eps=0.0010000000474974513, elementwise_affine=True)
  (encoder1/smolgen/gen_from/reshape): OnnxReshape()
  (encoder1/smolgen/smol_weight_gen): OnnxMatMul()
  (encoder1/smolgen/out/reshape): OnnxReshape()
  (encoder1/smolgen_weights): OnnxBinaryMathOperation()
  (encoder1/mha/QK/softmax): Softmax(dim=3)
  (encoder1/mha/QKV/matmul): OnnxMatMul()
  (encoder1/mha/out/transpose): OnnxTranspose()
  (encoder1/mha/out/reshape): OnnxReshape()
  (encoder1/mha/out/dense/w): OnnxMatMul()
  (encoder1/mha/out/dense/b): OnnxBinaryMathOperation()
  (encoder1/alpha*input): OnnxBinaryMathOperation()
  (encoder1/mha/out/skip): OnnxBinaryMathOperation()
  (encoder1/ln1): LayerNorm((768,), eps=9.999999974752427e-07, elementwise_affine=True)
  (encoder1/ffn/dense1/w): OnnxMatMul()
  (encoder1/ffn/dense1/b): OnnxBinaryMathOperation()
  (encoder1/ffn/dense1/sqrrelu/relu): ReLU()
  (encoder1/ffn/dense1/sqrrelu/sqr): OnnxBinaryMathOperation()
  (encoder1/ffn/dense2/w): OnnxMatMul()
  (encoder1/ffn/dense2/b): OnnxBinaryMathOperation()
  (encoder1/alpha*out1): OnnxBinaryMathOperation()
  (encoder1/ffn/skip): OnnxBinaryMathOperation()
  (encoder1/ln2): LayerNorm((768,), eps=9.999999974752427e-07, elementwise_affine=True)
  (encoder2/mha/Q/w): OnnxMatMul()
  (encoder2/mha/Q/b): OnnxBinaryMathOperation()
  (encoder2/mha/Q/reshape): OnnxReshape()
  (encoder2/mha/Q/transpose): OnnxTranspose()
  (encoder2/mha/K/w): OnnxMatMul()
  (encoder2/mha/K/b): OnnxBinaryMathOperation()
  (encoder2/mha/K/reshape): OnnxReshape()
  (encoder2/mha/K/transpose): OnnxTranspose()
  (encoder2/mha/V/w): OnnxMatMul()
  (encoder2/mha/V/b): OnnxBinaryMathOperation()
  (encoder2/mha/V/reshape): OnnxReshape()
  (encoder2/mha/V/transpose): OnnxTranspose()
  (encoder2/mha/QK/matmul): OnnxMatMul()
  (encoder2/mha/QK/scale): OnnxBinaryMathOperation()
  (encoder2/smolgen/compress): OnnxMatMul()
  (encoder2/smolgen/compress/reshape): OnnxReshape()
  (encoder2/smolgen/dense1/w): OnnxMatMul()
  (encoder2/smolgen/dense1/b): OnnxBinaryMathOperation()
  (encoder2/smolgen/dense1/swish/sigmoid): Sigmoid()
  (encoder2/smolgen/dense1/swish): OnnxBinaryMathOperation()
  (encoder2/smolgen/ln1): LayerNorm((256,), eps=0.0010000000474974513, elementwise_affine=True)
  (encoder2/smolgen/dense2/w): OnnxMatMul()
  (encoder2/smolgen/dense2/b): OnnxBinaryMathOperation()
  (encoder2/smolgen/dense2/swish/sigmoid): Sigmoid()
  (encoder2/smolgen/dense2/swish): OnnxBinaryMathOperation()
  (encoder2/smolgen/ln2): LayerNorm((6144,), eps=0.0010000000474974513, elementwise_affine=True)
  (encoder2/smolgen/gen_from/reshape): OnnxReshape()
  (encoder2/smolgen/smol_weight_gen): OnnxMatMul()
  (encoder2/smolgen/out/reshape): OnnxReshape()
  (encoder2/smolgen_weights): OnnxBinaryMathOperation()
  (encoder2/mha/QK/softmax): Softmax(dim=3)
  (encoder2/mha/QKV/matmul): OnnxMatMul()
  (encoder2/mha/out/transpose): OnnxTranspose()
  (encoder2/mha/out/reshape): OnnxReshape()
  (encoder2/mha/out/dense/w): OnnxMatMul()
  (encoder2/mha/out/dense/b): OnnxBinaryMathOperation()
  (encoder2/alpha*input): OnnxBinaryMathOperation()
  (encoder2/mha/out/skip): OnnxBinaryMathOperation()
  (encoder2/ln1): LayerNorm((768,), eps=9.999999974752427e-07, elementwise_affine=True)
  (encoder2/ffn/dense1/w): OnnxMatMul()
  (encoder2/ffn/dense1/b): OnnxBinaryMathOperation()
  (encoder2/ffn/dense1/sqrrelu/relu): ReLU()
  (encoder2/ffn/dense1/sqrrelu/sqr): OnnxBinaryMathOperation()
  (encoder2/ffn/dense2/w): OnnxMatMul()
  (encoder2/ffn/dense2/b): OnnxBinaryMathOperation()
  (encoder2/alpha*out1): OnnxBinaryMathOperation()
  (encoder2/ffn/skip): OnnxBinaryMathOperation()
  (encoder2/ln2): LayerNorm((768,), eps=9.999999974752427e-07, elementwise_affine=True)
  (encoder3/mha/Q/w): OnnxMatMul()
  (encoder3/mha/Q/b): OnnxBinaryMathOperation()
  (encoder3/mha/Q/reshape): OnnxReshape()
  (encoder3/mha/Q/transpose): OnnxTranspose()
  (encoder3/mha/K/w): OnnxMatMul()
  (encoder3/mha/K/b): OnnxBinaryMathOperation()
  (encoder3/mha/K/reshape): OnnxReshape()
  (encoder3/mha/K/transpose): OnnxTranspose()
  (encoder3/mha/V/w): OnnxMatMul()
  (encoder3/mha/V/b): OnnxBinaryMathOperation()
  (encoder3/mha/V/reshape): OnnxReshape()
  (encoder3/mha/V/transpose): OnnxTranspose()
  (encoder3/mha/QK/matmul): OnnxMatMul()
  (encoder3/mha/QK/scale): OnnxBinaryMathOperation()
  (encoder3/smolgen/compress): OnnxMatMul()
  (encoder3/smolgen/compress/reshape): OnnxReshape()
  (encoder3/smolgen/dense1/w): OnnxMatMul()
  (encoder3/smolgen/dense1/b): OnnxBinaryMathOperation()
  (encoder3/smolgen/dense1/swish/sigmoid): Sigmoid()
  (encoder3/smolgen/dense1/swish): OnnxBinaryMathOperation()
  (encoder3/smolgen/ln1): LayerNorm((256,), eps=0.0010000000474974513, elementwise_affine=True)
  (encoder3/smolgen/dense2/w): OnnxMatMul()
  (encoder3/smolgen/dense2/b): OnnxBinaryMathOperation()
  (encoder3/smolgen/dense2/swish/sigmoid): Sigmoid()
  (encoder3/smolgen/dense2/swish): OnnxBinaryMathOperation()
  (encoder3/smolgen/ln2): LayerNorm((6144,), eps=0.0010000000474974513, elementwise_affine=True)
  (encoder3/smolgen/gen_from/reshape): OnnxReshape()
  (encoder3/smolgen/smol_weight_gen): OnnxMatMul()
  (encoder3/smolgen/out/reshape): OnnxReshape()
  (encoder3/smolgen_weights): OnnxBinaryMathOperation()
  (encoder3/mha/QK/softmax): Softmax(dim=3)
  (encoder3/mha/QKV/matmul): OnnxMatMul()
  (encoder3/mha/out/transpose): OnnxTranspose()
  (encoder3/mha/out/reshape): OnnxReshape()
  (encoder3/mha/out/dense/w): OnnxMatMul()
  (encoder3/mha/out/dense/b): OnnxBinaryMathOperation()
  (encoder3/alpha*input): OnnxBinaryMathOperation()
  (encoder3/mha/out/skip): OnnxBinaryMathOperation()
  (encoder3/ln1): LayerNorm((768,), eps=9.999999974752427e-07, elementwise_affine=True)
  (encoder3/ffn/dense1/w): OnnxMatMul()
  (encoder3/ffn/dense1/b): OnnxBinaryMathOperation()
  (encoder3/ffn/dense1/sqrrelu/relu): ReLU()
  (encoder3/ffn/dense1/sqrrelu/sqr): OnnxBinaryMathOperation()
  (encoder3/ffn/dense2/w): OnnxMatMul()
  (encoder3/ffn/dense2/b): OnnxBinaryMathOperation()
  (encoder3/alpha*out1): OnnxBinaryMathOperation()
  (encoder3/ffn/skip): OnnxBinaryMathOperation()
  (encoder3/ln2): LayerNorm((768,), eps=9.999999974752427e-07, elementwise_affine=True)
  (encoder4/mha/Q/w): OnnxMatMul()
  (encoder4/mha/Q/b): OnnxBinaryMathOperation()
  (encoder4/mha/Q/reshape): OnnxReshape()
  (encoder4/mha/Q/transpose): OnnxTranspose()
  (encoder4/mha/K/w): OnnxMatMul()
  (encoder4/mha/K/b): OnnxBinaryMathOperation()
  (encoder4/mha/K/reshape): OnnxReshape()
  (encoder4/mha/K/transpose): OnnxTranspose()
  (encoder4/mha/V/w): OnnxMatMul()
  (encoder4/mha/V/b): OnnxBinaryMathOperation()
  (encoder4/mha/V/reshape): OnnxReshape()
  (encoder4/mha/V/transpose): OnnxTranspose()
  (encoder4/mha/QK/matmul): OnnxMatMul()
  (encoder4/mha/QK/scale): OnnxBinaryMathOperation()
  (encoder4/smolgen/compress): OnnxMatMul()
  (encoder4/smolgen/compress/reshape): OnnxReshape()
  (encoder4/smolgen/dense1/w): OnnxMatMul()
  (encoder4/smolgen/dense1/b): OnnxBinaryMathOperation()
  (encoder4/smolgen/dense1/swish/sigmoid): Sigmoid()
  (encoder4/smolgen/dense1/swish): OnnxBinaryMathOperation()
  (encoder4/smolgen/ln1): LayerNorm((256,), eps=0.0010000000474974513, elementwise_affine=True)
  (encoder4/smolgen/dense2/w): OnnxMatMul()
  (encoder4/smolgen/dense2/b): OnnxBinaryMathOperation()
  (encoder4/smolgen/dense2/swish/sigmoid): Sigmoid()
  (encoder4/smolgen/dense2/swish): OnnxBinaryMathOperation()
  (encoder4/smolgen/ln2): LayerNorm((6144,), eps=0.0010000000474974513, elementwise_affine=True)
  (encoder4/smolgen/gen_from/reshape): OnnxReshape()
  (encoder4/smolgen/smol_weight_gen): OnnxMatMul()
  (encoder4/smolgen/out/reshape): OnnxReshape()
  (encoder4/smolgen_weights): OnnxBinaryMathOperation()
  (encoder4/mha/QK/softmax): Softmax(dim=3)
  (encoder4/mha/QKV/matmul): OnnxMatMul()
  (encoder4/mha/out/transpose): OnnxTranspose()
  (encoder4/mha/out/reshape): OnnxReshape()
  (encoder4/mha/out/dense/w): OnnxMatMul()
  (encoder4/mha/out/dense/b): OnnxBinaryMathOperation()
  (encoder4/alpha*input): OnnxBinaryMathOperation()
  (encoder4/mha/out/skip): OnnxBinaryMathOperation()
  (encoder4/ln1): LayerNorm((768,), eps=9.999999974752427e-07, elementwise_affine=True)
  (encoder4/ffn/dense1/w): OnnxMatMul()
  (encoder4/ffn/dense1/b): OnnxBinaryMathOperation()
  (encoder4/ffn/dense1/sqrrelu/relu): ReLU()
  (encoder4/ffn/dense1/sqrrelu/sqr): OnnxBinaryMathOperation()
  (encoder4/ffn/dense2/w): OnnxMatMul()
  (encoder4/ffn/dense2/b): OnnxBinaryMathOperation()
  (encoder4/alpha*out1): OnnxBinaryMathOperation()
  (encoder4/ffn/skip): OnnxBinaryMathOperation()
  (encoder4/ln2): LayerNorm((768,), eps=9.999999974752427e-07, elementwise_affine=True)
  (encoder5/mha/Q/w): OnnxMatMul()
  (encoder5/mha/Q/b): OnnxBinaryMathOperation()
  (encoder5/mha/Q/reshape): OnnxReshape()
  (encoder5/mha/Q/transpose): OnnxTranspose()
  (encoder5/mha/K/w): OnnxMatMul()
  (encoder5/mha/K/b): OnnxBinaryMathOperation()
  (encoder5/mha/K/reshape): OnnxReshape()
  (encoder5/mha/K/transpose): OnnxTranspose()
  (encoder5/mha/V/w): OnnxMatMul()
  (encoder5/mha/V/b): OnnxBinaryMathOperation()
  (encoder5/mha/V/reshape): OnnxReshape()
  (encoder5/mha/V/transpose): OnnxTranspose()
  (encoder5/mha/QK/matmul): OnnxMatMul()
  (encoder5/mha/QK/scale): OnnxBinaryMathOperation()
  (encoder5/smolgen/compress): OnnxMatMul()
  (encoder5/smolgen/compress/reshape): OnnxReshape()
  (encoder5/smolgen/dense1/w): OnnxMatMul()
  (encoder5/smolgen/dense1/b): OnnxBinaryMathOperation()
  (encoder5/smolgen/dense1/swish/sigmoid): Sigmoid()
  (encoder5/smolgen/dense1/swish): OnnxBinaryMathOperation()
  (encoder5/smolgen/ln1): LayerNorm((256,), eps=0.0010000000474974513, elementwise_affine=True)
  (encoder5/smolgen/dense2/w): OnnxMatMul()
  (encoder5/smolgen/dense2/b): OnnxBinaryMathOperation()
  (encoder5/smolgen/dense2/swish/sigmoid): Sigmoid()
  (encoder5/smolgen/dense2/swish): OnnxBinaryMathOperation()
  (encoder5/smolgen/ln2): LayerNorm((6144,), eps=0.0010000000474974513, elementwise_affine=True)
  (encoder5/smolgen/gen_from/reshape): OnnxReshape()
  (encoder5/smolgen/smol_weight_gen): OnnxMatMul()
  (encoder5/smolgen/out/reshape): OnnxReshape()
  (encoder5/smolgen_weights): OnnxBinaryMathOperation()
  (encoder5/mha/QK/softmax): Softmax(dim=3)
  (encoder5/mha/QKV/matmul): OnnxMatMul()
  (encoder5/mha/out/transpose): OnnxTranspose()
  (encoder5/mha/out/reshape): OnnxReshape()
  (encoder5/mha/out/dense/w): OnnxMatMul()
  (encoder5/mha/out/dense/b): OnnxBinaryMathOperation()
  (encoder5/alpha*input): OnnxBinaryMathOperation()
  (encoder5/mha/out/skip): OnnxBinaryMathOperation()
  (encoder5/ln1): LayerNorm((768,), eps=9.999999974752427e-07, elementwise_affine=True)
  (encoder5/ffn/dense1/w): OnnxMatMul()
  (encoder5/ffn/dense1/b): OnnxBinaryMathOperation()
  (encoder5/ffn/dense1/sqrrelu/relu): ReLU()
  (encoder5/ffn/dense1/sqrrelu/sqr): OnnxBinaryMathOperation()
  (encoder5/ffn/dense2/w): OnnxMatMul()
  (encoder5/ffn/dense2/b): OnnxBinaryMathOperation()
  (encoder5/alpha*out1): OnnxBinaryMathOperation()
  (encoder5/ffn/skip): OnnxBinaryMathOperation()
  (encoder5/ln2): LayerNorm((768,), eps=9.999999974752427e-07, elementwise_affine=True)
  (encoder6/mha/Q/w): OnnxMatMul()
  (encoder6/mha/Q/b): OnnxBinaryMathOperation()
  (encoder6/mha/Q/reshape): OnnxReshape()
  (encoder6/mha/Q/transpose): OnnxTranspose()
  (encoder6/mha/K/w): OnnxMatMul()
  (encoder6/mha/K/b): OnnxBinaryMathOperation()
  (encoder6/mha/K/reshape): OnnxReshape()
  (encoder6/mha/K/transpose): OnnxTranspose()
  (encoder6/mha/V/w): OnnxMatMul()
  (encoder6/mha/V/b): OnnxBinaryMathOperation()
  (encoder6/mha/V/reshape): OnnxReshape()
  (encoder6/mha/V/transpose): OnnxTranspose()
  (encoder6/mha/QK/matmul): OnnxMatMul()
  (encoder6/mha/QK/scale): OnnxBinaryMathOperation()
  (encoder6/smolgen/compress): OnnxMatMul()
  (encoder6/smolgen/compress/reshape): OnnxReshape()
  (encoder6/smolgen/dense1/w): OnnxMatMul()
  (encoder6/smolgen/dense1/b): OnnxBinaryMathOperation()
  (encoder6/smolgen/dense1/swish/sigmoid): Sigmoid()
  (encoder6/smolgen/dense1/swish): OnnxBinaryMathOperation()
  (encoder6/smolgen/ln1): LayerNorm((256,), eps=0.0010000000474974513, elementwise_affine=True)
  (encoder6/smolgen/dense2/w): OnnxMatMul()
  (encoder6/smolgen/dense2/b): OnnxBinaryMathOperation()
  (encoder6/smolgen/dense2/swish/sigmoid): Sigmoid()
  (encoder6/smolgen/dense2/swish): OnnxBinaryMathOperation()
  (encoder6/smolgen/ln2): LayerNorm((6144,), eps=0.0010000000474974513, elementwise_affine=True)
  (encoder6/smolgen/gen_from/reshape): OnnxReshape()
  (encoder6/smolgen/smol_weight_gen): OnnxMatMul()
  (encoder6/smolgen/out/reshape): OnnxReshape()
  (encoder6/smolgen_weights): OnnxBinaryMathOperation()
  (encoder6/mha/QK/softmax): Softmax(dim=3)
  (encoder6/mha/QKV/matmul): OnnxMatMul()
  (encoder6/mha/out/transpose): OnnxTranspose()
  (encoder6/mha/out/reshape): OnnxReshape()
  (encoder6/mha/out/dense/w): OnnxMatMul()
  (encoder6/mha/out/dense/b): OnnxBinaryMathOperation()
  (encoder6/alpha*input): OnnxBinaryMathOperation()
  (encoder6/mha/out/skip): OnnxBinaryMathOperation()
  (encoder6/ln1): LayerNorm((768,), eps=9.999999974752427e-07, elementwise_affine=True)
  (encoder6/ffn/dense1/w): OnnxMatMul()
  (encoder6/ffn/dense1/b): OnnxBinaryMathOperation()
  (encoder6/ffn/dense1/sqrrelu/relu): ReLU()
  (encoder6/ffn/dense1/sqrrelu/sqr): OnnxBinaryMathOperation()
  (encoder6/ffn/dense2/w): OnnxMatMul()
  (encoder6/ffn/dense2/b): OnnxBinaryMathOperation()
  (encoder6/alpha*out1): OnnxBinaryMathOperation()
  (encoder6/ffn/skip): OnnxBinaryMathOperation()
  (encoder6/ln2): LayerNorm((768,), eps=9.999999974752427e-07, elementwise_affine=True)
  (encoder7/mha/Q/w): OnnxMatMul()
  (encoder7/mha/Q/b): OnnxBinaryMathOperation()
  (encoder7/mha/Q/reshape): OnnxReshape()
  (encoder7/mha/Q/transpose): OnnxTranspose()
  (encoder7/mha/K/w): OnnxMatMul()
  (encoder7/mha/K/b): OnnxBinaryMathOperation()
  (encoder7/mha/K/reshape): OnnxReshape()
  (encoder7/mha/K/transpose): OnnxTranspose()
  (encoder7/mha/V/w): OnnxMatMul()
  (encoder7/mha/V/b): OnnxBinaryMathOperation()
  (encoder7/mha/V/reshape): OnnxReshape()
  (encoder7/mha/V/transpose): OnnxTranspose()
  (encoder7/mha/QK/matmul): OnnxMatMul()
  (encoder7/mha/QK/scale): OnnxBinaryMathOperation()
  (encoder7/smolgen/compress): OnnxMatMul()
  (encoder7/smolgen/compress/reshape): OnnxReshape()
  (encoder7/smolgen/dense1/w): OnnxMatMul()
  (encoder7/smolgen/dense1/b): OnnxBinaryMathOperation()
  (encoder7/smolgen/dense1/swish/sigmoid): Sigmoid()
  (encoder7/smolgen/dense1/swish): OnnxBinaryMathOperation()
  (encoder7/smolgen/ln1): LayerNorm((256,), eps=0.0010000000474974513, elementwise_affine=True)
  (encoder7/smolgen/dense2/w): OnnxMatMul()
  (encoder7/smolgen/dense2/b): OnnxBinaryMathOperation()
  (encoder7/smolgen/dense2/swish/sigmoid): Sigmoid()
  (encoder7/smolgen/dense2/swish): OnnxBinaryMathOperation()
  (encoder7/smolgen/ln2): LayerNorm((6144,), eps=0.0010000000474974513, elementwise_affine=True)
  (encoder7/smolgen/gen_from/reshape): OnnxReshape()
  (encoder7/smolgen/smol_weight_gen): OnnxMatMul()
  (encoder7/smolgen/out/reshape): OnnxReshape()
  (encoder7/smolgen_weights): OnnxBinaryMathOperation()
  (encoder7/mha/QK/softmax): Softmax(dim=3)
  (encoder7/mha/QKV/matmul): OnnxMatMul()
  (encoder7/mha/out/transpose): OnnxTranspose()
  (encoder7/mha/out/reshape): OnnxReshape()
  (encoder7/mha/out/dense/w): OnnxMatMul()
  (encoder7/mha/out/dense/b): OnnxBinaryMathOperation()
  (encoder7/alpha*input): OnnxBinaryMathOperation()
  (encoder7/mha/out/skip): OnnxBinaryMathOperation()
  (encoder7/ln1): LayerNorm((768,), eps=9.999999974752427e-07, elementwise_affine=True)
  (encoder7/ffn/dense1/w): OnnxMatMul()
  (encoder7/ffn/dense1/b): OnnxBinaryMathOperation()
  (encoder7/ffn/dense1/sqrrelu/relu): ReLU()
  (encoder7/ffn/dense1/sqrrelu/sqr): OnnxBinaryMathOperation()
  (encoder7/ffn/dense2/w): OnnxMatMul()
  (encoder7/ffn/dense2/b): OnnxBinaryMathOperation()
  (encoder7/alpha*out1): OnnxBinaryMathOperation()
  (encoder7/ffn/skip): OnnxBinaryMathOperation()
  (encoder7/ln2): LayerNorm((768,), eps=9.999999974752427e-07, elementwise_affine=True)
  (encoder8/mha/Q/w): OnnxMatMul()
  (encoder8/mha/Q/b): OnnxBinaryMathOperation()
  (encoder8/mha/Q/reshape): OnnxReshape()
  (encoder8/mha/Q/transpose): OnnxTranspose()
  (encoder8/mha/K/w): OnnxMatMul()
  (encoder8/mha/K/b): OnnxBinaryMathOperation()
  (encoder8/mha/K/reshape): OnnxReshape()
  (encoder8/mha/K/transpose): OnnxTranspose()
  (encoder8/mha/V/w): OnnxMatMul()
  (encoder8/mha/V/b): OnnxBinaryMathOperation()
  (encoder8/mha/V/reshape): OnnxReshape()
  (encoder8/mha/V/transpose): OnnxTranspose()
  (encoder8/mha/QK/matmul): OnnxMatMul()
  (encoder8/mha/QK/scale): OnnxBinaryMathOperation()
  (encoder8/smolgen/compress): OnnxMatMul()
  (encoder8/smolgen/compress/reshape): OnnxReshape()
  (encoder8/smolgen/dense1/w): OnnxMatMul()
  (encoder8/smolgen/dense1/b): OnnxBinaryMathOperation()
  (encoder8/smolgen/dense1/swish/sigmoid): Sigmoid()
  (encoder8/smolgen/dense1/swish): OnnxBinaryMathOperation()
  (encoder8/smolgen/ln1): LayerNorm((256,), eps=0.0010000000474974513, elementwise_affine=True)
  (encoder8/smolgen/dense2/w): OnnxMatMul()
  (encoder8/smolgen/dense2/b): OnnxBinaryMathOperation()
  (encoder8/smolgen/dense2/swish/sigmoid): Sigmoid()
  (encoder8/smolgen/dense2/swish): OnnxBinaryMathOperation()
  (encoder8/smolgen/ln2): LayerNorm((6144,), eps=0.0010000000474974513, elementwise_affine=True)
  (encoder8/smolgen/gen_from/reshape): OnnxReshape()
  (encoder8/smolgen/smol_weight_gen): OnnxMatMul()
  (encoder8/smolgen/out/reshape): OnnxReshape()
  (encoder8/smolgen_weights): OnnxBinaryMathOperation()
  (encoder8/mha/QK/softmax): Softmax(dim=3)
  (encoder8/mha/QKV/matmul): OnnxMatMul()
  (encoder8/mha/out/transpose): OnnxTranspose()
  (encoder8/mha/out/reshape): OnnxReshape()
  (encoder8/mha/out/dense/w): OnnxMatMul()
  (encoder8/mha/out/dense/b): OnnxBinaryMathOperation()
  (encoder8/alpha*input): OnnxBinaryMathOperation()
  (encoder8/mha/out/skip): OnnxBinaryMathOperation()
  (encoder8/ln1): LayerNorm((768,), eps=9.999999974752427e-07, elementwise_affine=True)
  (encoder8/ffn/dense1/w): OnnxMatMul()
  (encoder8/ffn/dense1/b): OnnxBinaryMathOperation()
  (encoder8/ffn/dense1/sqrrelu/relu): ReLU()
  (encoder8/ffn/dense1/sqrrelu/sqr): OnnxBinaryMathOperation()
  (encoder8/ffn/dense2/w): OnnxMatMul()
  (encoder8/ffn/dense2/b): OnnxBinaryMathOperation()
  (encoder8/alpha*out1): OnnxBinaryMathOperation()
  (encoder8/ffn/skip): OnnxBinaryMathOperation()
  (encoder8/ln2): LayerNorm((768,), eps=9.999999974752427e-07, elementwise_affine=True)
  (encoder9/mha/Q/w): OnnxMatMul()
  (encoder9/mha/Q/b): OnnxBinaryMathOperation()
  (encoder9/mha/Q/reshape): OnnxReshape()
  (encoder9/mha/Q/transpose): OnnxTranspose()
  (encoder9/mha/K/w): OnnxMatMul()
  (encoder9/mha/K/b): OnnxBinaryMathOperation()
  (encoder9/mha/K/reshape): OnnxReshape()
  (encoder9/mha/K/transpose): OnnxTranspose()
  (encoder9/mha/V/w): OnnxMatMul()
  (encoder9/mha/V/b): OnnxBinaryMathOperation()
  (encoder9/mha/V/reshape): OnnxReshape()
  (encoder9/mha/V/transpose): OnnxTranspose()
  (encoder9/mha/QK/matmul): OnnxMatMul()
  (encoder9/mha/QK/scale): OnnxBinaryMathOperation()
  (encoder9/smolgen/compress): OnnxMatMul()
  (encoder9/smolgen/compress/reshape): OnnxReshape()
  (encoder9/smolgen/dense1/w): OnnxMatMul()
  (encoder9/smolgen/dense1/b): OnnxBinaryMathOperation()
  (encoder9/smolgen/dense1/swish/sigmoid): Sigmoid()
  (encoder9/smolgen/dense1/swish): OnnxBinaryMathOperation()
  (encoder9/smolgen/ln1): LayerNorm((256,), eps=0.0010000000474974513, elementwise_affine=True)
  (encoder9/smolgen/dense2/w): OnnxMatMul()
  (encoder9/smolgen/dense2/b): OnnxBinaryMathOperation()
  (encoder9/smolgen/dense2/swish/sigmoid): Sigmoid()
  (encoder9/smolgen/dense2/swish): OnnxBinaryMathOperation()
  (encoder9/smolgen/ln2): LayerNorm((6144,), eps=0.0010000000474974513, elementwise_affine=True)
  (encoder9/smolgen/gen_from/reshape): OnnxReshape()
  (encoder9/smolgen/smol_weight_gen): OnnxMatMul()
  (encoder9/smolgen/out/reshape): OnnxReshape()
  (encoder9/smolgen_weights): OnnxBinaryMathOperation()
  (encoder9/mha/QK/softmax): Softmax(dim=3)
  (encoder9/mha/QKV/matmul): OnnxMatMul()
  (encoder9/mha/out/transpose): OnnxTranspose()
  (encoder9/mha/out/reshape): OnnxReshape()
  (encoder9/mha/out/dense/w): OnnxMatMul()
  (encoder9/mha/out/dense/b): OnnxBinaryMathOperation()
  (encoder9/alpha*input): OnnxBinaryMathOperation()
  (encoder9/mha/out/skip): OnnxBinaryMathOperation()
  (encoder9/ln1): LayerNorm((768,), eps=9.999999974752427e-07, elementwise_affine=True)
  (encoder9/ffn/dense1/w): OnnxMatMul()
  (encoder9/ffn/dense1/b): OnnxBinaryMathOperation()
  (encoder9/ffn/dense1/sqrrelu/relu): ReLU()
  (encoder9/ffn/dense1/sqrrelu/sqr): OnnxBinaryMathOperation()
  (encoder9/ffn/dense2/w): OnnxMatMul()
  (encoder9/ffn/dense2/b): OnnxBinaryMathOperation()
  (encoder9/alpha*out1): OnnxBinaryMathOperation()
  (encoder9/ffn/skip): OnnxBinaryMathOperation()
  (encoder9/ln2): LayerNorm((768,), eps=9.999999974752427e-07, elementwise_affine=True)
  (encoder10/mha/Q/w): OnnxMatMul()
  (encoder10/mha/Q/b): OnnxBinaryMathOperation()
  (encoder10/mha/Q/reshape): OnnxReshape()
  (encoder10/mha/Q/transpose): OnnxTranspose()
  (encoder10/mha/K/w): OnnxMatMul()
  (encoder10/mha/K/b): OnnxBinaryMathOperation()
  (encoder10/mha/K/reshape): OnnxReshape()
  (encoder10/mha/K/transpose): OnnxTranspose()
  (encoder10/mha/V/w): OnnxMatMul()
  (encoder10/mha/V/b): OnnxBinaryMathOperation()
  (encoder10/mha/V/reshape): OnnxReshape()
  (encoder10/mha/V/transpose): OnnxTranspose()
  (encoder10/mha/QK/matmul): OnnxMatMul()
  (encoder10/mha/QK/scale): OnnxBinaryMathOperation()
  (encoder10/smolgen/compress): OnnxMatMul()
  (encoder10/smolgen/compress/reshape): OnnxReshape()
  (encoder10/smolgen/dense1/w): OnnxMatMul()
  (encoder10/smolgen/dense1/b): OnnxBinaryMathOperation()
  (encoder10/smolgen/dense1/swish/sigmoid): Sigmoid()
  (encoder10/smolgen/dense1/swish): OnnxBinaryMathOperation()
  (encoder10/smolgen/ln1): LayerNorm((256,), eps=0.0010000000474974513, elementwise_affine=True)
  (encoder10/smolgen/dense2/w): OnnxMatMul()
  (encoder10/smolgen/dense2/b): OnnxBinaryMathOperation()
  (encoder10/smolgen/dense2/swish/sigmoid): Sigmoid()
  (encoder10/smolgen/dense2/swish): OnnxBinaryMathOperation()
  (encoder10/smolgen/ln2): LayerNorm((6144,), eps=0.0010000000474974513, elementwise_affine=True)
  (encoder10/smolgen/gen_from/reshape): OnnxReshape()
  (encoder10/smolgen/smol_weight_gen): OnnxMatMul()
  (encoder10/smolgen/out/reshape): OnnxReshape()
  (encoder10/smolgen_weights): OnnxBinaryMathOperation()
  (encoder10/mha/QK/softmax): Softmax(dim=3)
  (encoder10/mha/QKV/matmul): OnnxMatMul()
  (encoder10/mha/out/transpose): OnnxTranspose()
  (encoder10/mha/out/reshape): OnnxReshape()
  (encoder10/mha/out/dense/w): OnnxMatMul()
  (encoder10/mha/out/dense/b): OnnxBinaryMathOperation()
  (encoder10/alpha*input): OnnxBinaryMathOperation()
  (encoder10/mha/out/skip): OnnxBinaryMathOperation()
  (encoder10/ln1): LayerNorm((768,), eps=9.999999974752427e-07, elementwise_affine=True)
  (encoder10/ffn/dense1/w): OnnxMatMul()
  (encoder10/ffn/dense1/b): OnnxBinaryMathOperation()
  (encoder10/ffn/dense1/sqrrelu/relu): ReLU()
  (encoder10/ffn/dense1/sqrrelu/sqr): OnnxBinaryMathOperation()
  (encoder10/ffn/dense2/w): OnnxMatMul()
  (encoder10/ffn/dense2/b): OnnxBinaryMathOperation()
  (encoder10/alpha*out1): OnnxBinaryMathOperation()
  (encoder10/ffn/skip): OnnxBinaryMathOperation()
  (encoder10/ln2): LayerNorm((768,), eps=9.999999974752427e-07, elementwise_affine=True)
  (encoder11/mha/Q/w): OnnxMatMul()
  (encoder11/mha/Q/b): OnnxBinaryMathOperation()
  (encoder11/mha/Q/reshape): OnnxReshape()
  (encoder11/mha/Q/transpose): OnnxTranspose()
  (encoder11/mha/K/w): OnnxMatMul()
  (encoder11/mha/K/b): OnnxBinaryMathOperation()
  (encoder11/mha/K/reshape): OnnxReshape()
  (encoder11/mha/K/transpose): OnnxTranspose()
  (encoder11/mha/V/w): OnnxMatMul()
  (encoder11/mha/V/b): OnnxBinaryMathOperation()
  (encoder11/mha/V/reshape): OnnxReshape()
  (encoder11/mha/V/transpose): OnnxTranspose()
  (encoder11/mha/QK/matmul): OnnxMatMul()
  (encoder11/mha/QK/scale): OnnxBinaryMathOperation()
  (encoder11/smolgen/compress): OnnxMatMul()
  (encoder11/smolgen/compress/reshape): OnnxReshape()
  (encoder11/smolgen/dense1/w): OnnxMatMul()
  (encoder11/smolgen/dense1/b): OnnxBinaryMathOperation()
  (encoder11/smolgen/dense1/swish/sigmoid): Sigmoid()
  (encoder11/smolgen/dense1/swish): OnnxBinaryMathOperation()
  (encoder11/smolgen/ln1): LayerNorm((256,), eps=0.0010000000474974513, elementwise_affine=True)
  (encoder11/smolgen/dense2/w): OnnxMatMul()
  (encoder11/smolgen/dense2/b): OnnxBinaryMathOperation()
  (encoder11/smolgen/dense2/swish/sigmoid): Sigmoid()
  (encoder11/smolgen/dense2/swish): OnnxBinaryMathOperation()
  (encoder11/smolgen/ln2): LayerNorm((6144,), eps=0.0010000000474974513, elementwise_affine=True)
  (encoder11/smolgen/gen_from/reshape): OnnxReshape()
  (encoder11/smolgen/smol_weight_gen): OnnxMatMul()
  (encoder11/smolgen/out/reshape): OnnxReshape()
  (encoder11/smolgen_weights): OnnxBinaryMathOperation()
  (encoder11/mha/QK/softmax): Softmax(dim=3)
  (encoder11/mha/QKV/matmul): OnnxMatMul()
  (encoder11/mha/out/transpose): OnnxTranspose()
  (encoder11/mha/out/reshape): OnnxReshape()
  (encoder11/mha/out/dense/w): OnnxMatMul()
  (encoder11/mha/out/dense/b): OnnxBinaryMathOperation()
  (encoder11/alpha*input): OnnxBinaryMathOperation()
  (encoder11/mha/out/skip): OnnxBinaryMathOperation()
  (encoder11/ln1): LayerNorm((768,), eps=9.999999974752427e-07, elementwise_affine=True)
  (encoder11/ffn/dense1/w): OnnxMatMul()
  (encoder11/ffn/dense1/b): OnnxBinaryMathOperation()
  (encoder11/ffn/dense1/sqrrelu/relu): ReLU()
  (encoder11/ffn/dense1/sqrrelu/sqr): OnnxBinaryMathOperation()
  (encoder11/ffn/dense2/w): OnnxMatMul()
  (encoder11/ffn/dense2/b): OnnxBinaryMathOperation()
  (encoder11/alpha*out1): OnnxBinaryMathOperation()
  (encoder11/ffn/skip): OnnxBinaryMathOperation()
  (encoder11/ln2): LayerNorm((768,), eps=9.999999974752427e-07, elementwise_affine=True)
  (encoder12/mha/Q/w): OnnxMatMul()
  (encoder12/mha/Q/b): OnnxBinaryMathOperation()
  (encoder12/mha/Q/reshape): OnnxReshape()
  (encoder12/mha/Q/transpose): OnnxTranspose()
  (encoder12/mha/K/w): OnnxMatMul()
  (encoder12/mha/K/b): OnnxBinaryMathOperation()
  (encoder12/mha/K/reshape): OnnxReshape()
  (encoder12/mha/K/transpose): OnnxTranspose()
  (encoder12/mha/V/w): OnnxMatMul()
  (encoder12/mha/V/b): OnnxBinaryMathOperation()
  (encoder12/mha/V/reshape): OnnxReshape()
  (encoder12/mha/V/transpose): OnnxTranspose()
  (encoder12/mha/QK/matmul): OnnxMatMul()
  (encoder12/mha/QK/scale): OnnxBinaryMathOperation()
  (encoder12/smolgen/compress): OnnxMatMul()
  (encoder12/smolgen/compress/reshape): OnnxReshape()
  (encoder12/smolgen/dense1/w): OnnxMatMul()
  (encoder12/smolgen/dense1/b): OnnxBinaryMathOperation()
  (encoder12/smolgen/dense1/swish/sigmoid): Sigmoid()
  (encoder12/smolgen/dense1/swish): OnnxBinaryMathOperation()
  (encoder12/smolgen/ln1): LayerNorm((256,), eps=0.0010000000474974513, elementwise_affine=True)
  (encoder12/smolgen/dense2/w): OnnxMatMul()
  (encoder12/smolgen/dense2/b): OnnxBinaryMathOperation()
  (encoder12/smolgen/dense2/swish/sigmoid): Sigmoid()
  (encoder12/smolgen/dense2/swish): OnnxBinaryMathOperation()
  (encoder12/smolgen/ln2): LayerNorm((6144,), eps=0.0010000000474974513, elementwise_affine=True)
  (encoder12/smolgen/gen_from/reshape): OnnxReshape()
  (encoder12/smolgen/smol_weight_gen): OnnxMatMul()
  (encoder12/smolgen/out/reshape): OnnxReshape()
  (encoder12/smolgen_weights): OnnxBinaryMathOperation()
  (encoder12/mha/QK/softmax): Softmax(dim=3)
  (encoder12/mha/QKV/matmul): OnnxMatMul()
  (encoder12/mha/out/transpose): OnnxTranspose()
  (encoder12/mha/out/reshape): OnnxReshape()
  (encoder12/mha/out/dense/w): OnnxMatMul()
  (encoder12/mha/out/dense/b): OnnxBinaryMathOperation()
  (encoder12/alpha*input): OnnxBinaryMathOperation()
  (encoder12/mha/out/skip): OnnxBinaryMathOperation()
  (encoder12/ln1): LayerNorm((768,), eps=9.999999974752427e-07, elementwise_affine=True)
  (encoder12/ffn/dense1/w): OnnxMatMul()
  (encoder12/ffn/dense1/b): OnnxBinaryMathOperation()
  (encoder12/ffn/dense1/sqrrelu/relu): ReLU()
  (encoder12/ffn/dense1/sqrrelu/sqr): OnnxBinaryMathOperation()
  (encoder12/ffn/dense2/w): OnnxMatMul()
  (encoder12/ffn/dense2/b): OnnxBinaryMathOperation()
  (encoder12/alpha*out1): OnnxBinaryMathOperation()
  (encoder12/ffn/skip): OnnxBinaryMathOperation()
  (encoder12/ln2): LayerNorm((768,), eps=9.999999974752427e-07, elementwise_affine=True)
  (encoder13/mha/Q/w): OnnxMatMul()
  (encoder13/mha/Q/b): OnnxBinaryMathOperation()
  (encoder13/mha/Q/reshape): OnnxReshape()
  (encoder13/mha/Q/transpose): OnnxTranspose()
  (encoder13/mha/K/w): OnnxMatMul()
  (encoder13/mha/K/b): OnnxBinaryMathOperation()
  (encoder13/mha/K/reshape): OnnxReshape()
  (encoder13/mha/K/transpose): OnnxTranspose()
  (encoder13/mha/V/w): OnnxMatMul()
  (encoder13/mha/V/b): OnnxBinaryMathOperation()
  (encoder13/mha/V/reshape): OnnxReshape()
  (encoder13/mha/V/transpose): OnnxTranspose()
  (encoder13/mha/QK/matmul): OnnxMatMul()
  (encoder13/mha/QK/scale): OnnxBinaryMathOperation()
  (encoder13/smolgen/compress): OnnxMatMul()
  (encoder13/smolgen/compress/reshape): OnnxReshape()
  (encoder13/smolgen/dense1/w): OnnxMatMul()
  (encoder13/smolgen/dense1/b): OnnxBinaryMathOperation()
  (encoder13/smolgen/dense1/swish/sigmoid): Sigmoid()
  (encoder13/smolgen/dense1/swish): OnnxBinaryMathOperation()
  (encoder13/smolgen/ln1): LayerNorm((256,), eps=0.0010000000474974513, elementwise_affine=True)
  (encoder13/smolgen/dense2/w): OnnxMatMul()
  (encoder13/smolgen/dense2/b): OnnxBinaryMathOperation()
  (encoder13/smolgen/dense2/swish/sigmoid): Sigmoid()
  (encoder13/smolgen/dense2/swish): OnnxBinaryMathOperation()
  (encoder13/smolgen/ln2): LayerNorm((6144,), eps=0.0010000000474974513, elementwise_affine=True)
  (encoder13/smolgen/gen_from/reshape): OnnxReshape()
  (encoder13/smolgen/smol_weight_gen): OnnxMatMul()
  (encoder13/smolgen/out/reshape): OnnxReshape()
  (encoder13/smolgen_weights): OnnxBinaryMathOperation()
  (encoder13/mha/QK/softmax): Softmax(dim=3)
  (encoder13/mha/QKV/matmul): OnnxMatMul()
  (encoder13/mha/out/transpose): OnnxTranspose()
  (encoder13/mha/out/reshape): OnnxReshape()
  (encoder13/mha/out/dense/w): OnnxMatMul()
  (encoder13/mha/out/dense/b): OnnxBinaryMathOperation()
  (encoder13/alpha*input): OnnxBinaryMathOperation()
  (encoder13/mha/out/skip): OnnxBinaryMathOperation()
  (encoder13/ln1): LayerNorm((768,), eps=9.999999974752427e-07, elementwise_affine=True)
  (encoder13/ffn/dense1/w): OnnxMatMul()
  (encoder13/ffn/dense1/b): OnnxBinaryMathOperation()
  (encoder13/ffn/dense1/sqrrelu/relu): ReLU()
  (encoder13/ffn/dense1/sqrrelu/sqr): OnnxBinaryMathOperation()
  (encoder13/ffn/dense2/w): OnnxMatMul()
  (encoder13/ffn/dense2/b): OnnxBinaryMathOperation()
  (encoder13/alpha*out1): OnnxBinaryMathOperation()
  (encoder13/ffn/skip): OnnxBinaryMathOperation()
  (encoder13/ln2): LayerNorm((768,), eps=9.999999974752427e-07, elementwise_affine=True)
  (encoder14/mha/Q/w): OnnxMatMul()
  (encoder14/mha/Q/b): OnnxBinaryMathOperation()
  (encoder14/mha/Q/reshape): OnnxReshape()
  (encoder14/mha/Q/transpose): OnnxTranspose()
  (encoder14/mha/K/w): OnnxMatMul()
  (encoder14/mha/K/b): OnnxBinaryMathOperation()
  (encoder14/mha/K/reshape): OnnxReshape()
  (encoder14/mha/K/transpose): OnnxTranspose()
  (encoder14/mha/V/w): OnnxMatMul()
  (encoder14/mha/V/b): OnnxBinaryMathOperation()
  (encoder14/mha/V/reshape): OnnxReshape()
  (encoder14/mha/V/transpose): OnnxTranspose()
  (encoder14/mha/QK/matmul): OnnxMatMul()
  (encoder14/mha/QK/scale): OnnxBinaryMathOperation()
  (encoder14/smolgen/compress): OnnxMatMul()
  (encoder14/smolgen/compress/reshape): OnnxReshape()
  (encoder14/smolgen/dense1/w): OnnxMatMul()
  (encoder14/smolgen/dense1/b): OnnxBinaryMathOperation()
  (encoder14/smolgen/dense1/swish/sigmoid): Sigmoid()
  (encoder14/smolgen/dense1/swish): OnnxBinaryMathOperation()
  (encoder14/smolgen/ln1): LayerNorm((256,), eps=0.0010000000474974513, elementwise_affine=True)
  (encoder14/smolgen/dense2/w): OnnxMatMul()
  (encoder14/smolgen/dense2/b): OnnxBinaryMathOperation()
  (encoder14/smolgen/dense2/swish/sigmoid): Sigmoid()
  (encoder14/smolgen/dense2/swish): OnnxBinaryMathOperation()
  (encoder14/smolgen/ln2): LayerNorm((6144,), eps=0.0010000000474974513, elementwise_affine=True)
  (encoder14/smolgen/gen_from/reshape): OnnxReshape()
  (encoder14/smolgen/smol_weight_gen): OnnxMatMul()
  (encoder14/smolgen/out/reshape): OnnxReshape()
  (encoder14/smolgen_weights): OnnxBinaryMathOperation()
  (encoder14/mha/QK/softmax): Softmax(dim=3)
  (encoder14/mha/QKV/matmul): OnnxMatMul()
  (encoder14/mha/out/transpose): OnnxTranspose()
  (encoder14/mha/out/reshape): OnnxReshape()
  (encoder14/mha/out/dense/w): OnnxMatMul()
  (encoder14/mha/out/dense/b): OnnxBinaryMathOperation()
  (encoder14/alpha*input): OnnxBinaryMathOperation()
  (encoder14/mha/out/skip): OnnxBinaryMathOperation()
  (encoder14/ln1): LayerNorm((768,), eps=9.999999974752427e-07, elementwise_affine=True)
  (encoder14/ffn/dense1/w): OnnxMatMul()
  (encoder14/ffn/dense1/b): OnnxBinaryMathOperation()
  (encoder14/ffn/dense1/sqrrelu/relu): ReLU()
  (encoder14/ffn/dense1/sqrrelu/sqr): OnnxBinaryMathOperation()
  (encoder14/ffn/dense2/w): OnnxMatMul()
  (encoder14/ffn/dense2/b): OnnxBinaryMathOperation()
  (encoder14/alpha*out1): OnnxBinaryMathOperation()
  (encoder14/ffn/skip): OnnxBinaryMathOperation()
  (encoder14/ln2): LayerNorm((768,), eps=9.999999974752427e-07, elementwise_affine=True)
  (policy/dense1/matmul): OnnxMatMul()
  (policy/dense1/add): OnnxBinaryMathOperation()
  (policy/dense1/mish/softplus): Softplus(beta=1.0, threshold=20.0)
  (policy/dense1/mish/tanh): OnnxFunction()
  (policy/dense1/mish): OnnxBinaryMathOperation()
  (policy/Q/matmul): OnnxMatMul()
  (policy/Q/add): OnnxBinaryMathOperation()
  (policy/Q/reshape): OnnxReshape()
  (policy/K/matmul): OnnxMatMul()
  (policy/K/add): OnnxBinaryMathOperation()
  (policy/K/reshape): OnnxReshape()
  (policy/K/transpose): OnnxTranspose()
  (policy/matmul): OnnxMatMul()
  (policy/scale): OnnxBinaryMathOperation()
  (policy/promotion/slice): OnnxSlice()
  (policy/promotion/matmul): OnnxMatMul()
  (policy/promotion/transpose): OnnxTranspose()
  (policy/promotion/split): OnnxSplit13()
  (policy/promotion/add): OnnxBinaryMathOperation()
  (policy/promotion/transpose2): OnnxTranspose()
  (policy/promotion/reshape): OnnxReshape()
  (policy/promotion/slice2): OnnxSlice()
  (policy/promotion/reshape2): OnnxReshape()
  (policy/promotion/concat): OnnxConcat()
  (policy/promotion/reshape3): OnnxReshape()
  (policy/promotion/add2): OnnxBinaryMathOperation()
  (policy/promotion/reshape4): OnnxReshape()
  (policy/concat): OnnxConcat()
  (policy/reshape): OnnxReshape()
  (output/policy): OnnxGather()
  (value/embed/matmul): OnnxMatMul()
  (value/embed/add): OnnxBinaryMathOperation()
  (value/embed/mish/softplus): Softplus(beta=1.0, threshold=20.0)
  (value/embed/mish/tanh): OnnxFunction()
  (value/embed/mish): OnnxBinaryMathOperation()
  (value/reshape): OnnxReshape()
  (value/dense1/matmul): OnnxMatMul()
  (value/dense1/add): OnnxBinaryMathOperation()
  (value/dense1/mish/softplus): Softplus(beta=1.0, threshold=20.0)
  (value/dense1/mish/tanh): OnnxFunction()
  (value/dense1/mish): OnnxBinaryMathOperation()
  (value/dense2/matmul): OnnxMatMul()
  (value/dense2/add): OnnxBinaryMathOperation()
  (output/wdl): Softmax(dim=1)
  (mlh/embed/matmul): OnnxMatMul()
  (mlh/embed/add): OnnxBinaryMathOperation()
  (mlh/embed/mish/softplus): Softplus(beta=1.0, threshold=20.0)
  (mlh/embed/mish/tanh): OnnxFunction()
  (mlh/embed/mish): OnnxBinaryMathOperation()
  (mlh/reshape): OnnxReshape()
  (mlh/dense1/matmul): OnnxMatMul()
  (mlh/dense1/add): OnnxBinaryMathOperation()
  (mlh/dense1/mish/softplus): Softplus(beta=1.0, threshold=20.0)
  (mlh/dense1/mish/tanh): OnnxFunction()
  (mlh/dense1/mish): OnnxBinaryMathOperation()
  (mlh/dense2/matmul): OnnxMatMul()
  (mlh/dense2/add): OnnxBinaryMathOperation()
  (mlh/dense2/mish/softplus): Softplus(beta=1.0, threshold=20.0)
  (mlh/dense2/mish/tanh): OnnxFunction()
  (mlh/dense2/mish): OnnxBinaryMathOperation()
  (output/mlh): OnnxCopyIdentity()
)
[5]:
from lczerolens import LczeroBoard
from lczerolens.lenses import ActivationLens

board = LczeroBoard(fen="1rb1rbk1/2qn1p1p/p2p2p1/1ppPp2n/PP2P3/2P1BN1P/R1BN1PP1/3QR1K1 w - - 0 22")
lens = ActivationLens("encoder\d+/mha/QK/softmax")
results = lens.analyse(transformer_model, board)
results.keys()
[5]:
dict_keys(['encoder0/mha/QK/softmax_output', 'encoder1/mha/QK/softmax_output', 'encoder2/mha/QK/softmax_output', 'encoder3/mha/QK/softmax_output', 'encoder4/mha/QK/softmax_output', 'encoder5/mha/QK/softmax_output', 'encoder6/mha/QK/softmax_output', 'encoder7/mha/QK/softmax_output', 'encoder8/mha/QK/softmax_output', 'encoder9/mha/QK/softmax_output', 'encoder10/mha/QK/softmax_output', 'encoder11/mha/QK/softmax_output', 'encoder12/mha/QK/softmax_output', 'encoder13/mha/QK/softmax_output', 'encoder14/mha/QK/softmax_output'])
[6]:
import chess
import IPython.display

batch_index = 0
layer = 9
head = 5
piece = chess.F3

attention_weights = results[f"encoder{layer}/mha/QK/softmax_output"][batch_index, head]
svg_board, svg_colorbar = board.render_heatmap(attention_weights[piece].detach())
display(IPython.display.HTML(f"{svg_board}{svg_colorbar}"))
. r b . r b k .
. . q n . p . p
p . . p . . p .
. p p P p . . n
P P . . P . . .
. . P . B N . P
R . B N . P P .
. . . Q R . K .
2025-01-15T16:40:56.698191 image/svg+xml Matplotlib v3.10.0, https://matplotlib.org/

Visualise Gradients#

[7]:
cnn_model = LczeroModel.from_path("lc0-19-4508.onnx")
cnn_model
[7]:
GraphModule(
  (inputconv): Conv2d(112, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (inputconv/relu): ReLU()
  (block0/conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (block0/conv1/relu): ReLU()
  (block0/conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (block0/conv2/se/pooled): OnnxGlobalAveragePoolWithKnownInputShape()
  (initializers): Module()
  (block0/conv2/se/squeeze): OnnxSqueezeDynamicAxes()
  (block0/conv2/se/matmul1): OnnxMatMul()
  (block0/conv2/se/add1): OnnxBinaryMathOperation()
  (block0/conv2/se/relu): ReLU()
  (block0/conv2/se/matmul2): OnnxMatMul()
  (block0/conv2/se/add2): OnnxBinaryMathOperation()
  (block0/conv2/se/reshape): OnnxReshape()
  (block0/conv2/se/split): OnnxSplit13()
  (block0/conv2/se/sigmoid): Sigmoid()
  (block0/conv2/se/mul): OnnxBinaryMathOperation()
  (block0/conv2/se/add3): OnnxBinaryMathOperation()
  (block0/conv2/mixin): OnnxBinaryMathOperation()
  (block0/conv2/relu): ReLU()
  (block1/conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (block1/conv1/relu): ReLU()
  (block1/conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (block1/conv2/se/pooled): OnnxGlobalAveragePoolWithKnownInputShape()
  (block1/conv2/se/squeeze): OnnxSqueezeDynamicAxes()
  (block1/conv2/se/matmul1): OnnxMatMul()
  (block1/conv2/se/add1): OnnxBinaryMathOperation()
  (block1/conv2/se/relu): ReLU()
  (block1/conv2/se/matmul2): OnnxMatMul()
  (block1/conv2/se/add2): OnnxBinaryMathOperation()
  (block1/conv2/se/reshape): OnnxReshape()
  (block1/conv2/se/split): OnnxSplit13()
  (block1/conv2/se/sigmoid): Sigmoid()
  (block1/conv2/se/mul): OnnxBinaryMathOperation()
  (block1/conv2/se/add3): OnnxBinaryMathOperation()
  (block1/conv2/mixin): OnnxBinaryMathOperation()
  (block1/conv2/relu): ReLU()
  (block2/conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (block2/conv1/relu): ReLU()
  (block2/conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (block2/conv2/se/pooled): OnnxGlobalAveragePoolWithKnownInputShape()
  (block2/conv2/se/squeeze): OnnxSqueezeDynamicAxes()
  (block2/conv2/se/matmul1): OnnxMatMul()
  (block2/conv2/se/add1): OnnxBinaryMathOperation()
  (block2/conv2/se/relu): ReLU()
  (block2/conv2/se/matmul2): OnnxMatMul()
  (block2/conv2/se/add2): OnnxBinaryMathOperation()
  (block2/conv2/se/reshape): OnnxReshape()
  (block2/conv2/se/split): OnnxSplit13()
  (block2/conv2/se/sigmoid): Sigmoid()
  (block2/conv2/se/mul): OnnxBinaryMathOperation()
  (block2/conv2/se/add3): OnnxBinaryMathOperation()
  (block2/conv2/mixin): OnnxBinaryMathOperation()
  (block2/conv2/relu): ReLU()
  (block3/conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (block3/conv1/relu): ReLU()
  (block3/conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (block3/conv2/se/pooled): OnnxGlobalAveragePoolWithKnownInputShape()
  (block3/conv2/se/squeeze): OnnxSqueezeDynamicAxes()
  (block3/conv2/se/matmul1): OnnxMatMul()
  (block3/conv2/se/add1): OnnxBinaryMathOperation()
  (block3/conv2/se/relu): ReLU()
  (block3/conv2/se/matmul2): OnnxMatMul()
  (block3/conv2/se/add2): OnnxBinaryMathOperation()
  (block3/conv2/se/reshape): OnnxReshape()
  (block3/conv2/se/split): OnnxSplit13()
  (block3/conv2/se/sigmoid): Sigmoid()
  (block3/conv2/se/mul): OnnxBinaryMathOperation()
  (block3/conv2/se/add3): OnnxBinaryMathOperation()
  (block3/conv2/mixin): OnnxBinaryMathOperation()
  (block3/conv2/relu): ReLU()
  (block4/conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (block4/conv1/relu): ReLU()
  (block4/conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (block4/conv2/se/pooled): OnnxGlobalAveragePoolWithKnownInputShape()
  (block4/conv2/se/squeeze): OnnxSqueezeDynamicAxes()
  (block4/conv2/se/matmul1): OnnxMatMul()
  (block4/conv2/se/add1): OnnxBinaryMathOperation()
  (block4/conv2/se/relu): ReLU()
  (block4/conv2/se/matmul2): OnnxMatMul()
  (block4/conv2/se/add2): OnnxBinaryMathOperation()
  (block4/conv2/se/reshape): OnnxReshape()
  (block4/conv2/se/split): OnnxSplit13()
  (block4/conv2/se/sigmoid): Sigmoid()
  (block4/conv2/se/mul): OnnxBinaryMathOperation()
  (block4/conv2/se/add3): OnnxBinaryMathOperation()
  (block4/conv2/mixin): OnnxBinaryMathOperation()
  (block4/conv2/relu): ReLU()
  (block5/conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (block5/conv1/relu): ReLU()
  (block5/conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (block5/conv2/se/pooled): OnnxGlobalAveragePoolWithKnownInputShape()
  (block5/conv2/se/squeeze): OnnxSqueezeDynamicAxes()
  (block5/conv2/se/matmul1): OnnxMatMul()
  (block5/conv2/se/add1): OnnxBinaryMathOperation()
  (block5/conv2/se/relu): ReLU()
  (block5/conv2/se/matmul2): OnnxMatMul()
  (block5/conv2/se/add2): OnnxBinaryMathOperation()
  (block5/conv2/se/reshape): OnnxReshape()
  (block5/conv2/se/split): OnnxSplit13()
  (block5/conv2/se/sigmoid): Sigmoid()
  (block5/conv2/se/mul): OnnxBinaryMathOperation()
  (block5/conv2/se/add3): OnnxBinaryMathOperation()
  (block5/conv2/mixin): OnnxBinaryMathOperation()
  (block5/conv2/relu): ReLU()
  (block6/conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (block6/conv1/relu): ReLU()
  (block6/conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (block6/conv2/se/pooled): OnnxGlobalAveragePoolWithKnownInputShape()
  (block6/conv2/se/squeeze): OnnxSqueezeDynamicAxes()
  (block6/conv2/se/matmul1): OnnxMatMul()
  (block6/conv2/se/add1): OnnxBinaryMathOperation()
  (block6/conv2/se/relu): ReLU()
  (block6/conv2/se/matmul2): OnnxMatMul()
  (block6/conv2/se/add2): OnnxBinaryMathOperation()
  (block6/conv2/se/reshape): OnnxReshape()
  (block6/conv2/se/split): OnnxSplit13()
  (block6/conv2/se/sigmoid): Sigmoid()
  (block6/conv2/se/mul): OnnxBinaryMathOperation()
  (block6/conv2/se/add3): OnnxBinaryMathOperation()
  (block6/conv2/mixin): OnnxBinaryMathOperation()
  (block6/conv2/relu): ReLU()
  (block7/conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (block7/conv1/relu): ReLU()
  (block7/conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (block7/conv2/se/pooled): OnnxGlobalAveragePoolWithKnownInputShape()
  (block7/conv2/se/squeeze): OnnxSqueezeDynamicAxes()
  (block7/conv2/se/matmul1): OnnxMatMul()
  (block7/conv2/se/add1): OnnxBinaryMathOperation()
  (block7/conv2/se/relu): ReLU()
  (block7/conv2/se/matmul2): OnnxMatMul()
  (block7/conv2/se/add2): OnnxBinaryMathOperation()
  (block7/conv2/se/reshape): OnnxReshape()
  (block7/conv2/se/split): OnnxSplit13()
  (block7/conv2/se/sigmoid): Sigmoid()
  (block7/conv2/se/mul): OnnxBinaryMathOperation()
  (block7/conv2/se/add3): OnnxBinaryMathOperation()
  (block7/conv2/mixin): OnnxBinaryMathOperation()
  (block7/conv2/relu): ReLU()
  (block8/conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (block8/conv1/relu): ReLU()
  (block8/conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (block8/conv2/se/pooled): OnnxGlobalAveragePoolWithKnownInputShape()
  (block8/conv2/se/squeeze): OnnxSqueezeDynamicAxes()
  (block8/conv2/se/matmul1): OnnxMatMul()
  (block8/conv2/se/add1): OnnxBinaryMathOperation()
  (block8/conv2/se/relu): ReLU()
  (block8/conv2/se/matmul2): OnnxMatMul()
  (block8/conv2/se/add2): OnnxBinaryMathOperation()
  (block8/conv2/se/reshape): OnnxReshape()
  (block8/conv2/se/split): OnnxSplit13()
  (block8/conv2/se/sigmoid): Sigmoid()
  (block8/conv2/se/mul): OnnxBinaryMathOperation()
  (block8/conv2/se/add3): OnnxBinaryMathOperation()
  (block8/conv2/mixin): OnnxBinaryMathOperation()
  (block8/conv2/relu): ReLU()
  (block9/conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (block9/conv1/relu): ReLU()
  (block9/conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (block9/conv2/se/pooled): OnnxGlobalAveragePoolWithKnownInputShape()
  (block9/conv2/se/squeeze): OnnxSqueezeDynamicAxes()
  (block9/conv2/se/matmul1): OnnxMatMul()
  (block9/conv2/se/add1): OnnxBinaryMathOperation()
  (block9/conv2/se/relu): ReLU()
  (block9/conv2/se/matmul2): OnnxMatMul()
  (block9/conv2/se/add2): OnnxBinaryMathOperation()
  (block9/conv2/se/reshape): OnnxReshape()
  (block9/conv2/se/split): OnnxSplit13()
  (block9/conv2/se/sigmoid): Sigmoid()
  (block9/conv2/se/mul): OnnxBinaryMathOperation()
  (block9/conv2/se/add3): OnnxBinaryMathOperation()
  (block9/conv2/mixin): OnnxBinaryMathOperation()
  (block9/conv2/relu): ReLU()
  (block10/conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (block10/conv1/relu): ReLU()
  (block10/conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (block10/conv2/se/pooled): OnnxGlobalAveragePoolWithKnownInputShape()
  (block10/conv2/se/squeeze): OnnxSqueezeDynamicAxes()
  (block10/conv2/se/matmul1): OnnxMatMul()
  (block10/conv2/se/add1): OnnxBinaryMathOperation()
  (block10/conv2/se/relu): ReLU()
  (block10/conv2/se/matmul2): OnnxMatMul()
  (block10/conv2/se/add2): OnnxBinaryMathOperation()
  (block10/conv2/se/reshape): OnnxReshape()
  (block10/conv2/se/split): OnnxSplit13()
  (block10/conv2/se/sigmoid): Sigmoid()
  (block10/conv2/se/mul): OnnxBinaryMathOperation()
  (block10/conv2/se/add3): OnnxBinaryMathOperation()
  (block10/conv2/mixin): OnnxBinaryMathOperation()
  (block10/conv2/relu): ReLU()
  (block11/conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (block11/conv1/relu): ReLU()
  (block11/conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (block11/conv2/se/pooled): OnnxGlobalAveragePoolWithKnownInputShape()
  (block11/conv2/se/squeeze): OnnxSqueezeDynamicAxes()
  (block11/conv2/se/matmul1): OnnxMatMul()
  (block11/conv2/se/add1): OnnxBinaryMathOperation()
  (block11/conv2/se/relu): ReLU()
  (block11/conv2/se/matmul2): OnnxMatMul()
  (block11/conv2/se/add2): OnnxBinaryMathOperation()
  (block11/conv2/se/reshape): OnnxReshape()
  (block11/conv2/se/split): OnnxSplit13()
  (block11/conv2/se/sigmoid): Sigmoid()
  (block11/conv2/se/mul): OnnxBinaryMathOperation()
  (block11/conv2/se/add3): OnnxBinaryMathOperation()
  (block11/conv2/mixin): OnnxBinaryMathOperation()
  (block11/conv2/relu): ReLU()
  (block12/conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (block12/conv1/relu): ReLU()
  (block12/conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (block12/conv2/se/pooled): OnnxGlobalAveragePoolWithKnownInputShape()
  (block12/conv2/se/squeeze): OnnxSqueezeDynamicAxes()
  (block12/conv2/se/matmul1): OnnxMatMul()
  (block12/conv2/se/add1): OnnxBinaryMathOperation()
  (block12/conv2/se/relu): ReLU()
  (block12/conv2/se/matmul2): OnnxMatMul()
  (block12/conv2/se/add2): OnnxBinaryMathOperation()
  (block12/conv2/se/reshape): OnnxReshape()
  (block12/conv2/se/split): OnnxSplit13()
  (block12/conv2/se/sigmoid): Sigmoid()
  (block12/conv2/se/mul): OnnxBinaryMathOperation()
  (block12/conv2/se/add3): OnnxBinaryMathOperation()
  (block12/conv2/mixin): OnnxBinaryMathOperation()
  (block12/conv2/relu): ReLU()
  (block13/conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (block13/conv1/relu): ReLU()
  (block13/conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (block13/conv2/se/pooled): OnnxGlobalAveragePoolWithKnownInputShape()
  (block13/conv2/se/squeeze): OnnxSqueezeDynamicAxes()
  (block13/conv2/se/matmul1): OnnxMatMul()
  (block13/conv2/se/add1): OnnxBinaryMathOperation()
  (block13/conv2/se/relu): ReLU()
  (block13/conv2/se/matmul2): OnnxMatMul()
  (block13/conv2/se/add2): OnnxBinaryMathOperation()
  (block13/conv2/se/reshape): OnnxReshape()
  (block13/conv2/se/split): OnnxSplit13()
  (block13/conv2/se/sigmoid): Sigmoid()
  (block13/conv2/se/mul): OnnxBinaryMathOperation()
  (block13/conv2/se/add3): OnnxBinaryMathOperation()
  (block13/conv2/mixin): OnnxBinaryMathOperation()
  (block13/conv2/relu): ReLU()
  (block14/conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (block14/conv1/relu): ReLU()
  (block14/conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (block14/conv2/se/pooled): OnnxGlobalAveragePoolWithKnownInputShape()
  (block14/conv2/se/squeeze): OnnxSqueezeDynamicAxes()
  (block14/conv2/se/matmul1): OnnxMatMul()
  (block14/conv2/se/add1): OnnxBinaryMathOperation()
  (block14/conv2/se/relu): ReLU()
  (block14/conv2/se/matmul2): OnnxMatMul()
  (block14/conv2/se/add2): OnnxBinaryMathOperation()
  (block14/conv2/se/reshape): OnnxReshape()
  (block14/conv2/se/split): OnnxSplit13()
  (block14/conv2/se/sigmoid): Sigmoid()
  (block14/conv2/se/mul): OnnxBinaryMathOperation()
  (block14/conv2/se/add3): OnnxBinaryMathOperation()
  (block14/conv2/mixin): OnnxBinaryMathOperation()
  (block14/conv2/relu): ReLU()
  (block15/conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (block15/conv1/relu): ReLU()
  (block15/conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (block15/conv2/se/pooled): OnnxGlobalAveragePoolWithKnownInputShape()
  (block15/conv2/se/squeeze): OnnxSqueezeDynamicAxes()
  (block15/conv2/se/matmul1): OnnxMatMul()
  (block15/conv2/se/add1): OnnxBinaryMathOperation()
  (block15/conv2/se/relu): ReLU()
  (block15/conv2/se/matmul2): OnnxMatMul()
  (block15/conv2/se/add2): OnnxBinaryMathOperation()
  (block15/conv2/se/reshape): OnnxReshape()
  (block15/conv2/se/split): OnnxSplit13()
  (block15/conv2/se/sigmoid): Sigmoid()
  (block15/conv2/se/mul): OnnxBinaryMathOperation()
  (block15/conv2/se/add3): OnnxBinaryMathOperation()
  (block15/conv2/mixin): OnnxBinaryMathOperation()
  (block15/conv2/relu): ReLU()
  (block16/conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (block16/conv1/relu): ReLU()
  (block16/conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (block16/conv2/se/pooled): OnnxGlobalAveragePoolWithKnownInputShape()
  (block16/conv2/se/squeeze): OnnxSqueezeDynamicAxes()
  (block16/conv2/se/matmul1): OnnxMatMul()
  (block16/conv2/se/add1): OnnxBinaryMathOperation()
  (block16/conv2/se/relu): ReLU()
  (block16/conv2/se/matmul2): OnnxMatMul()
  (block16/conv2/se/add2): OnnxBinaryMathOperation()
  (block16/conv2/se/reshape): OnnxReshape()
  (block16/conv2/se/split): OnnxSplit13()
  (block16/conv2/se/sigmoid): Sigmoid()
  (block16/conv2/se/mul): OnnxBinaryMathOperation()
  (block16/conv2/se/add3): OnnxBinaryMathOperation()
  (block16/conv2/mixin): OnnxBinaryMathOperation()
  (block16/conv2/relu): ReLU()
  (block17/conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (block17/conv1/relu): ReLU()
  (block17/conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (block17/conv2/se/pooled): OnnxGlobalAveragePoolWithKnownInputShape()
  (block17/conv2/se/squeeze): OnnxSqueezeDynamicAxes()
  (block17/conv2/se/matmul1): OnnxMatMul()
  (block17/conv2/se/add1): OnnxBinaryMathOperation()
  (block17/conv2/se/relu): ReLU()
  (block17/conv2/se/matmul2): OnnxMatMul()
  (block17/conv2/se/add2): OnnxBinaryMathOperation()
  (block17/conv2/se/reshape): OnnxReshape()
  (block17/conv2/se/split): OnnxSplit13()
  (block17/conv2/se/sigmoid): Sigmoid()
  (block17/conv2/se/mul): OnnxBinaryMathOperation()
  (block17/conv2/se/add3): OnnxBinaryMathOperation()
  (block17/conv2/mixin): OnnxBinaryMathOperation()
  (block17/conv2/relu): ReLU()
  (block18/conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (block18/conv1/relu): ReLU()
  (block18/conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (block18/conv2/se/pooled): OnnxGlobalAveragePoolWithKnownInputShape()
  (block18/conv2/se/squeeze): OnnxSqueezeDynamicAxes()
  (block18/conv2/se/matmul1): OnnxMatMul()
  (block18/conv2/se/add1): OnnxBinaryMathOperation()
  (block18/conv2/se/relu): ReLU()
  (block18/conv2/se/matmul2): OnnxMatMul()
  (block18/conv2/se/add2): OnnxBinaryMathOperation()
  (block18/conv2/se/reshape): OnnxReshape()
  (block18/conv2/se/split): OnnxSplit13()
  (block18/conv2/se/sigmoid): Sigmoid()
  (block18/conv2/se/mul): OnnxBinaryMathOperation()
  (block18/conv2/se/add3): OnnxBinaryMathOperation()
  (block18/conv2/mixin): OnnxBinaryMathOperation()
  (block18/conv2/relu): ReLU()
  (policy/conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (policy/conv1/relu): ReLU()
  (policy/conv2): Conv2d(256, 80, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (policy/flatten): OnnxReshape()
  (output/policy): OnnxGather()
  (value/conv): Conv2d(256, 32, kernel_size=(1, 1), stride=(1, 1))
  (value/conv/relu): ReLU()
  (value/reshape): OnnxReshape()
  (value/dense1/matmul): OnnxMatMul()
  (value/dense1/add): OnnxBinaryMathOperation()
  (value/dense1/relu): ReLU()
  (value/dense2/matmul): OnnxMatMul()
  (value/dense2/add): OnnxBinaryMathOperation()
  (output/wdl): Softmax(dim=1)
  (mlh/conv): Conv2d(256, 8, kernel_size=(1, 1), stride=(1, 1))
  (mlh/conv/relu): ReLU()
  (mlh/reshape): OnnxReshape()
  (mlh/dense1/matmul): OnnxMatMul()
  (mlh/dense1/add): OnnxBinaryMathOperation()
  (mlh/dense1/relu): ReLU()
  (mlh/dense2/matmul): OnnxMatMul()
  (mlh/dense2/add): OnnxBinaryMathOperation()
  (mlh/dense2/relu): ReLU()
  (output/mlh): OnnxCopyIdentity()
)
[8]:
from lczerolens.lenses import GradientLens

lens = GradientLens(pattern="output/wdl|value/dense2/add")


def init_target(model):
    return getattr(model, "output/wdl").output[:, 0]


results = lens.analyse(cnn_model, board, init_target=init_target)
results.keys()
[8]:
dict_keys(['input_grad', 'value/dense2/add_output_grad', 'output/wdl_output_grad'])
[9]:
batch_index = 0
plane = 1  # N

svg_board, svg_colorbar = board.render_heatmap(
    results["input_grad"][batch_index, plane].view(64).detach(), normalise="abs"
)
display(IPython.display.HTML(f"{svg_board}{svg_colorbar}"))
. r b . r b k .
. . q n . p . p
p . . p . . p .
. p p P p . . n
P P . . P . . .
. . P . B N . P
R . B N . P P .
. . . Q R . K .
2025-01-15T16:40:57.991884 image/svg+xml Matplotlib v3.10.0, https://matplotlib.org/
[10]:
gap_input_grad = results["input_grad"][:, :12].mean(dim=1)

svg_board, svg_colorbar = board.render_heatmap(gap_input_grad[batch_index].view(64).detach(), normalise="abs")
display(IPython.display.HTML(f"{svg_board}{svg_colorbar}"))
. r b . r b k .
. . q n . p . p
p . . p . . p .
. p p P p . . n
P P . . P . . .
. . P . B N . P
R . B N . P P .
. . . Q R . K .
2025-01-15T16:40:58.045298 image/svg+xml Matplotlib v3.10.0, https://matplotlib.org/

GradCAM#

[11]:
from lczerolens.lenses import CompositeLens

board = LczeroBoard(fen="1rb1rbk1/2qn1p1p/p2p2p1/1ppPp2n/PP2P3/2P1BN1P/R1BN1PP1/3QR1K1 w - - 0 22")
pattern = "block18/conv2/relu|value/conv|policy/conv\d"

act_lens = ActivationLens(pattern)
grad_lens = GradientLens(pattern=pattern)
lens = CompositeLens([act_lens, grad_lens])
[12]:
import einops
import torch.nn.functional as F


def compute_cam_heatmap(values, gradients):
    weights = einops.reduce(gradients, "b c h w -> b c", "mean")
    heatmap = einops.einsum(values, weights, "b c h w, b c -> b h w")
    return F.relu(heatmap)

Loking at the value#

[13]:
import IPython.display

wdl_index = 0  # 0: win, 1: draw, 2: loss
block = "value/conv"  # "block18/conv2/relu" | "value/conv"


def init_target(model):
    return getattr(model, "value/dense2/add").output[:, wdl_index]


results = lens.analyse(cnn_model, board, init_target=init_target)

heatmap = compute_cam_heatmap(results[f"{block}_output"], results[f"{block}_output_grad"])
svg_board, svg_colorbar = board.render_heatmap(heatmap.view(64).detach(), normalise="abs")
display(IPython.display.HTML(f"{svg_board}{svg_colorbar}"))
. r b . r b k .
. . q n . p . p
p . . p . . p .
. p p P p . . n
P P . . P . . .
. . P . B N . P
R . B N . P P .
. . . Q R . K .
2025-01-15T16:40:58.163233 image/svg+xml Matplotlib v3.10.0, https://matplotlib.org/

Looking at the policy#

[14]:
import torch
from lczerolens.play.sampling import PolicySampler

policy_sampler = PolicySampler(model=cnn_model)
utilities, legal_indices, _ = next(iter(policy_sampler.get_utilities([board])))

topk_indices = torch.topk(utilities, k=3).indices
utilities[topk_indices], legal_indices[topk_indices]
[14]:
(tensor([2.9891, 2.5723, 2.2822]), tensor([286, 668,  72]))
[15]:
topk_index = 0

demo_board = LczeroBoard(fen="1rb1rbk1/2qn1p1p/p2p2p1/1ppPp2n/PP2P3/2P1BN1P/R1BN1PP1/3QR1K1 w - - 0 22")
move = demo_board.decode_move(legal_indices[topk_indices[topk_index]])
demo_board.push(move)
display(demo_board)
../../_images/notebooks_features_visualise-heatmaps_21_0.svg
[16]:
import IPython.display

policy_index = legal_indices[topk_indices[topk_index]]
block = "policy/conv2"  # "block18/conv2/relu" | "policy/conv1" | "policy/conv2"


def init_target(model):
    return getattr(model, "output/policy").output[:, policy_index]


results = lens.analyse(cnn_model, board, init_target=init_target)

heatmap = compute_cam_heatmap(results[f"{block}_output"], results[f"{block}_output_grad"])
svg_board, svg_colorbar = board.render_heatmap(heatmap.view(64).detach(), normalise="abs")
display(IPython.display.HTML(f"{svg_board}{svg_colorbar}"))
. r b . r b k .
. . q n . p . p
p . . p . . p .
. p p P p . . n
P P . . P . . .
. . P . B N . P
R . B N . P P .
. . . Q R . K .
2025-01-15T16:40:58.321063 image/svg+xml Matplotlib v3.10.0, https://matplotlib.org/
[ ]: