Evidence of Learned Look-Ahead#

Open In Colab

Setup#

[1]:
MODE = "local"  # "colab" | "colab-dev" | "local"
[2]:
if MODE == "colab":
    !pip install -q lczerolens
elif MODE == "colab-dev":
    !rm -r lczerolens
    !git clone https://github.com/Xmaster6y/lczerolens -b main
    !pip install -q ./lczerolens
[3]:
!wget https://figshare.com/ndownloader/files/46473526?private_link=adc80845c00b67c8fce5 -O interesting_puzzles.pkl
!wget https://figshare.com/ndownloader/files/46473529?private_link=adc80845c00b67c8fce5 -O lc0.onnx

# When `wget` fail, e.g., "403 Forbidden"
# !pip install gdown
# !gdown https://drive.google.com/uc?id=1GT6I7FAgxWIxA-tzsifBQx0MkKZcR_qz -O interesting_puzzles.pkl
# !gdown https://drive.google.com/uc?id=1PB097ZKd_zTaPHxLK29WKUWmv6KcZ15T -O lc0.onnx
zsh:1: no matches found: https://figshare.com/ndownloader/files/46473526?private_link=adc80845c00b67c8fce5
zsh:1: no matches found: https://figshare.com/ndownloader/files/46473529?private_link=adc80845c00b67c8fce5

Checking Assets#

[4]:
import pickle
import chess
[5]:
with open("interesting_puzzles.pkl", "rb") as f:
    puzzles = pickle.load(f)
puzzles.head()
[5]:
PuzzleId FEN Moves Rating RatingDeviation Popularity NbPlays Themes GameUrl OpeningTags principal_variation full_pv_probs full_model_moves full_wdl sparring_full_pv_probs sparring_full_model_moves sparring_wdl different_targets corrupted_fen
22 001w5 1rb2rk1/q5P1/4p2p/3p3p/3P1P2/2P5/2QK3P/3R2R1 b... f8f7 c2h7 g8h7 g7g8q 1073 77 91 189 advancedPawn attraction mate mateIn2 middlegam... https://lichess.org/0e1vxAEn/black#57 NaN [c2h7, g8h7, g7g8q] [0.7599087357521057, 1.0, 0.9156637191772461] [c2h7, g8h7, g7g8q] [0.561260461807251, 0.18697379529476166, 0.251... [0.006373010575771332, 1.0, 0.15674205124378204] [d1a1, g8h7, g7g8q] [0.0053426711820065975, 0.9625791907310486, 0.... False 1r4k1/q4rP1/2b1p2p/3p3p/3P1P2/2P5/2QK3P/3R2R1 ...
111 006wz 2r5/4ppkp/5bp1/1p6/1P6/P3B3/2r2PPP/1R1R2K1 b -... f6b2 b1b2 c2b2 e3d4 f7f6 d4b2 1515 75 73 527 attraction crushing endgame fork long sacrifice https://lichess.org/qT0W6o27/black#43 NaN [b1b2, c2b2, e3d4, f7f6, d4b2] [0.8258911967277527, 0.4339713156223297, 0.935... [b1b2, c2b2, e3d4, f7f6, d4b2] [0.816787600517273, 0.16938228905200958, 0.013... [0.04489869996905327, 0.881671667098999, 0.104... [d1d7, c2b2, e3d4, e7e5, d4b2] [0.0032313962001353502, 0.9889336824417114, 0.... False 2r5/4pp1p/6p1/1p3k2/1P6/P3B3/1br2PPP/1R1R2K1 w...
116 00761 3r2k1/1b3pbR/p2P2P1/3p2N1/2p5/2P2N2/PP6/2K5 b ... f7g6 h7g7 g8g7 g5e6 g7g8 e6d8 1512 75 94 17597 attraction crushing endgame exposedKing fork l... https://lichess.org/vu70Maig/black#55 NaN [h7g7, g8g7, g5e6, g7g8, e6d8] [0.7690635919570923, 0.81597501039505, 0.94436... [h7g7, g8g7, g5e6, g7f6, e6d8] [0.787416398525238, 0.18706896901130676, 0.025... [0.014131303876638412, 0.8102595806121826, 0.4... [f3d4, g8g7, g5e6, g7f6, e6d8] [0.005333698820322752, 0.9863110184669495, 0.0... False 2br2k1/6bR/p2P2p1/3p2N1/2p5/2P2N2/PP6/2K5 w - ...
170 00AoZ 8/1R6/p1pk4/6bp/1QP5/P7/KP6/3r2q1 b - - 2 44 g1c5 b7d7 d6d7 b4c5 1023 77 93 3354 advantage deflection endgame short https://lichess.org/356BAYqk/black#87 NaN [b7d7, d6d7, b4c5] [0.9532407522201538, 0.8907768130302429, 0.966... [b7d7, d6d7, b4c5] [0.8298567533493042, 0.16442830860614777, 0.00... [0.03029155358672142, 0.9142314195632935, 0.81... [b4b3, d6d7, b4c5] [0.01180175319314003, 0.947725236415863, 0.040... False 8/1R6/p1pk4/2q4p/1QP5/P6b/KP6/3r4 w - - 3 45
182 00Bg4 3r2k1/1q3ppp/p2rp3/Qp1B4/7P/P4P2/1PP3P1/1K1R3R... d6d5 a5d8 d5d8 d1d8 1374 85 75 303 backRankMate endgame mate mateIn2 short xRayAt... https://lichess.org/6qWf8wOP/black#41 NaN [a5d8, d5d8, d1d8] [0.9316630959510803, 1.0, 0.9474834203720093] [a5d8, d5d8, d1d8] [0.9491999745368958, 0.04444213956594467, 0.00... [0.008508995175361633, 1.0, 0.772297203540802] [d1d5, d5d8, d1d8] [0.007286431733518839, 0.983742892742157, 0.00... False 3r4/1q3ppp/p3p3/Qp1r4/5k1P/P4P2/1PP3P1/1K1R3R ...
[6]:
from lczerolens import LczeroModel

model = LczeroModel.from_path("lc0.onnx")
model
/Users/xmaster/Work/Chess/lczerolens/.venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
[6]:
GraphModule(
  (attn_body/transpose): OnnxTranspose()
  (initializers): Module()
  (attn_body/reshape): OnnxReshape()
  (attn_body/shape): OnnxShape()
  (attn_body/batch): OnnxSlice()
  (attn_body/pos_encoding_shape): OnnxConcat()
  (attn_body/expand): OnnxExpand()
  (attn_body/padded_input): OnnxConcat()
  (attn_body/reshape2): OnnxReshape()
  (attn_body/matmul): OnnxMatMul()
  (attn_body/add): OnnxBinaryMathOperation()
  (attn_body/mish/softplus): Softplus(beta=1.0, threshold=20.0)
  (attn_body/mish/tanh): OnnxFunction()
  (attn_body/mish): OnnxBinaryMathOperation()
  (attn_body/ma_gating/rehape1): OnnxReshape()
  (ip_mul_gate): OnnxBinaryMathOperation()
  (ip_add_gate): OnnxBinaryMathOperation()
  (attn_body/ma_gating/rehape2): OnnxReshape()
  (encoder0/mha/Q/w): OnnxMatMul()
  (encoder0/mha/Q/b): OnnxBinaryMathOperation()
  (encoder0/mha/Q/reshape): OnnxReshape()
  (encoder0/mha/Q/transpose): OnnxTranspose()
  (encoder0/mha/K/w): OnnxMatMul()
  (encoder0/mha/K/b): OnnxBinaryMathOperation()
  (encoder0/mha/K/reshape): OnnxReshape()
  (encoder0/mha/K/transpose): OnnxTranspose()
  (encoder0/mha/V/w): OnnxMatMul()
  (encoder0/mha/V/b): OnnxBinaryMathOperation()
  (encoder0/mha/V/reshape): OnnxReshape()
  (encoder0/mha/V/transpose): OnnxTranspose()
  (encoder0/mha/QK/matmul): OnnxMatMul()
  (encoder0/mha/QK/scale): OnnxBinaryMathOperation()
  (encoder0/smolgen/compress): OnnxMatMul()
  (encoder0/smolgen/compress/reshape): OnnxReshape()
  (encoder0/smolgen/dense1/w): OnnxMatMul()
  (encoder0/smolgen/dense1/b): OnnxBinaryMathOperation()
  (encoder0/smolgen/dense1/swish/sigmoid): Sigmoid()
  (encoder0/smolgen/dense1/swish): OnnxBinaryMathOperation()
  (encoder0/smolgen/ln1): LayerNorm((256,), eps=0.0010000000474974513, elementwise_affine=True)
  (encoder0/smolgen/dense2/w): OnnxMatMul()
  (encoder0/smolgen/dense2/b): OnnxBinaryMathOperation()
  (encoder0/smolgen/dense2/swish/sigmoid): Sigmoid()
  (encoder0/smolgen/dense2/swish): OnnxBinaryMathOperation()
  (encoder0/smolgen/ln2): LayerNorm((6144,), eps=0.0010000000474974513, elementwise_affine=True)
  (encoder0/smolgen/gen_from/reshape): OnnxReshape()
  (encoder0/smolgen/smol_weight_gen): OnnxMatMul()
  (encoder0/smolgen/out/reshape): OnnxReshape()
  (encoder0/smolgen_weights): OnnxBinaryMathOperation()
  (encoder0/mha/QK/softmax): Softmax(dim=3)
  (encoder0/mha/QKV/matmul): OnnxMatMul()
  (encoder0/mha/out/transpose): OnnxTranspose()
  (encoder0/mha/out/reshape): OnnxReshape()
  (encoder0/mha/out/dense/w): OnnxMatMul()
  (encoder0/mha/out/dense/b): OnnxBinaryMathOperation()
  (encoder0/alpha*input): OnnxBinaryMathOperation()
  (encoder0/mha/out/skip): OnnxBinaryMathOperation()
  (encoder0/ln1): LayerNorm((768,), eps=9.999999974752427e-07, elementwise_affine=True)
  (encoder0/ffn/dense1/w): OnnxMatMul()
  (encoder0/ffn/dense1/b): OnnxBinaryMathOperation()
  (encoder0/ffn/dense1/sqrrelu/relu): ReLU()
  (encoder0/ffn/dense1/sqrrelu/sqr): OnnxBinaryMathOperation()
  (encoder0/ffn/dense2/w): OnnxMatMul()
  (encoder0/ffn/dense2/b): OnnxBinaryMathOperation()
  (encoder0/alpha*out1): OnnxBinaryMathOperation()
  (encoder0/ffn/skip): OnnxBinaryMathOperation()
  (encoder0/ln2): LayerNorm((768,), eps=9.999999974752427e-07, elementwise_affine=True)
  (encoder1/mha/Q/w): OnnxMatMul()
  (encoder1/mha/Q/b): OnnxBinaryMathOperation()
  (encoder1/mha/Q/reshape): OnnxReshape()
  (encoder1/mha/Q/transpose): OnnxTranspose()
  (encoder1/mha/K/w): OnnxMatMul()
  (encoder1/mha/K/b): OnnxBinaryMathOperation()
  (encoder1/mha/K/reshape): OnnxReshape()
  (encoder1/mha/K/transpose): OnnxTranspose()
  (encoder1/mha/V/w): OnnxMatMul()
  (encoder1/mha/V/b): OnnxBinaryMathOperation()
  (encoder1/mha/V/reshape): OnnxReshape()
  (encoder1/mha/V/transpose): OnnxTranspose()
  (encoder1/mha/QK/matmul): OnnxMatMul()
  (encoder1/mha/QK/scale): OnnxBinaryMathOperation()
  (encoder1/smolgen/compress): OnnxMatMul()
  (encoder1/smolgen/compress/reshape): OnnxReshape()
  (encoder1/smolgen/dense1/w): OnnxMatMul()
  (encoder1/smolgen/dense1/b): OnnxBinaryMathOperation()
  (encoder1/smolgen/dense1/swish/sigmoid): Sigmoid()
  (encoder1/smolgen/dense1/swish): OnnxBinaryMathOperation()
  (encoder1/smolgen/ln1): LayerNorm((256,), eps=0.0010000000474974513, elementwise_affine=True)
  (encoder1/smolgen/dense2/w): OnnxMatMul()
  (encoder1/smolgen/dense2/b): OnnxBinaryMathOperation()
  (encoder1/smolgen/dense2/swish/sigmoid): Sigmoid()
  (encoder1/smolgen/dense2/swish): OnnxBinaryMathOperation()
  (encoder1/smolgen/ln2): LayerNorm((6144,), eps=0.0010000000474974513, elementwise_affine=True)
  (encoder1/smolgen/gen_from/reshape): OnnxReshape()
  (encoder1/smolgen/smol_weight_gen): OnnxMatMul()
  (encoder1/smolgen/out/reshape): OnnxReshape()
  (encoder1/smolgen_weights): OnnxBinaryMathOperation()
  (encoder1/mha/QK/softmax): Softmax(dim=3)
  (encoder1/mha/QKV/matmul): OnnxMatMul()
  (encoder1/mha/out/transpose): OnnxTranspose()
  (encoder1/mha/out/reshape): OnnxReshape()
  (encoder1/mha/out/dense/w): OnnxMatMul()
  (encoder1/mha/out/dense/b): OnnxBinaryMathOperation()
  (encoder1/alpha*input): OnnxBinaryMathOperation()
  (encoder1/mha/out/skip): OnnxBinaryMathOperation()
  (encoder1/ln1): LayerNorm((768,), eps=9.999999974752427e-07, elementwise_affine=True)
  (encoder1/ffn/dense1/w): OnnxMatMul()
  (encoder1/ffn/dense1/b): OnnxBinaryMathOperation()
  (encoder1/ffn/dense1/sqrrelu/relu): ReLU()
  (encoder1/ffn/dense1/sqrrelu/sqr): OnnxBinaryMathOperation()
  (encoder1/ffn/dense2/w): OnnxMatMul()
  (encoder1/ffn/dense2/b): OnnxBinaryMathOperation()
  (encoder1/alpha*out1): OnnxBinaryMathOperation()
  (encoder1/ffn/skip): OnnxBinaryMathOperation()
  (encoder1/ln2): LayerNorm((768,), eps=9.999999974752427e-07, elementwise_affine=True)
  (encoder2/mha/Q/w): OnnxMatMul()
  (encoder2/mha/Q/b): OnnxBinaryMathOperation()
  (encoder2/mha/Q/reshape): OnnxReshape()
  (encoder2/mha/Q/transpose): OnnxTranspose()
  (encoder2/mha/K/w): OnnxMatMul()
  (encoder2/mha/K/b): OnnxBinaryMathOperation()
  (encoder2/mha/K/reshape): OnnxReshape()
  (encoder2/mha/K/transpose): OnnxTranspose()
  (encoder2/mha/V/w): OnnxMatMul()
  (encoder2/mha/V/b): OnnxBinaryMathOperation()
  (encoder2/mha/V/reshape): OnnxReshape()
  (encoder2/mha/V/transpose): OnnxTranspose()
  (encoder2/mha/QK/matmul): OnnxMatMul()
  (encoder2/mha/QK/scale): OnnxBinaryMathOperation()
  (encoder2/smolgen/compress): OnnxMatMul()
  (encoder2/smolgen/compress/reshape): OnnxReshape()
  (encoder2/smolgen/dense1/w): OnnxMatMul()
  (encoder2/smolgen/dense1/b): OnnxBinaryMathOperation()
  (encoder2/smolgen/dense1/swish/sigmoid): Sigmoid()
  (encoder2/smolgen/dense1/swish): OnnxBinaryMathOperation()
  (encoder2/smolgen/ln1): LayerNorm((256,), eps=0.0010000000474974513, elementwise_affine=True)
  (encoder2/smolgen/dense2/w): OnnxMatMul()
  (encoder2/smolgen/dense2/b): OnnxBinaryMathOperation()
  (encoder2/smolgen/dense2/swish/sigmoid): Sigmoid()
  (encoder2/smolgen/dense2/swish): OnnxBinaryMathOperation()
  (encoder2/smolgen/ln2): LayerNorm((6144,), eps=0.0010000000474974513, elementwise_affine=True)
  (encoder2/smolgen/gen_from/reshape): OnnxReshape()
  (encoder2/smolgen/smol_weight_gen): OnnxMatMul()
  (encoder2/smolgen/out/reshape): OnnxReshape()
  (encoder2/smolgen_weights): OnnxBinaryMathOperation()
  (encoder2/mha/QK/softmax): Softmax(dim=3)
  (encoder2/mha/QKV/matmul): OnnxMatMul()
  (encoder2/mha/out/transpose): OnnxTranspose()
  (encoder2/mha/out/reshape): OnnxReshape()
  (encoder2/mha/out/dense/w): OnnxMatMul()
  (encoder2/mha/out/dense/b): OnnxBinaryMathOperation()
  (encoder2/alpha*input): OnnxBinaryMathOperation()
  (encoder2/mha/out/skip): OnnxBinaryMathOperation()
  (encoder2/ln1): LayerNorm((768,), eps=9.999999974752427e-07, elementwise_affine=True)
  (encoder2/ffn/dense1/w): OnnxMatMul()
  (encoder2/ffn/dense1/b): OnnxBinaryMathOperation()
  (encoder2/ffn/dense1/sqrrelu/relu): ReLU()
  (encoder2/ffn/dense1/sqrrelu/sqr): OnnxBinaryMathOperation()
  (encoder2/ffn/dense2/w): OnnxMatMul()
  (encoder2/ffn/dense2/b): OnnxBinaryMathOperation()
  (encoder2/alpha*out1): OnnxBinaryMathOperation()
  (encoder2/ffn/skip): OnnxBinaryMathOperation()
  (encoder2/ln2): LayerNorm((768,), eps=9.999999974752427e-07, elementwise_affine=True)
  (encoder3/mha/Q/w): OnnxMatMul()
  (encoder3/mha/Q/b): OnnxBinaryMathOperation()
  (encoder3/mha/Q/reshape): OnnxReshape()
  (encoder3/mha/Q/transpose): OnnxTranspose()
  (encoder3/mha/K/w): OnnxMatMul()
  (encoder3/mha/K/b): OnnxBinaryMathOperation()
  (encoder3/mha/K/reshape): OnnxReshape()
  (encoder3/mha/K/transpose): OnnxTranspose()
  (encoder3/mha/V/w): OnnxMatMul()
  (encoder3/mha/V/b): OnnxBinaryMathOperation()
  (encoder3/mha/V/reshape): OnnxReshape()
  (encoder3/mha/V/transpose): OnnxTranspose()
  (encoder3/mha/QK/matmul): OnnxMatMul()
  (encoder3/mha/QK/scale): OnnxBinaryMathOperation()
  (encoder3/smolgen/compress): OnnxMatMul()
  (encoder3/smolgen/compress/reshape): OnnxReshape()
  (encoder3/smolgen/dense1/w): OnnxMatMul()
  (encoder3/smolgen/dense1/b): OnnxBinaryMathOperation()
  (encoder3/smolgen/dense1/swish/sigmoid): Sigmoid()
  (encoder3/smolgen/dense1/swish): OnnxBinaryMathOperation()
  (encoder3/smolgen/ln1): LayerNorm((256,), eps=0.0010000000474974513, elementwise_affine=True)
  (encoder3/smolgen/dense2/w): OnnxMatMul()
  (encoder3/smolgen/dense2/b): OnnxBinaryMathOperation()
  (encoder3/smolgen/dense2/swish/sigmoid): Sigmoid()
  (encoder3/smolgen/dense2/swish): OnnxBinaryMathOperation()
  (encoder3/smolgen/ln2): LayerNorm((6144,), eps=0.0010000000474974513, elementwise_affine=True)
  (encoder3/smolgen/gen_from/reshape): OnnxReshape()
  (encoder3/smolgen/smol_weight_gen): OnnxMatMul()
  (encoder3/smolgen/out/reshape): OnnxReshape()
  (encoder3/smolgen_weights): OnnxBinaryMathOperation()
  (encoder3/mha/QK/softmax): Softmax(dim=3)
  (encoder3/mha/QKV/matmul): OnnxMatMul()
  (encoder3/mha/out/transpose): OnnxTranspose()
  (encoder3/mha/out/reshape): OnnxReshape()
  (encoder3/mha/out/dense/w): OnnxMatMul()
  (encoder3/mha/out/dense/b): OnnxBinaryMathOperation()
  (encoder3/alpha*input): OnnxBinaryMathOperation()
  (encoder3/mha/out/skip): OnnxBinaryMathOperation()
  (encoder3/ln1): LayerNorm((768,), eps=9.999999974752427e-07, elementwise_affine=True)
  (encoder3/ffn/dense1/w): OnnxMatMul()
  (encoder3/ffn/dense1/b): OnnxBinaryMathOperation()
  (encoder3/ffn/dense1/sqrrelu/relu): ReLU()
  (encoder3/ffn/dense1/sqrrelu/sqr): OnnxBinaryMathOperation()
  (encoder3/ffn/dense2/w): OnnxMatMul()
  (encoder3/ffn/dense2/b): OnnxBinaryMathOperation()
  (encoder3/alpha*out1): OnnxBinaryMathOperation()
  (encoder3/ffn/skip): OnnxBinaryMathOperation()
  (encoder3/ln2): LayerNorm((768,), eps=9.999999974752427e-07, elementwise_affine=True)
  (encoder4/mha/Q/w): OnnxMatMul()
  (encoder4/mha/Q/b): OnnxBinaryMathOperation()
  (encoder4/mha/Q/reshape): OnnxReshape()
  (encoder4/mha/Q/transpose): OnnxTranspose()
  (encoder4/mha/K/w): OnnxMatMul()
  (encoder4/mha/K/b): OnnxBinaryMathOperation()
  (encoder4/mha/K/reshape): OnnxReshape()
  (encoder4/mha/K/transpose): OnnxTranspose()
  (encoder4/mha/V/w): OnnxMatMul()
  (encoder4/mha/V/b): OnnxBinaryMathOperation()
  (encoder4/mha/V/reshape): OnnxReshape()
  (encoder4/mha/V/transpose): OnnxTranspose()
  (encoder4/mha/QK/matmul): OnnxMatMul()
  (encoder4/mha/QK/scale): OnnxBinaryMathOperation()
  (encoder4/smolgen/compress): OnnxMatMul()
  (encoder4/smolgen/compress/reshape): OnnxReshape()
  (encoder4/smolgen/dense1/w): OnnxMatMul()
  (encoder4/smolgen/dense1/b): OnnxBinaryMathOperation()
  (encoder4/smolgen/dense1/swish/sigmoid): Sigmoid()
  (encoder4/smolgen/dense1/swish): OnnxBinaryMathOperation()
  (encoder4/smolgen/ln1): LayerNorm((256,), eps=0.0010000000474974513, elementwise_affine=True)
  (encoder4/smolgen/dense2/w): OnnxMatMul()
  (encoder4/smolgen/dense2/b): OnnxBinaryMathOperation()
  (encoder4/smolgen/dense2/swish/sigmoid): Sigmoid()
  (encoder4/smolgen/dense2/swish): OnnxBinaryMathOperation()
  (encoder4/smolgen/ln2): LayerNorm((6144,), eps=0.0010000000474974513, elementwise_affine=True)
  (encoder4/smolgen/gen_from/reshape): OnnxReshape()
  (encoder4/smolgen/smol_weight_gen): OnnxMatMul()
  (encoder4/smolgen/out/reshape): OnnxReshape()
  (encoder4/smolgen_weights): OnnxBinaryMathOperation()
  (encoder4/mha/QK/softmax): Softmax(dim=3)
  (encoder4/mha/QKV/matmul): OnnxMatMul()
  (encoder4/mha/out/transpose): OnnxTranspose()
  (encoder4/mha/out/reshape): OnnxReshape()
  (encoder4/mha/out/dense/w): OnnxMatMul()
  (encoder4/mha/out/dense/b): OnnxBinaryMathOperation()
  (encoder4/alpha*input): OnnxBinaryMathOperation()
  (encoder4/mha/out/skip): OnnxBinaryMathOperation()
  (encoder4/ln1): LayerNorm((768,), eps=9.999999974752427e-07, elementwise_affine=True)
  (encoder4/ffn/dense1/w): OnnxMatMul()
  (encoder4/ffn/dense1/b): OnnxBinaryMathOperation()
  (encoder4/ffn/dense1/sqrrelu/relu): ReLU()
  (encoder4/ffn/dense1/sqrrelu/sqr): OnnxBinaryMathOperation()
  (encoder4/ffn/dense2/w): OnnxMatMul()
  (encoder4/ffn/dense2/b): OnnxBinaryMathOperation()
  (encoder4/alpha*out1): OnnxBinaryMathOperation()
  (encoder4/ffn/skip): OnnxBinaryMathOperation()
  (encoder4/ln2): LayerNorm((768,), eps=9.999999974752427e-07, elementwise_affine=True)
  (encoder5/mha/Q/w): OnnxMatMul()
  (encoder5/mha/Q/b): OnnxBinaryMathOperation()
  (encoder5/mha/Q/reshape): OnnxReshape()
  (encoder5/mha/Q/transpose): OnnxTranspose()
  (encoder5/mha/K/w): OnnxMatMul()
  (encoder5/mha/K/b): OnnxBinaryMathOperation()
  (encoder5/mha/K/reshape): OnnxReshape()
  (encoder5/mha/K/transpose): OnnxTranspose()
  (encoder5/mha/V/w): OnnxMatMul()
  (encoder5/mha/V/b): OnnxBinaryMathOperation()
  (encoder5/mha/V/reshape): OnnxReshape()
  (encoder5/mha/V/transpose): OnnxTranspose()
  (encoder5/mha/QK/matmul): OnnxMatMul()
  (encoder5/mha/QK/scale): OnnxBinaryMathOperation()
  (encoder5/smolgen/compress): OnnxMatMul()
  (encoder5/smolgen/compress/reshape): OnnxReshape()
  (encoder5/smolgen/dense1/w): OnnxMatMul()
  (encoder5/smolgen/dense1/b): OnnxBinaryMathOperation()
  (encoder5/smolgen/dense1/swish/sigmoid): Sigmoid()
  (encoder5/smolgen/dense1/swish): OnnxBinaryMathOperation()
  (encoder5/smolgen/ln1): LayerNorm((256,), eps=0.0010000000474974513, elementwise_affine=True)
  (encoder5/smolgen/dense2/w): OnnxMatMul()
  (encoder5/smolgen/dense2/b): OnnxBinaryMathOperation()
  (encoder5/smolgen/dense2/swish/sigmoid): Sigmoid()
  (encoder5/smolgen/dense2/swish): OnnxBinaryMathOperation()
  (encoder5/smolgen/ln2): LayerNorm((6144,), eps=0.0010000000474974513, elementwise_affine=True)
  (encoder5/smolgen/gen_from/reshape): OnnxReshape()
  (encoder5/smolgen/smol_weight_gen): OnnxMatMul()
  (encoder5/smolgen/out/reshape): OnnxReshape()
  (encoder5/smolgen_weights): OnnxBinaryMathOperation()
  (encoder5/mha/QK/softmax): Softmax(dim=3)
  (encoder5/mha/QKV/matmul): OnnxMatMul()
  (encoder5/mha/out/transpose): OnnxTranspose()
  (encoder5/mha/out/reshape): OnnxReshape()
  (encoder5/mha/out/dense/w): OnnxMatMul()
  (encoder5/mha/out/dense/b): OnnxBinaryMathOperation()
  (encoder5/alpha*input): OnnxBinaryMathOperation()
  (encoder5/mha/out/skip): OnnxBinaryMathOperation()
  (encoder5/ln1): LayerNorm((768,), eps=9.999999974752427e-07, elementwise_affine=True)
  (encoder5/ffn/dense1/w): OnnxMatMul()
  (encoder5/ffn/dense1/b): OnnxBinaryMathOperation()
  (encoder5/ffn/dense1/sqrrelu/relu): ReLU()
  (encoder5/ffn/dense1/sqrrelu/sqr): OnnxBinaryMathOperation()
  (encoder5/ffn/dense2/w): OnnxMatMul()
  (encoder5/ffn/dense2/b): OnnxBinaryMathOperation()
  (encoder5/alpha*out1): OnnxBinaryMathOperation()
  (encoder5/ffn/skip): OnnxBinaryMathOperation()
  (encoder5/ln2): LayerNorm((768,), eps=9.999999974752427e-07, elementwise_affine=True)
  (encoder6/mha/Q/w): OnnxMatMul()
  (encoder6/mha/Q/b): OnnxBinaryMathOperation()
  (encoder6/mha/Q/reshape): OnnxReshape()
  (encoder6/mha/Q/transpose): OnnxTranspose()
  (encoder6/mha/K/w): OnnxMatMul()
  (encoder6/mha/K/b): OnnxBinaryMathOperation()
  (encoder6/mha/K/reshape): OnnxReshape()
  (encoder6/mha/K/transpose): OnnxTranspose()
  (encoder6/mha/V/w): OnnxMatMul()
  (encoder6/mha/V/b): OnnxBinaryMathOperation()
  (encoder6/mha/V/reshape): OnnxReshape()
  (encoder6/mha/V/transpose): OnnxTranspose()
  (encoder6/mha/QK/matmul): OnnxMatMul()
  (encoder6/mha/QK/scale): OnnxBinaryMathOperation()
  (encoder6/smolgen/compress): OnnxMatMul()
  (encoder6/smolgen/compress/reshape): OnnxReshape()
  (encoder6/smolgen/dense1/w): OnnxMatMul()
  (encoder6/smolgen/dense1/b): OnnxBinaryMathOperation()
  (encoder6/smolgen/dense1/swish/sigmoid): Sigmoid()
  (encoder6/smolgen/dense1/swish): OnnxBinaryMathOperation()
  (encoder6/smolgen/ln1): LayerNorm((256,), eps=0.0010000000474974513, elementwise_affine=True)
  (encoder6/smolgen/dense2/w): OnnxMatMul()
  (encoder6/smolgen/dense2/b): OnnxBinaryMathOperation()
  (encoder6/smolgen/dense2/swish/sigmoid): Sigmoid()
  (encoder6/smolgen/dense2/swish): OnnxBinaryMathOperation()
  (encoder6/smolgen/ln2): LayerNorm((6144,), eps=0.0010000000474974513, elementwise_affine=True)
  (encoder6/smolgen/gen_from/reshape): OnnxReshape()
  (encoder6/smolgen/smol_weight_gen): OnnxMatMul()
  (encoder6/smolgen/out/reshape): OnnxReshape()
  (encoder6/smolgen_weights): OnnxBinaryMathOperation()
  (encoder6/mha/QK/softmax): Softmax(dim=3)
  (encoder6/mha/QKV/matmul): OnnxMatMul()
  (encoder6/mha/out/transpose): OnnxTranspose()
  (encoder6/mha/out/reshape): OnnxReshape()
  (encoder6/mha/out/dense/w): OnnxMatMul()
  (encoder6/mha/out/dense/b): OnnxBinaryMathOperation()
  (encoder6/alpha*input): OnnxBinaryMathOperation()
  (encoder6/mha/out/skip): OnnxBinaryMathOperation()
  (encoder6/ln1): LayerNorm((768,), eps=9.999999974752427e-07, elementwise_affine=True)
  (encoder6/ffn/dense1/w): OnnxMatMul()
  (encoder6/ffn/dense1/b): OnnxBinaryMathOperation()
  (encoder6/ffn/dense1/sqrrelu/relu): ReLU()
  (encoder6/ffn/dense1/sqrrelu/sqr): OnnxBinaryMathOperation()
  (encoder6/ffn/dense2/w): OnnxMatMul()
  (encoder6/ffn/dense2/b): OnnxBinaryMathOperation()
  (encoder6/alpha*out1): OnnxBinaryMathOperation()
  (encoder6/ffn/skip): OnnxBinaryMathOperation()
  (encoder6/ln2): LayerNorm((768,), eps=9.999999974752427e-07, elementwise_affine=True)
  (encoder7/mha/Q/w): OnnxMatMul()
  (encoder7/mha/Q/b): OnnxBinaryMathOperation()
  (encoder7/mha/Q/reshape): OnnxReshape()
  (encoder7/mha/Q/transpose): OnnxTranspose()
  (encoder7/mha/K/w): OnnxMatMul()
  (encoder7/mha/K/b): OnnxBinaryMathOperation()
  (encoder7/mha/K/reshape): OnnxReshape()
  (encoder7/mha/K/transpose): OnnxTranspose()
  (encoder7/mha/V/w): OnnxMatMul()
  (encoder7/mha/V/b): OnnxBinaryMathOperation()
  (encoder7/mha/V/reshape): OnnxReshape()
  (encoder7/mha/V/transpose): OnnxTranspose()
  (encoder7/mha/QK/matmul): OnnxMatMul()
  (encoder7/mha/QK/scale): OnnxBinaryMathOperation()
  (encoder7/smolgen/compress): OnnxMatMul()
  (encoder7/smolgen/compress/reshape): OnnxReshape()
  (encoder7/smolgen/dense1/w): OnnxMatMul()
  (encoder7/smolgen/dense1/b): OnnxBinaryMathOperation()
  (encoder7/smolgen/dense1/swish/sigmoid): Sigmoid()
  (encoder7/smolgen/dense1/swish): OnnxBinaryMathOperation()
  (encoder7/smolgen/ln1): LayerNorm((256,), eps=0.0010000000474974513, elementwise_affine=True)
  (encoder7/smolgen/dense2/w): OnnxMatMul()
  (encoder7/smolgen/dense2/b): OnnxBinaryMathOperation()
  (encoder7/smolgen/dense2/swish/sigmoid): Sigmoid()
  (encoder7/smolgen/dense2/swish): OnnxBinaryMathOperation()
  (encoder7/smolgen/ln2): LayerNorm((6144,), eps=0.0010000000474974513, elementwise_affine=True)
  (encoder7/smolgen/gen_from/reshape): OnnxReshape()
  (encoder7/smolgen/smol_weight_gen): OnnxMatMul()
  (encoder7/smolgen/out/reshape): OnnxReshape()
  (encoder7/smolgen_weights): OnnxBinaryMathOperation()
  (encoder7/mha/QK/softmax): Softmax(dim=3)
  (encoder7/mha/QKV/matmul): OnnxMatMul()
  (encoder7/mha/out/transpose): OnnxTranspose()
  (encoder7/mha/out/reshape): OnnxReshape()
  (encoder7/mha/out/dense/w): OnnxMatMul()
  (encoder7/mha/out/dense/b): OnnxBinaryMathOperation()
  (encoder7/alpha*input): OnnxBinaryMathOperation()
  (encoder7/mha/out/skip): OnnxBinaryMathOperation()
  (encoder7/ln1): LayerNorm((768,), eps=9.999999974752427e-07, elementwise_affine=True)
  (encoder7/ffn/dense1/w): OnnxMatMul()
  (encoder7/ffn/dense1/b): OnnxBinaryMathOperation()
  (encoder7/ffn/dense1/sqrrelu/relu): ReLU()
  (encoder7/ffn/dense1/sqrrelu/sqr): OnnxBinaryMathOperation()
  (encoder7/ffn/dense2/w): OnnxMatMul()
  (encoder7/ffn/dense2/b): OnnxBinaryMathOperation()
  (encoder7/alpha*out1): OnnxBinaryMathOperation()
  (encoder7/ffn/skip): OnnxBinaryMathOperation()
  (encoder7/ln2): LayerNorm((768,), eps=9.999999974752427e-07, elementwise_affine=True)
  (encoder8/mha/Q/w): OnnxMatMul()
  (encoder8/mha/Q/b): OnnxBinaryMathOperation()
  (encoder8/mha/Q/reshape): OnnxReshape()
  (encoder8/mha/Q/transpose): OnnxTranspose()
  (encoder8/mha/K/w): OnnxMatMul()
  (encoder8/mha/K/b): OnnxBinaryMathOperation()
  (encoder8/mha/K/reshape): OnnxReshape()
  (encoder8/mha/K/transpose): OnnxTranspose()
  (encoder8/mha/V/w): OnnxMatMul()
  (encoder8/mha/V/b): OnnxBinaryMathOperation()
  (encoder8/mha/V/reshape): OnnxReshape()
  (encoder8/mha/V/transpose): OnnxTranspose()
  (encoder8/mha/QK/matmul): OnnxMatMul()
  (encoder8/mha/QK/scale): OnnxBinaryMathOperation()
  (encoder8/smolgen/compress): OnnxMatMul()
  (encoder8/smolgen/compress/reshape): OnnxReshape()
  (encoder8/smolgen/dense1/w): OnnxMatMul()
  (encoder8/smolgen/dense1/b): OnnxBinaryMathOperation()
  (encoder8/smolgen/dense1/swish/sigmoid): Sigmoid()
  (encoder8/smolgen/dense1/swish): OnnxBinaryMathOperation()
  (encoder8/smolgen/ln1): LayerNorm((256,), eps=0.0010000000474974513, elementwise_affine=True)
  (encoder8/smolgen/dense2/w): OnnxMatMul()
  (encoder8/smolgen/dense2/b): OnnxBinaryMathOperation()
  (encoder8/smolgen/dense2/swish/sigmoid): Sigmoid()
  (encoder8/smolgen/dense2/swish): OnnxBinaryMathOperation()
  (encoder8/smolgen/ln2): LayerNorm((6144,), eps=0.0010000000474974513, elementwise_affine=True)
  (encoder8/smolgen/gen_from/reshape): OnnxReshape()
  (encoder8/smolgen/smol_weight_gen): OnnxMatMul()
  (encoder8/smolgen/out/reshape): OnnxReshape()
  (encoder8/smolgen_weights): OnnxBinaryMathOperation()
  (encoder8/mha/QK/softmax): Softmax(dim=3)
  (encoder8/mha/QKV/matmul): OnnxMatMul()
  (encoder8/mha/out/transpose): OnnxTranspose()
  (encoder8/mha/out/reshape): OnnxReshape()
  (encoder8/mha/out/dense/w): OnnxMatMul()
  (encoder8/mha/out/dense/b): OnnxBinaryMathOperation()
  (encoder8/alpha*input): OnnxBinaryMathOperation()
  (encoder8/mha/out/skip): OnnxBinaryMathOperation()
  (encoder8/ln1): LayerNorm((768,), eps=9.999999974752427e-07, elementwise_affine=True)
  (encoder8/ffn/dense1/w): OnnxMatMul()
  (encoder8/ffn/dense1/b): OnnxBinaryMathOperation()
  (encoder8/ffn/dense1/sqrrelu/relu): ReLU()
  (encoder8/ffn/dense1/sqrrelu/sqr): OnnxBinaryMathOperation()
  (encoder8/ffn/dense2/w): OnnxMatMul()
  (encoder8/ffn/dense2/b): OnnxBinaryMathOperation()
  (encoder8/alpha*out1): OnnxBinaryMathOperation()
  (encoder8/ffn/skip): OnnxBinaryMathOperation()
  (encoder8/ln2): LayerNorm((768,), eps=9.999999974752427e-07, elementwise_affine=True)
  (encoder9/mha/Q/w): OnnxMatMul()
  (encoder9/mha/Q/b): OnnxBinaryMathOperation()
  (encoder9/mha/Q/reshape): OnnxReshape()
  (encoder9/mha/Q/transpose): OnnxTranspose()
  (encoder9/mha/K/w): OnnxMatMul()
  (encoder9/mha/K/b): OnnxBinaryMathOperation()
  (encoder9/mha/K/reshape): OnnxReshape()
  (encoder9/mha/K/transpose): OnnxTranspose()
  (encoder9/mha/V/w): OnnxMatMul()
  (encoder9/mha/V/b): OnnxBinaryMathOperation()
  (encoder9/mha/V/reshape): OnnxReshape()
  (encoder9/mha/V/transpose): OnnxTranspose()
  (encoder9/mha/QK/matmul): OnnxMatMul()
  (encoder9/mha/QK/scale): OnnxBinaryMathOperation()
  (encoder9/smolgen/compress): OnnxMatMul()
  (encoder9/smolgen/compress/reshape): OnnxReshape()
  (encoder9/smolgen/dense1/w): OnnxMatMul()
  (encoder9/smolgen/dense1/b): OnnxBinaryMathOperation()
  (encoder9/smolgen/dense1/swish/sigmoid): Sigmoid()
  (encoder9/smolgen/dense1/swish): OnnxBinaryMathOperation()
  (encoder9/smolgen/ln1): LayerNorm((256,), eps=0.0010000000474974513, elementwise_affine=True)
  (encoder9/smolgen/dense2/w): OnnxMatMul()
  (encoder9/smolgen/dense2/b): OnnxBinaryMathOperation()
  (encoder9/smolgen/dense2/swish/sigmoid): Sigmoid()
  (encoder9/smolgen/dense2/swish): OnnxBinaryMathOperation()
  (encoder9/smolgen/ln2): LayerNorm((6144,), eps=0.0010000000474974513, elementwise_affine=True)
  (encoder9/smolgen/gen_from/reshape): OnnxReshape()
  (encoder9/smolgen/smol_weight_gen): OnnxMatMul()
  (encoder9/smolgen/out/reshape): OnnxReshape()
  (encoder9/smolgen_weights): OnnxBinaryMathOperation()
  (encoder9/mha/QK/softmax): Softmax(dim=3)
  (encoder9/mha/QKV/matmul): OnnxMatMul()
  (encoder9/mha/out/transpose): OnnxTranspose()
  (encoder9/mha/out/reshape): OnnxReshape()
  (encoder9/mha/out/dense/w): OnnxMatMul()
  (encoder9/mha/out/dense/b): OnnxBinaryMathOperation()
  (encoder9/alpha*input): OnnxBinaryMathOperation()
  (encoder9/mha/out/skip): OnnxBinaryMathOperation()
  (encoder9/ln1): LayerNorm((768,), eps=9.999999974752427e-07, elementwise_affine=True)
  (encoder9/ffn/dense1/w): OnnxMatMul()
  (encoder9/ffn/dense1/b): OnnxBinaryMathOperation()
  (encoder9/ffn/dense1/sqrrelu/relu): ReLU()
  (encoder9/ffn/dense1/sqrrelu/sqr): OnnxBinaryMathOperation()
  (encoder9/ffn/dense2/w): OnnxMatMul()
  (encoder9/ffn/dense2/b): OnnxBinaryMathOperation()
  (encoder9/alpha*out1): OnnxBinaryMathOperation()
  (encoder9/ffn/skip): OnnxBinaryMathOperation()
  (encoder9/ln2): LayerNorm((768,), eps=9.999999974752427e-07, elementwise_affine=True)
  (encoder10/mha/Q/w): OnnxMatMul()
  (encoder10/mha/Q/b): OnnxBinaryMathOperation()
  (encoder10/mha/Q/reshape): OnnxReshape()
  (encoder10/mha/Q/transpose): OnnxTranspose()
  (encoder10/mha/K/w): OnnxMatMul()
  (encoder10/mha/K/b): OnnxBinaryMathOperation()
  (encoder10/mha/K/reshape): OnnxReshape()
  (encoder10/mha/K/transpose): OnnxTranspose()
  (encoder10/mha/V/w): OnnxMatMul()
  (encoder10/mha/V/b): OnnxBinaryMathOperation()
  (encoder10/mha/V/reshape): OnnxReshape()
  (encoder10/mha/V/transpose): OnnxTranspose()
  (encoder10/mha/QK/matmul): OnnxMatMul()
  (encoder10/mha/QK/scale): OnnxBinaryMathOperation()
  (encoder10/smolgen/compress): OnnxMatMul()
  (encoder10/smolgen/compress/reshape): OnnxReshape()
  (encoder10/smolgen/dense1/w): OnnxMatMul()
  (encoder10/smolgen/dense1/b): OnnxBinaryMathOperation()
  (encoder10/smolgen/dense1/swish/sigmoid): Sigmoid()
  (encoder10/smolgen/dense1/swish): OnnxBinaryMathOperation()
  (encoder10/smolgen/ln1): LayerNorm((256,), eps=0.0010000000474974513, elementwise_affine=True)
  (encoder10/smolgen/dense2/w): OnnxMatMul()
  (encoder10/smolgen/dense2/b): OnnxBinaryMathOperation()
  (encoder10/smolgen/dense2/swish/sigmoid): Sigmoid()
  (encoder10/smolgen/dense2/swish): OnnxBinaryMathOperation()
  (encoder10/smolgen/ln2): LayerNorm((6144,), eps=0.0010000000474974513, elementwise_affine=True)
  (encoder10/smolgen/gen_from/reshape): OnnxReshape()
  (encoder10/smolgen/smol_weight_gen): OnnxMatMul()
  (encoder10/smolgen/out/reshape): OnnxReshape()
  (encoder10/smolgen_weights): OnnxBinaryMathOperation()
  (encoder10/mha/QK/softmax): Softmax(dim=3)
  (encoder10/mha/QKV/matmul): OnnxMatMul()
  (encoder10/mha/out/transpose): OnnxTranspose()
  (encoder10/mha/out/reshape): OnnxReshape()
  (encoder10/mha/out/dense/w): OnnxMatMul()
  (encoder10/mha/out/dense/b): OnnxBinaryMathOperation()
  (encoder10/alpha*input): OnnxBinaryMathOperation()
  (encoder10/mha/out/skip): OnnxBinaryMathOperation()
  (encoder10/ln1): LayerNorm((768,), eps=9.999999974752427e-07, elementwise_affine=True)
  (encoder10/ffn/dense1/w): OnnxMatMul()
  (encoder10/ffn/dense1/b): OnnxBinaryMathOperation()
  (encoder10/ffn/dense1/sqrrelu/relu): ReLU()
  (encoder10/ffn/dense1/sqrrelu/sqr): OnnxBinaryMathOperation()
  (encoder10/ffn/dense2/w): OnnxMatMul()
  (encoder10/ffn/dense2/b): OnnxBinaryMathOperation()
  (encoder10/alpha*out1): OnnxBinaryMathOperation()
  (encoder10/ffn/skip): OnnxBinaryMathOperation()
  (encoder10/ln2): LayerNorm((768,), eps=9.999999974752427e-07, elementwise_affine=True)
  (encoder11/mha/Q/w): OnnxMatMul()
  (encoder11/mha/Q/b): OnnxBinaryMathOperation()
  (encoder11/mha/Q/reshape): OnnxReshape()
  (encoder11/mha/Q/transpose): OnnxTranspose()
  (encoder11/mha/K/w): OnnxMatMul()
  (encoder11/mha/K/b): OnnxBinaryMathOperation()
  (encoder11/mha/K/reshape): OnnxReshape()
  (encoder11/mha/K/transpose): OnnxTranspose()
  (encoder11/mha/V/w): OnnxMatMul()
  (encoder11/mha/V/b): OnnxBinaryMathOperation()
  (encoder11/mha/V/reshape): OnnxReshape()
  (encoder11/mha/V/transpose): OnnxTranspose()
  (encoder11/mha/QK/matmul): OnnxMatMul()
  (encoder11/mha/QK/scale): OnnxBinaryMathOperation()
  (encoder11/smolgen/compress): OnnxMatMul()
  (encoder11/smolgen/compress/reshape): OnnxReshape()
  (encoder11/smolgen/dense1/w): OnnxMatMul()
  (encoder11/smolgen/dense1/b): OnnxBinaryMathOperation()
  (encoder11/smolgen/dense1/swish/sigmoid): Sigmoid()
  (encoder11/smolgen/dense1/swish): OnnxBinaryMathOperation()
  (encoder11/smolgen/ln1): LayerNorm((256,), eps=0.0010000000474974513, elementwise_affine=True)
  (encoder11/smolgen/dense2/w): OnnxMatMul()
  (encoder11/smolgen/dense2/b): OnnxBinaryMathOperation()
  (encoder11/smolgen/dense2/swish/sigmoid): Sigmoid()
  (encoder11/smolgen/dense2/swish): OnnxBinaryMathOperation()
  (encoder11/smolgen/ln2): LayerNorm((6144,), eps=0.0010000000474974513, elementwise_affine=True)
  (encoder11/smolgen/gen_from/reshape): OnnxReshape()
  (encoder11/smolgen/smol_weight_gen): OnnxMatMul()
  (encoder11/smolgen/out/reshape): OnnxReshape()
  (encoder11/smolgen_weights): OnnxBinaryMathOperation()
  (encoder11/mha/QK/softmax): Softmax(dim=3)
  (encoder11/mha/QKV/matmul): OnnxMatMul()
  (encoder11/mha/out/transpose): OnnxTranspose()
  (encoder11/mha/out/reshape): OnnxReshape()
  (encoder11/mha/out/dense/w): OnnxMatMul()
  (encoder11/mha/out/dense/b): OnnxBinaryMathOperation()
  (encoder11/alpha*input): OnnxBinaryMathOperation()
  (encoder11/mha/out/skip): OnnxBinaryMathOperation()
  (encoder11/ln1): LayerNorm((768,), eps=9.999999974752427e-07, elementwise_affine=True)
  (encoder11/ffn/dense1/w): OnnxMatMul()
  (encoder11/ffn/dense1/b): OnnxBinaryMathOperation()
  (encoder11/ffn/dense1/sqrrelu/relu): ReLU()
  (encoder11/ffn/dense1/sqrrelu/sqr): OnnxBinaryMathOperation()
  (encoder11/ffn/dense2/w): OnnxMatMul()
  (encoder11/ffn/dense2/b): OnnxBinaryMathOperation()
  (encoder11/alpha*out1): OnnxBinaryMathOperation()
  (encoder11/ffn/skip): OnnxBinaryMathOperation()
  (encoder11/ln2): LayerNorm((768,), eps=9.999999974752427e-07, elementwise_affine=True)
  (encoder12/mha/Q/w): OnnxMatMul()
  (encoder12/mha/Q/b): OnnxBinaryMathOperation()
  (encoder12/mha/Q/reshape): OnnxReshape()
  (encoder12/mha/Q/transpose): OnnxTranspose()
  (encoder12/mha/K/w): OnnxMatMul()
  (encoder12/mha/K/b): OnnxBinaryMathOperation()
  (encoder12/mha/K/reshape): OnnxReshape()
  (encoder12/mha/K/transpose): OnnxTranspose()
  (encoder12/mha/V/w): OnnxMatMul()
  (encoder12/mha/V/b): OnnxBinaryMathOperation()
  (encoder12/mha/V/reshape): OnnxReshape()
  (encoder12/mha/V/transpose): OnnxTranspose()
  (encoder12/mha/QK/matmul): OnnxMatMul()
  (encoder12/mha/QK/scale): OnnxBinaryMathOperation()
  (encoder12/smolgen/compress): OnnxMatMul()
  (encoder12/smolgen/compress/reshape): OnnxReshape()
  (encoder12/smolgen/dense1/w): OnnxMatMul()
  (encoder12/smolgen/dense1/b): OnnxBinaryMathOperation()
  (encoder12/smolgen/dense1/swish/sigmoid): Sigmoid()
  (encoder12/smolgen/dense1/swish): OnnxBinaryMathOperation()
  (encoder12/smolgen/ln1): LayerNorm((256,), eps=0.0010000000474974513, elementwise_affine=True)
  (encoder12/smolgen/dense2/w): OnnxMatMul()
  (encoder12/smolgen/dense2/b): OnnxBinaryMathOperation()
  (encoder12/smolgen/dense2/swish/sigmoid): Sigmoid()
  (encoder12/smolgen/dense2/swish): OnnxBinaryMathOperation()
  (encoder12/smolgen/ln2): LayerNorm((6144,), eps=0.0010000000474974513, elementwise_affine=True)
  (encoder12/smolgen/gen_from/reshape): OnnxReshape()
  (encoder12/smolgen/smol_weight_gen): OnnxMatMul()
  (encoder12/smolgen/out/reshape): OnnxReshape()
  (encoder12/smolgen_weights): OnnxBinaryMathOperation()
  (encoder12/mha/QK/softmax): Softmax(dim=3)
  (encoder12/mha/QKV/matmul): OnnxMatMul()
  (encoder12/mha/out/transpose): OnnxTranspose()
  (encoder12/mha/out/reshape): OnnxReshape()
  (encoder12/mha/out/dense/w): OnnxMatMul()
  (encoder12/mha/out/dense/b): OnnxBinaryMathOperation()
  (encoder12/alpha*input): OnnxBinaryMathOperation()
  (encoder12/mha/out/skip): OnnxBinaryMathOperation()
  (encoder12/ln1): LayerNorm((768,), eps=9.999999974752427e-07, elementwise_affine=True)
  (encoder12/ffn/dense1/w): OnnxMatMul()
  (encoder12/ffn/dense1/b): OnnxBinaryMathOperation()
  (encoder12/ffn/dense1/sqrrelu/relu): ReLU()
  (encoder12/ffn/dense1/sqrrelu/sqr): OnnxBinaryMathOperation()
  (encoder12/ffn/dense2/w): OnnxMatMul()
  (encoder12/ffn/dense2/b): OnnxBinaryMathOperation()
  (encoder12/alpha*out1): OnnxBinaryMathOperation()
  (encoder12/ffn/skip): OnnxBinaryMathOperation()
  (encoder12/ln2): LayerNorm((768,), eps=9.999999974752427e-07, elementwise_affine=True)
  (encoder13/mha/Q/w): OnnxMatMul()
  (encoder13/mha/Q/b): OnnxBinaryMathOperation()
  (encoder13/mha/Q/reshape): OnnxReshape()
  (encoder13/mha/Q/transpose): OnnxTranspose()
  (encoder13/mha/K/w): OnnxMatMul()
  (encoder13/mha/K/b): OnnxBinaryMathOperation()
  (encoder13/mha/K/reshape): OnnxReshape()
  (encoder13/mha/K/transpose): OnnxTranspose()
  (encoder13/mha/V/w): OnnxMatMul()
  (encoder13/mha/V/b): OnnxBinaryMathOperation()
  (encoder13/mha/V/reshape): OnnxReshape()
  (encoder13/mha/V/transpose): OnnxTranspose()
  (encoder13/mha/QK/matmul): OnnxMatMul()
  (encoder13/mha/QK/scale): OnnxBinaryMathOperation()
  (encoder13/smolgen/compress): OnnxMatMul()
  (encoder13/smolgen/compress/reshape): OnnxReshape()
  (encoder13/smolgen/dense1/w): OnnxMatMul()
  (encoder13/smolgen/dense1/b): OnnxBinaryMathOperation()
  (encoder13/smolgen/dense1/swish/sigmoid): Sigmoid()
  (encoder13/smolgen/dense1/swish): OnnxBinaryMathOperation()
  (encoder13/smolgen/ln1): LayerNorm((256,), eps=0.0010000000474974513, elementwise_affine=True)
  (encoder13/smolgen/dense2/w): OnnxMatMul()
  (encoder13/smolgen/dense2/b): OnnxBinaryMathOperation()
  (encoder13/smolgen/dense2/swish/sigmoid): Sigmoid()
  (encoder13/smolgen/dense2/swish): OnnxBinaryMathOperation()
  (encoder13/smolgen/ln2): LayerNorm((6144,), eps=0.0010000000474974513, elementwise_affine=True)
  (encoder13/smolgen/gen_from/reshape): OnnxReshape()
  (encoder13/smolgen/smol_weight_gen): OnnxMatMul()
  (encoder13/smolgen/out/reshape): OnnxReshape()
  (encoder13/smolgen_weights): OnnxBinaryMathOperation()
  (encoder13/mha/QK/softmax): Softmax(dim=3)
  (encoder13/mha/QKV/matmul): OnnxMatMul()
  (encoder13/mha/out/transpose): OnnxTranspose()
  (encoder13/mha/out/reshape): OnnxReshape()
  (encoder13/mha/out/dense/w): OnnxMatMul()
  (encoder13/mha/out/dense/b): OnnxBinaryMathOperation()
  (encoder13/alpha*input): OnnxBinaryMathOperation()
  (encoder13/mha/out/skip): OnnxBinaryMathOperation()
  (encoder13/ln1): LayerNorm((768,), eps=9.999999974752427e-07, elementwise_affine=True)
  (encoder13/ffn/dense1/w): OnnxMatMul()
  (encoder13/ffn/dense1/b): OnnxBinaryMathOperation()
  (encoder13/ffn/dense1/sqrrelu/relu): ReLU()
  (encoder13/ffn/dense1/sqrrelu/sqr): OnnxBinaryMathOperation()
  (encoder13/ffn/dense2/w): OnnxMatMul()
  (encoder13/ffn/dense2/b): OnnxBinaryMathOperation()
  (encoder13/alpha*out1): OnnxBinaryMathOperation()
  (encoder13/ffn/skip): OnnxBinaryMathOperation()
  (encoder13/ln2): LayerNorm((768,), eps=9.999999974752427e-07, elementwise_affine=True)
  (encoder14/mha/Q/w): OnnxMatMul()
  (encoder14/mha/Q/b): OnnxBinaryMathOperation()
  (encoder14/mha/Q/reshape): OnnxReshape()
  (encoder14/mha/Q/transpose): OnnxTranspose()
  (encoder14/mha/K/w): OnnxMatMul()
  (encoder14/mha/K/b): OnnxBinaryMathOperation()
  (encoder14/mha/K/reshape): OnnxReshape()
  (encoder14/mha/K/transpose): OnnxTranspose()
  (encoder14/mha/V/w): OnnxMatMul()
  (encoder14/mha/V/b): OnnxBinaryMathOperation()
  (encoder14/mha/V/reshape): OnnxReshape()
  (encoder14/mha/V/transpose): OnnxTranspose()
  (encoder14/mha/QK/matmul): OnnxMatMul()
  (encoder14/mha/QK/scale): OnnxBinaryMathOperation()
  (encoder14/smolgen/compress): OnnxMatMul()
  (encoder14/smolgen/compress/reshape): OnnxReshape()
  (encoder14/smolgen/dense1/w): OnnxMatMul()
  (encoder14/smolgen/dense1/b): OnnxBinaryMathOperation()
  (encoder14/smolgen/dense1/swish/sigmoid): Sigmoid()
  (encoder14/smolgen/dense1/swish): OnnxBinaryMathOperation()
  (encoder14/smolgen/ln1): LayerNorm((256,), eps=0.0010000000474974513, elementwise_affine=True)
  (encoder14/smolgen/dense2/w): OnnxMatMul()
  (encoder14/smolgen/dense2/b): OnnxBinaryMathOperation()
  (encoder14/smolgen/dense2/swish/sigmoid): Sigmoid()
  (encoder14/smolgen/dense2/swish): OnnxBinaryMathOperation()
  (encoder14/smolgen/ln2): LayerNorm((6144,), eps=0.0010000000474974513, elementwise_affine=True)
  (encoder14/smolgen/gen_from/reshape): OnnxReshape()
  (encoder14/smolgen/smol_weight_gen): OnnxMatMul()
  (encoder14/smolgen/out/reshape): OnnxReshape()
  (encoder14/smolgen_weights): OnnxBinaryMathOperation()
  (encoder14/mha/QK/softmax): Softmax(dim=3)
  (encoder14/mha/QKV/matmul): OnnxMatMul()
  (encoder14/mha/out/transpose): OnnxTranspose()
  (encoder14/mha/out/reshape): OnnxReshape()
  (encoder14/mha/out/dense/w): OnnxMatMul()
  (encoder14/mha/out/dense/b): OnnxBinaryMathOperation()
  (encoder14/alpha*input): OnnxBinaryMathOperation()
  (encoder14/mha/out/skip): OnnxBinaryMathOperation()
  (encoder14/ln1): LayerNorm((768,), eps=9.999999974752427e-07, elementwise_affine=True)
  (encoder14/ffn/dense1/w): OnnxMatMul()
  (encoder14/ffn/dense1/b): OnnxBinaryMathOperation()
  (encoder14/ffn/dense1/sqrrelu/relu): ReLU()
  (encoder14/ffn/dense1/sqrrelu/sqr): OnnxBinaryMathOperation()
  (encoder14/ffn/dense2/w): OnnxMatMul()
  (encoder14/ffn/dense2/b): OnnxBinaryMathOperation()
  (encoder14/alpha*out1): OnnxBinaryMathOperation()
  (encoder14/ffn/skip): OnnxBinaryMathOperation()
  (encoder14/ln2): LayerNorm((768,), eps=9.999999974752427e-07, elementwise_affine=True)
  (policy/dense1/matmul): OnnxMatMul()
  (policy/dense1/add): OnnxBinaryMathOperation()
  (policy/dense1/mish/softplus): Softplus(beta=1.0, threshold=20.0)
  (policy/dense1/mish/tanh): OnnxFunction()
  (policy/dense1/mish): OnnxBinaryMathOperation()
  (policy/Q/matmul): OnnxMatMul()
  (policy/Q/add): OnnxBinaryMathOperation()
  (policy/Q/reshape): OnnxReshape()
  (policy/K/matmul): OnnxMatMul()
  (policy/K/add): OnnxBinaryMathOperation()
  (policy/K/reshape): OnnxReshape()
  (policy/K/transpose): OnnxTranspose()
  (policy/matmul): OnnxMatMul()
  (policy/scale): OnnxBinaryMathOperation()
  (policy/promotion/slice): OnnxSlice()
  (policy/promotion/matmul): OnnxMatMul()
  (policy/promotion/transpose): OnnxTranspose()
  (policy/promotion/split): OnnxSplit13()
  (policy/promotion/add): OnnxBinaryMathOperation()
  (policy/promotion/transpose2): OnnxTranspose()
  (policy/promotion/reshape): OnnxReshape()
  (policy/promotion/slice2): OnnxSlice()
  (policy/promotion/reshape2): OnnxReshape()
  (policy/promotion/concat): OnnxConcat()
  (policy/promotion/reshape3): OnnxReshape()
  (policy/promotion/add2): OnnxBinaryMathOperation()
  (policy/promotion/reshape4): OnnxReshape()
  (policy/concat): OnnxConcat()
  (policy/reshape): OnnxReshape()
  (output/policy): OnnxGather()
  (value/embed/matmul): OnnxMatMul()
  (value/embed/add): OnnxBinaryMathOperation()
  (value/embed/mish/softplus): Softplus(beta=1.0, threshold=20.0)
  (value/embed/mish/tanh): OnnxFunction()
  (value/embed/mish): OnnxBinaryMathOperation()
  (value/reshape): OnnxReshape()
  (value/dense1/matmul): OnnxMatMul()
  (value/dense1/add): OnnxBinaryMathOperation()
  (value/dense1/mish/softplus): Softplus(beta=1.0, threshold=20.0)
  (value/dense1/mish/tanh): OnnxFunction()
  (value/dense1/mish): OnnxBinaryMathOperation()
  (value/dense2/matmul): OnnxMatMul()
  (value/dense2/add): OnnxBinaryMathOperation()
  (output/wdl): Softmax(dim=1)
  (mlh/embed/matmul): OnnxMatMul()
  (mlh/embed/add): OnnxBinaryMathOperation()
  (mlh/embed/mish/softplus): Softplus(beta=1.0, threshold=20.0)
  (mlh/embed/mish/tanh): OnnxFunction()
  (mlh/embed/mish): OnnxBinaryMathOperation()
  (mlh/reshape): OnnxReshape()
  (mlh/dense1/matmul): OnnxMatMul()
  (mlh/dense1/add): OnnxBinaryMathOperation()
  (mlh/dense1/mish/softplus): Softplus(beta=1.0, threshold=20.0)
  (mlh/dense1/mish/tanh): OnnxFunction()
  (mlh/dense1/mish): OnnxBinaryMathOperation()
  (mlh/dense2/matmul): OnnxMatMul()
  (mlh/dense2/add): OnnxBinaryMathOperation()
  (mlh/dense2/mish/softplus): Softplus(beta=1.0, threshold=20.0)
  (mlh/dense2/mish/tanh): OnnxFunction()
  (mlh/dense2/mish): OnnxBinaryMathOperation()
  (output/mlh): OnnxCopyIdentity()
)
[7]:
import IPython

from lczerolens.board import LczeroBoard

puzzle = puzzles.loc[19612]
board = LczeroBoard(puzzle.FEN)
moves = puzzle.Moves.split()
board.push_uci(moves[0])
corrupted_board = LczeroBoard(puzzle.corrupted_fen)
display(board)
display(corrupted_board)
../../_images/notebooks_tutorials_evidence-of-learned-look-ahead_9_0.svg
../../_images/notebooks_tutorials_evidence-of-learned-look-ahead_9_1.svg
[8]:
out = model(*[board, corrupted_board])
out["wdl"]
[8]:
tensor([[0.3586, 0.4711, 0.1703],
        [0.0132, 0.1089, 0.8779]], grad_fn=<SoftmaxBackward0>)

Visualising Attention#

[9]:
layer = 9
head = 5

with model.trace(board):
    attention = getattr(model, f"encoder{layer}/mha/QK/softmax").output[0, head].save()
attention.shape
[9]:
torch.Size([64, 64])
[10]:
square = chess.F4

boardsvg, _ = board.render_heatmap(attention[square].detach())
display(IPython.display.HTML((boardsvg)))
. . r . . r . k
. . . . . . p p
p . q . . p . .
. p . . . . . .
. . . R . N . .
P Q . . P . . P
. P . . . P P .
. . . . . . K .

Probing Analysis#

Activation Patching#

[11]:
from lczerolens.lenses import ActivationLens

MODULE = "encoder13/ln2"
act_lens = ActivationLens(MODULE)

clean_acts = act_lens.analyse(model, board)
corrupted_acts = act_lens.analyse(model, corrupted_board)
[12]:
corrupted_acts[f"{MODULE}_output"].shape
[12]:
torch.Size([64, 768])
[13]:
with model.trace(board):
    out = model.output.save()
[ ]: