Framework-Agnostic Interpretability#

Open In Colab

Setup#

[1]:
import importlib.util

DEV = True

if importlib.util.find_spec("google.colab") is not None:
    MODE = "colab-dev" if DEV else "colab"
else:
    MODE = "local"
[2]:
if MODE == "colab":
    %pip install -q lczerolens datasets nnsight
elif MODE == "colab-dev":
    !rm -r lczerolens
    !git clone https://github.com/Xmaster6y/lczerolens -b main
    %pip install -q ./lczerolens datasets nnsight

Load a Model and Board#

[3]:
from lczerolens import LczeroModel

model = LczeroModel.from_hf("lczerolens/256x19-4508")
model
[3]:
LczeroModel(
    module=GraphModule(
      (inputconv): Conv2d(112, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (inputconv/relu): ReLU()
      (block0/conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (block0/conv1/relu): ReLU()
      (block0/conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (block0/conv2/se/pooled): OnnxGlobalAveragePoolWithKnownInputShape()
      (initializers): Module()
      (block0/conv2/se/squeeze): OnnxSqueezeDynamicAxes()
      (block0/conv2/se/matmul1): OnnxMatMul()
      (block0/conv2/se/add1): OnnxBinaryMathOperation()
      (block0/conv2/se/relu): ReLU()
      (block0/conv2/se/matmul2): OnnxMatMul()
      (block0/conv2/se/add2): OnnxBinaryMathOperation()
      (block0/conv2/se/reshape): OnnxReshape()
      (block0/conv2/se/split): OnnxSplit13()
      (block0/conv2/se/sigmoid): Sigmoid()
      (block0/conv2/se/mul): OnnxBinaryMathOperation()
      (block0/conv2/se/add3): OnnxBinaryMathOperation()
      (block0/conv2/mixin): OnnxBinaryMathOperation()
      (block0/conv2/relu): ReLU()
      (block1/conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (block1/conv1/relu): ReLU()
      (block1/conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (block1/conv2/se/pooled): OnnxGlobalAveragePoolWithKnownInputShape()
      (block1/conv2/se/squeeze): OnnxSqueezeDynamicAxes()
      (block1/conv2/se/matmul1): OnnxMatMul()
      (block1/conv2/se/add1): OnnxBinaryMathOperation()
      (block1/conv2/se/relu): ReLU()
      (block1/conv2/se/matmul2): OnnxMatMul()
      (block1/conv2/se/add2): OnnxBinaryMathOperation()
      (block1/conv2/se/reshape): OnnxReshape()
      (block1/conv2/se/split): OnnxSplit13()
      (block1/conv2/se/sigmoid): Sigmoid()
      (block1/conv2/se/mul): OnnxBinaryMathOperation()
      (block1/conv2/se/add3): OnnxBinaryMathOperation()
      (block1/conv2/mixin): OnnxBinaryMathOperation()
      (block1/conv2/relu): ReLU()
      (block2/conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (block2/conv1/relu): ReLU()
      (block2/conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (block2/conv2/se/pooled): OnnxGlobalAveragePoolWithKnownInputShape()
      (block2/conv2/se/squeeze): OnnxSqueezeDynamicAxes()
      (block2/conv2/se/matmul1): OnnxMatMul()
      (block2/conv2/se/add1): OnnxBinaryMathOperation()
      (block2/conv2/se/relu): ReLU()
      (block2/conv2/se/matmul2): OnnxMatMul()
      (block2/conv2/se/add2): OnnxBinaryMathOperation()
      (block2/conv2/se/reshape): OnnxReshape()
      (block2/conv2/se/split): OnnxSplit13()
      (block2/conv2/se/sigmoid): Sigmoid()
      (block2/conv2/se/mul): OnnxBinaryMathOperation()
      (block2/conv2/se/add3): OnnxBinaryMathOperation()
      (block2/conv2/mixin): OnnxBinaryMathOperation()
      (block2/conv2/relu): ReLU()
      (block3/conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (block3/conv1/relu): ReLU()
      (block3/conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (block3/conv2/se/pooled): OnnxGlobalAveragePoolWithKnownInputShape()
      (block3/conv2/se/squeeze): OnnxSqueezeDynamicAxes()
      (block3/conv2/se/matmul1): OnnxMatMul()
      (block3/conv2/se/add1): OnnxBinaryMathOperation()
      (block3/conv2/se/relu): ReLU()
      (block3/conv2/se/matmul2): OnnxMatMul()
      (block3/conv2/se/add2): OnnxBinaryMathOperation()
      (block3/conv2/se/reshape): OnnxReshape()
      (block3/conv2/se/split): OnnxSplit13()
      (block3/conv2/se/sigmoid): Sigmoid()
      (block3/conv2/se/mul): OnnxBinaryMathOperation()
      (block3/conv2/se/add3): OnnxBinaryMathOperation()
      (block3/conv2/mixin): OnnxBinaryMathOperation()
      (block3/conv2/relu): ReLU()
      (block4/conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (block4/conv1/relu): ReLU()
      (block4/conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (block4/conv2/se/pooled): OnnxGlobalAveragePoolWithKnownInputShape()
      (block4/conv2/se/squeeze): OnnxSqueezeDynamicAxes()
      (block4/conv2/se/matmul1): OnnxMatMul()
      (block4/conv2/se/add1): OnnxBinaryMathOperation()
      (block4/conv2/se/relu): ReLU()
      (block4/conv2/se/matmul2): OnnxMatMul()
      (block4/conv2/se/add2): OnnxBinaryMathOperation()
      (block4/conv2/se/reshape): OnnxReshape()
      (block4/conv2/se/split): OnnxSplit13()
      (block4/conv2/se/sigmoid): Sigmoid()
      (block4/conv2/se/mul): OnnxBinaryMathOperation()
      (block4/conv2/se/add3): OnnxBinaryMathOperation()
      (block4/conv2/mixin): OnnxBinaryMathOperation()
      (block4/conv2/relu): ReLU()
      (block5/conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (block5/conv1/relu): ReLU()
      (block5/conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (block5/conv2/se/pooled): OnnxGlobalAveragePoolWithKnownInputShape()
      (block5/conv2/se/squeeze): OnnxSqueezeDynamicAxes()
      (block5/conv2/se/matmul1): OnnxMatMul()
      (block5/conv2/se/add1): OnnxBinaryMathOperation()
      (block5/conv2/se/relu): ReLU()
      (block5/conv2/se/matmul2): OnnxMatMul()
      (block5/conv2/se/add2): OnnxBinaryMathOperation()
      (block5/conv2/se/reshape): OnnxReshape()
      (block5/conv2/se/split): OnnxSplit13()
      (block5/conv2/se/sigmoid): Sigmoid()
      (block5/conv2/se/mul): OnnxBinaryMathOperation()
      (block5/conv2/se/add3): OnnxBinaryMathOperation()
      (block5/conv2/mixin): OnnxBinaryMathOperation()
      (block5/conv2/relu): ReLU()
      (block6/conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (block6/conv1/relu): ReLU()
      (block6/conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (block6/conv2/se/pooled): OnnxGlobalAveragePoolWithKnownInputShape()
      (block6/conv2/se/squeeze): OnnxSqueezeDynamicAxes()
      (block6/conv2/se/matmul1): OnnxMatMul()
      (block6/conv2/se/add1): OnnxBinaryMathOperation()
      (block6/conv2/se/relu): ReLU()
      (block6/conv2/se/matmul2): OnnxMatMul()
      (block6/conv2/se/add2): OnnxBinaryMathOperation()
      (block6/conv2/se/reshape): OnnxReshape()
      (block6/conv2/se/split): OnnxSplit13()
      (block6/conv2/se/sigmoid): Sigmoid()
      (block6/conv2/se/mul): OnnxBinaryMathOperation()
      (block6/conv2/se/add3): OnnxBinaryMathOperation()
      (block6/conv2/mixin): OnnxBinaryMathOperation()
      (block6/conv2/relu): ReLU()
      (block7/conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (block7/conv1/relu): ReLU()
      (block7/conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (block7/conv2/se/pooled): OnnxGlobalAveragePoolWithKnownInputShape()
      (block7/conv2/se/squeeze): OnnxSqueezeDynamicAxes()
      (block7/conv2/se/matmul1): OnnxMatMul()
      (block7/conv2/se/add1): OnnxBinaryMathOperation()
      (block7/conv2/se/relu): ReLU()
      (block7/conv2/se/matmul2): OnnxMatMul()
      (block7/conv2/se/add2): OnnxBinaryMathOperation()
      (block7/conv2/se/reshape): OnnxReshape()
      (block7/conv2/se/split): OnnxSplit13()
      (block7/conv2/se/sigmoid): Sigmoid()
      (block7/conv2/se/mul): OnnxBinaryMathOperation()
      (block7/conv2/se/add3): OnnxBinaryMathOperation()
      (block7/conv2/mixin): OnnxBinaryMathOperation()
      (block7/conv2/relu): ReLU()
      (block8/conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (block8/conv1/relu): ReLU()
      (block8/conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (block8/conv2/se/pooled): OnnxGlobalAveragePoolWithKnownInputShape()
      (block8/conv2/se/squeeze): OnnxSqueezeDynamicAxes()
      (block8/conv2/se/matmul1): OnnxMatMul()
      (block8/conv2/se/add1): OnnxBinaryMathOperation()
      (block8/conv2/se/relu): ReLU()
      (block8/conv2/se/matmul2): OnnxMatMul()
      (block8/conv2/se/add2): OnnxBinaryMathOperation()
      (block8/conv2/se/reshape): OnnxReshape()
      (block8/conv2/se/split): OnnxSplit13()
      (block8/conv2/se/sigmoid): Sigmoid()
      (block8/conv2/se/mul): OnnxBinaryMathOperation()
      (block8/conv2/se/add3): OnnxBinaryMathOperation()
      (block8/conv2/mixin): OnnxBinaryMathOperation()
      (block8/conv2/relu): ReLU()
      (block9/conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (block9/conv1/relu): ReLU()
      (block9/conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (block9/conv2/se/pooled): OnnxGlobalAveragePoolWithKnownInputShape()
      (block9/conv2/se/squeeze): OnnxSqueezeDynamicAxes()
      (block9/conv2/se/matmul1): OnnxMatMul()
      (block9/conv2/se/add1): OnnxBinaryMathOperation()
      (block9/conv2/se/relu): ReLU()
      (block9/conv2/se/matmul2): OnnxMatMul()
      (block9/conv2/se/add2): OnnxBinaryMathOperation()
      (block9/conv2/se/reshape): OnnxReshape()
      (block9/conv2/se/split): OnnxSplit13()
      (block9/conv2/se/sigmoid): Sigmoid()
      (block9/conv2/se/mul): OnnxBinaryMathOperation()
      (block9/conv2/se/add3): OnnxBinaryMathOperation()
      (block9/conv2/mixin): OnnxBinaryMathOperation()
      (block9/conv2/relu): ReLU()
      (block10/conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (block10/conv1/relu): ReLU()
      (block10/conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (block10/conv2/se/pooled): OnnxGlobalAveragePoolWithKnownInputShape()
      (block10/conv2/se/squeeze): OnnxSqueezeDynamicAxes()
      (block10/conv2/se/matmul1): OnnxMatMul()
      (block10/conv2/se/add1): OnnxBinaryMathOperation()
      (block10/conv2/se/relu): ReLU()
      (block10/conv2/se/matmul2): OnnxMatMul()
      (block10/conv2/se/add2): OnnxBinaryMathOperation()
      (block10/conv2/se/reshape): OnnxReshape()
      (block10/conv2/se/split): OnnxSplit13()
      (block10/conv2/se/sigmoid): Sigmoid()
      (block10/conv2/se/mul): OnnxBinaryMathOperation()
      (block10/conv2/se/add3): OnnxBinaryMathOperation()
      (block10/conv2/mixin): OnnxBinaryMathOperation()
      (block10/conv2/relu): ReLU()
      (block11/conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (block11/conv1/relu): ReLU()
      (block11/conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (block11/conv2/se/pooled): OnnxGlobalAveragePoolWithKnownInputShape()
      (block11/conv2/se/squeeze): OnnxSqueezeDynamicAxes()
      (block11/conv2/se/matmul1): OnnxMatMul()
      (block11/conv2/se/add1): OnnxBinaryMathOperation()
      (block11/conv2/se/relu): ReLU()
      (block11/conv2/se/matmul2): OnnxMatMul()
      (block11/conv2/se/add2): OnnxBinaryMathOperation()
      (block11/conv2/se/reshape): OnnxReshape()
      (block11/conv2/se/split): OnnxSplit13()
      (block11/conv2/se/sigmoid): Sigmoid()
      (block11/conv2/se/mul): OnnxBinaryMathOperation()
      (block11/conv2/se/add3): OnnxBinaryMathOperation()
      (block11/conv2/mixin): OnnxBinaryMathOperation()
      (block11/conv2/relu): ReLU()
      (block12/conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (block12/conv1/relu): ReLU()
      (block12/conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (block12/conv2/se/pooled): OnnxGlobalAveragePoolWithKnownInputShape()
      (block12/conv2/se/squeeze): OnnxSqueezeDynamicAxes()
      (block12/conv2/se/matmul1): OnnxMatMul()
      (block12/conv2/se/add1): OnnxBinaryMathOperation()
      (block12/conv2/se/relu): ReLU()
      (block12/conv2/se/matmul2): OnnxMatMul()
      (block12/conv2/se/add2): OnnxBinaryMathOperation()
      (block12/conv2/se/reshape): OnnxReshape()
      (block12/conv2/se/split): OnnxSplit13()
      (block12/conv2/se/sigmoid): Sigmoid()
      (block12/conv2/se/mul): OnnxBinaryMathOperation()
      (block12/conv2/se/add3): OnnxBinaryMathOperation()
      (block12/conv2/mixin): OnnxBinaryMathOperation()
      (block12/conv2/relu): ReLU()
      (block13/conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (block13/conv1/relu): ReLU()
      (block13/conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (block13/conv2/se/pooled): OnnxGlobalAveragePoolWithKnownInputShape()
      (block13/conv2/se/squeeze): OnnxSqueezeDynamicAxes()
      (block13/conv2/se/matmul1): OnnxMatMul()
      (block13/conv2/se/add1): OnnxBinaryMathOperation()
      (block13/conv2/se/relu): ReLU()
      (block13/conv2/se/matmul2): OnnxMatMul()
      (block13/conv2/se/add2): OnnxBinaryMathOperation()
      (block13/conv2/se/reshape): OnnxReshape()
      (block13/conv2/se/split): OnnxSplit13()
      (block13/conv2/se/sigmoid): Sigmoid()
      (block13/conv2/se/mul): OnnxBinaryMathOperation()
      (block13/conv2/se/add3): OnnxBinaryMathOperation()
      (block13/conv2/mixin): OnnxBinaryMathOperation()
      (block13/conv2/relu): ReLU()
      (block14/conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (block14/conv1/relu): ReLU()
      (block14/conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (block14/conv2/se/pooled): OnnxGlobalAveragePoolWithKnownInputShape()
      (block14/conv2/se/squeeze): OnnxSqueezeDynamicAxes()
      (block14/conv2/se/matmul1): OnnxMatMul()
      (block14/conv2/se/add1): OnnxBinaryMathOperation()
      (block14/conv2/se/relu): ReLU()
      (block14/conv2/se/matmul2): OnnxMatMul()
      (block14/conv2/se/add2): OnnxBinaryMathOperation()
      (block14/conv2/se/reshape): OnnxReshape()
      (block14/conv2/se/split): OnnxSplit13()
      (block14/conv2/se/sigmoid): Sigmoid()
      (block14/conv2/se/mul): OnnxBinaryMathOperation()
      (block14/conv2/se/add3): OnnxBinaryMathOperation()
      (block14/conv2/mixin): OnnxBinaryMathOperation()
      (block14/conv2/relu): ReLU()
      (block15/conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (block15/conv1/relu): ReLU()
      (block15/conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (block15/conv2/se/pooled): OnnxGlobalAveragePoolWithKnownInputShape()
      (block15/conv2/se/squeeze): OnnxSqueezeDynamicAxes()
      (block15/conv2/se/matmul1): OnnxMatMul()
      (block15/conv2/se/add1): OnnxBinaryMathOperation()
      (block15/conv2/se/relu): ReLU()
      (block15/conv2/se/matmul2): OnnxMatMul()
      (block15/conv2/se/add2): OnnxBinaryMathOperation()
      (block15/conv2/se/reshape): OnnxReshape()
      (block15/conv2/se/split): OnnxSplit13()
      (block15/conv2/se/sigmoid): Sigmoid()
      (block15/conv2/se/mul): OnnxBinaryMathOperation()
      (block15/conv2/se/add3): OnnxBinaryMathOperation()
      (block15/conv2/mixin): OnnxBinaryMathOperation()
      (block15/conv2/relu): ReLU()
      (block16/conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (block16/conv1/relu): ReLU()
      (block16/conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (block16/conv2/se/pooled): OnnxGlobalAveragePoolWithKnownInputShape()
      (block16/conv2/se/squeeze): OnnxSqueezeDynamicAxes()
      (block16/conv2/se/matmul1): OnnxMatMul()
      (block16/conv2/se/add1): OnnxBinaryMathOperation()
      (block16/conv2/se/relu): ReLU()
      (block16/conv2/se/matmul2): OnnxMatMul()
      (block16/conv2/se/add2): OnnxBinaryMathOperation()
      (block16/conv2/se/reshape): OnnxReshape()
      (block16/conv2/se/split): OnnxSplit13()
      (block16/conv2/se/sigmoid): Sigmoid()
      (block16/conv2/se/mul): OnnxBinaryMathOperation()
      (block16/conv2/se/add3): OnnxBinaryMathOperation()
      (block16/conv2/mixin): OnnxBinaryMathOperation()
      (block16/conv2/relu): ReLU()
      (block17/conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (block17/conv1/relu): ReLU()
      (block17/conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (block17/conv2/se/pooled): OnnxGlobalAveragePoolWithKnownInputShape()
      (block17/conv2/se/squeeze): OnnxSqueezeDynamicAxes()
      (block17/conv2/se/matmul1): OnnxMatMul()
      (block17/conv2/se/add1): OnnxBinaryMathOperation()
      (block17/conv2/se/relu): ReLU()
      (block17/conv2/se/matmul2): OnnxMatMul()
      (block17/conv2/se/add2): OnnxBinaryMathOperation()
      (block17/conv2/se/reshape): OnnxReshape()
      (block17/conv2/se/split): OnnxSplit13()
      (block17/conv2/se/sigmoid): Sigmoid()
      (block17/conv2/se/mul): OnnxBinaryMathOperation()
      (block17/conv2/se/add3): OnnxBinaryMathOperation()
      (block17/conv2/mixin): OnnxBinaryMathOperation()
      (block17/conv2/relu): ReLU()
      (block18/conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (block18/conv1/relu): ReLU()
      (block18/conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (block18/conv2/se/pooled): OnnxGlobalAveragePoolWithKnownInputShape()
      (block18/conv2/se/squeeze): OnnxSqueezeDynamicAxes()
      (block18/conv2/se/matmul1): OnnxMatMul()
      (block18/conv2/se/add1): OnnxBinaryMathOperation()
      (block18/conv2/se/relu): ReLU()
      (block18/conv2/se/matmul2): OnnxMatMul()
      (block18/conv2/se/add2): OnnxBinaryMathOperation()
      (block18/conv2/se/reshape): OnnxReshape()
      (block18/conv2/se/split): OnnxSplit13()
      (block18/conv2/se/sigmoid): Sigmoid()
      (block18/conv2/se/mul): OnnxBinaryMathOperation()
      (block18/conv2/se/add3): OnnxBinaryMathOperation()
      (block18/conv2/mixin): OnnxBinaryMathOperation()
      (block18/conv2/relu): ReLU()
      (policy/conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (policy/conv1/relu): ReLU()
      (policy/conv2): Conv2d(256, 80, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (policy/flatten): OnnxReshape()
      (output/policy): OnnxGather()
      (value/conv): Conv2d(256, 32, kernel_size=(1, 1), stride=(1, 1))
      (value/conv/relu): ReLU()
      (value/reshape): OnnxReshape()
      (value/dense1/matmul): OnnxMatMul()
      (value/dense1/add): OnnxBinaryMathOperation()
      (value/dense1/relu): ReLU()
      (value/dense2/matmul): OnnxMatMul()
      (value/dense2/add): OnnxBinaryMathOperation()
      (output/wdl): Softmax(dim=1)
      (mlh/conv): Conv2d(256, 8, kernel_size=(1, 1), stride=(1, 1))
      (mlh/conv/relu): ReLU()
      (mlh/reshape): OnnxReshape()
      (mlh/dense1/matmul): OnnxMatMul()
      (mlh/dense1/add): OnnxBinaryMathOperation()
      (mlh/dense1/relu): ReLU()
      (mlh/dense2/matmul): OnnxMatMul()
      (mlh/dense2/add): OnnxBinaryMathOperation()
      (mlh/dense2/relu): ReLU()
      (output/mlh): OnnxCopyIdentity()
    )



    def forward(self, input_1):
        inputconv = self.inputconv(input_1);  input_1 = None
        inputconv_relu = getattr(self, "inputconv/relu")(inputconv);  inputconv = None
        block0_conv1 = getattr(self, "block0/conv1")(inputconv_relu)
        block0_conv1_relu = getattr(self, "block0/conv1/relu")(block0_conv1);  block0_conv1 = None
        block0_conv2 = getattr(self, "block0/conv2")(block0_conv1_relu);  block0_conv1_relu = None
        block0_conv2_se_pooled = getattr(self, "block0/conv2/se/pooled")(block0_conv2)
        initializers_onnx_initializer_0 = self.initializers.onnx_initializer_0
        block0_conv2_se_squeeze = getattr(self, "block0/conv2/se/squeeze")(block0_conv2_se_pooled, initializers_onnx_initializer_0);  block0_conv2_se_pooled = initializers_onnx_initializer_0 = None
        initializers_onnx_initializer_1 = self.initializers.onnx_initializer_1
        block0_conv2_se_matmul1 = getattr(self, "block0/conv2/se/matmul1")(block0_conv2_se_squeeze, initializers_onnx_initializer_1);  block0_conv2_se_squeeze = initializers_onnx_initializer_1 = None
        initializers_onnx_initializer_2 = self.initializers.onnx_initializer_2
        block0_conv2_se_add1 = getattr(self, "block0/conv2/se/add1")(block0_conv2_se_matmul1, initializers_onnx_initializer_2);  block0_conv2_se_matmul1 = initializers_onnx_initializer_2 = None
        block0_conv2_se_relu = getattr(self, "block0/conv2/se/relu")(block0_conv2_se_add1);  block0_conv2_se_add1 = None
        initializers_onnx_initializer_3 = self.initializers.onnx_initializer_3
        block0_conv2_se_matmul2 = getattr(self, "block0/conv2/se/matmul2")(block0_conv2_se_relu, initializers_onnx_initializer_3);  block0_conv2_se_relu = initializers_onnx_initializer_3 = None
        initializers_onnx_initializer_4 = self.initializers.onnx_initializer_4
        block0_conv2_se_add2 = getattr(self, "block0/conv2/se/add2")(block0_conv2_se_matmul2, initializers_onnx_initializer_4);  block0_conv2_se_matmul2 = initializers_onnx_initializer_4 = None
        initializers_onnx_initializer_5 = self.initializers.onnx_initializer_5
        block0_conv2_se_reshape = getattr(self, "block0/conv2/se/reshape")(block0_conv2_se_add2, initializers_onnx_initializer_5);  block0_conv2_se_add2 = initializers_onnx_initializer_5 = None
        block0_conv2_se_split = getattr(self, "block0/conv2/se/split")(block0_conv2_se_reshape);  block0_conv2_se_reshape = None
        getitem = block0_conv2_se_split[0]
        block0_conv2_se_sigmoid = getattr(self, "block0/conv2/se/sigmoid")(getitem);  getitem = None
        block0_conv2_se_mul = getattr(self, "block0/conv2/se/mul")(block0_conv2_se_sigmoid, block0_conv2);  block0_conv2_se_sigmoid = block0_conv2 = None
        getitem_1 = block0_conv2_se_split[1];  block0_conv2_se_split = None
        block0_conv2_se_add3 = getattr(self, "block0/conv2/se/add3")(block0_conv2_se_mul, getitem_1);  block0_conv2_se_mul = getitem_1 = None
        block0_conv2_mixin = getattr(self, "block0/conv2/mixin")(block0_conv2_se_add3, inputconv_relu);  block0_conv2_se_add3 = inputconv_relu = None
        block0_conv2_relu = getattr(self, "block0/conv2/relu")(block0_conv2_mixin);  block0_conv2_mixin = None
        block1_conv1 = getattr(self, "block1/conv1")(block0_conv2_relu)
        block1_conv1_relu = getattr(self, "block1/conv1/relu")(block1_conv1);  block1_conv1 = None
        block1_conv2 = getattr(self, "block1/conv2")(block1_conv1_relu);  block1_conv1_relu = None
        block1_conv2_se_pooled = getattr(self, "block1/conv2/se/pooled")(block1_conv2)
        initializers_onnx_initializer_6 = self.initializers.onnx_initializer_6
        block1_conv2_se_squeeze = getattr(self, "block1/conv2/se/squeeze")(block1_conv2_se_pooled, initializers_onnx_initializer_6);  block1_conv2_se_pooled = initializers_onnx_initializer_6 = None
        initializers_onnx_initializer_7 = self.initializers.onnx_initializer_7
        block1_conv2_se_matmul1 = getattr(self, "block1/conv2/se/matmul1")(block1_conv2_se_squeeze, initializers_onnx_initializer_7);  block1_conv2_se_squeeze = initializers_onnx_initializer_7 = None
        initializers_onnx_initializer_8 = self.initializers.onnx_initializer_8
        block1_conv2_se_add1 = getattr(self, "block1/conv2/se/add1")(block1_conv2_se_matmul1, initializers_onnx_initializer_8);  block1_conv2_se_matmul1 = initializers_onnx_initializer_8 = None
        block1_conv2_se_relu = getattr(self, "block1/conv2/se/relu")(block1_conv2_se_add1);  block1_conv2_se_add1 = None
        initializers_onnx_initializer_9 = self.initializers.onnx_initializer_9
        block1_conv2_se_matmul2 = getattr(self, "block1/conv2/se/matmul2")(block1_conv2_se_relu, initializers_onnx_initializer_9);  block1_conv2_se_relu = initializers_onnx_initializer_9 = None
        initializers_onnx_initializer_10 = self.initializers.onnx_initializer_10
        block1_conv2_se_add2 = getattr(self, "block1/conv2/se/add2")(block1_conv2_se_matmul2, initializers_onnx_initializer_10);  block1_conv2_se_matmul2 = initializers_onnx_initializer_10 = None
        initializers_onnx_initializer_11 = self.initializers.onnx_initializer_11
        block1_conv2_se_reshape = getattr(self, "block1/conv2/se/reshape")(block1_conv2_se_add2, initializers_onnx_initializer_11);  block1_conv2_se_add2 = initializers_onnx_initializer_11 = None
        block1_conv2_se_split = getattr(self, "block1/conv2/se/split")(block1_conv2_se_reshape);  block1_conv2_se_reshape = None
        getitem_2 = block1_conv2_se_split[0]
        block1_conv2_se_sigmoid = getattr(self, "block1/conv2/se/sigmoid")(getitem_2);  getitem_2 = None
        block1_conv2_se_mul = getattr(self, "block1/conv2/se/mul")(block1_conv2_se_sigmoid, block1_conv2);  block1_conv2_se_sigmoid = block1_conv2 = None
        getitem_3 = block1_conv2_se_split[1];  block1_conv2_se_split = None
        block1_conv2_se_add3 = getattr(self, "block1/conv2/se/add3")(block1_conv2_se_mul, getitem_3);  block1_conv2_se_mul = getitem_3 = None
        block1_conv2_mixin = getattr(self, "block1/conv2/mixin")(block1_conv2_se_add3, block0_conv2_relu);  block1_conv2_se_add3 = block0_conv2_relu = None
        block1_conv2_relu = getattr(self, "block1/conv2/relu")(block1_conv2_mixin);  block1_conv2_mixin = None
        block2_conv1 = getattr(self, "block2/conv1")(block1_conv2_relu)
        block2_conv1_relu = getattr(self, "block2/conv1/relu")(block2_conv1);  block2_conv1 = None
        block2_conv2 = getattr(self, "block2/conv2")(block2_conv1_relu);  block2_conv1_relu = None
        block2_conv2_se_pooled = getattr(self, "block2/conv2/se/pooled")(block2_conv2)
        initializers_onnx_initializer_12 = self.initializers.onnx_initializer_12
        block2_conv2_se_squeeze = getattr(self, "block2/conv2/se/squeeze")(block2_conv2_se_pooled, initializers_onnx_initializer_12);  block2_conv2_se_pooled = initializers_onnx_initializer_12 = None
        initializers_onnx_initializer_13 = self.initializers.onnx_initializer_13
        block2_conv2_se_matmul1 = getattr(self, "block2/conv2/se/matmul1")(block2_conv2_se_squeeze, initializers_onnx_initializer_13);  block2_conv2_se_squeeze = initializers_onnx_initializer_13 = None
        initializers_onnx_initializer_14 = self.initializers.onnx_initializer_14
        block2_conv2_se_add1 = getattr(self, "block2/conv2/se/add1")(block2_conv2_se_matmul1, initializers_onnx_initializer_14);  block2_conv2_se_matmul1 = initializers_onnx_initializer_14 = None
        block2_conv2_se_relu = getattr(self, "block2/conv2/se/relu")(block2_conv2_se_add1);  block2_conv2_se_add1 = None
        initializers_onnx_initializer_15 = self.initializers.onnx_initializer_15
        block2_conv2_se_matmul2 = getattr(self, "block2/conv2/se/matmul2")(block2_conv2_se_relu, initializers_onnx_initializer_15);  block2_conv2_se_relu = initializers_onnx_initializer_15 = None
        initializers_onnx_initializer_16 = self.initializers.onnx_initializer_16
        block2_conv2_se_add2 = getattr(self, "block2/conv2/se/add2")(block2_conv2_se_matmul2, initializers_onnx_initializer_16);  block2_conv2_se_matmul2 = initializers_onnx_initializer_16 = None
        initializers_onnx_initializer_17 = self.initializers.onnx_initializer_17
        block2_conv2_se_reshape = getattr(self, "block2/conv2/se/reshape")(block2_conv2_se_add2, initializers_onnx_initializer_17);  block2_conv2_se_add2 = initializers_onnx_initializer_17 = None
        block2_conv2_se_split = getattr(self, "block2/conv2/se/split")(block2_conv2_se_reshape);  block2_conv2_se_reshape = None
        getitem_4 = block2_conv2_se_split[0]
        block2_conv2_se_sigmoid = getattr(self, "block2/conv2/se/sigmoid")(getitem_4);  getitem_4 = None
        block2_conv2_se_mul = getattr(self, "block2/conv2/se/mul")(block2_conv2_se_sigmoid, block2_conv2);  block2_conv2_se_sigmoid = block2_conv2 = None
        getitem_5 = block2_conv2_se_split[1];  block2_conv2_se_split = None
        block2_conv2_se_add3 = getattr(self, "block2/conv2/se/add3")(block2_conv2_se_mul, getitem_5);  block2_conv2_se_mul = getitem_5 = None
        block2_conv2_mixin = getattr(self, "block2/conv2/mixin")(block2_conv2_se_add3, block1_conv2_relu);  block2_conv2_se_add3 = block1_conv2_relu = None
        block2_conv2_relu = getattr(self, "block2/conv2/relu")(block2_conv2_mixin);  block2_conv2_mixin = None
        block3_conv1 = getattr(self, "block3/conv1")(block2_conv2_relu)
        block3_conv1_relu = getattr(self, "block3/conv1/relu")(block3_conv1);  block3_conv1 = None
        block3_conv2 = getattr(self, "block3/conv2")(block3_conv1_relu);  block3_conv1_relu = None
        block3_conv2_se_pooled = getattr(self, "block3/conv2/se/pooled")(block3_conv2)
        initializers_onnx_initializer_18 = self.initializers.onnx_initializer_18
        block3_conv2_se_squeeze = getattr(self, "block3/conv2/se/squeeze")(block3_conv2_se_pooled, initializers_onnx_initializer_18);  block3_conv2_se_pooled = initializers_onnx_initializer_18 = None
        initializers_onnx_initializer_19 = self.initializers.onnx_initializer_19
        block3_conv2_se_matmul1 = getattr(self, "block3/conv2/se/matmul1")(block3_conv2_se_squeeze, initializers_onnx_initializer_19);  block3_conv2_se_squeeze = initializers_onnx_initializer_19 = None
        initializers_onnx_initializer_20 = self.initializers.onnx_initializer_20
        block3_conv2_se_add1 = getattr(self, "block3/conv2/se/add1")(block3_conv2_se_matmul1, initializers_onnx_initializer_20);  block3_conv2_se_matmul1 = initializers_onnx_initializer_20 = None
        block3_conv2_se_relu = getattr(self, "block3/conv2/se/relu")(block3_conv2_se_add1);  block3_conv2_se_add1 = None
        initializers_onnx_initializer_21 = self.initializers.onnx_initializer_21
        block3_conv2_se_matmul2 = getattr(self, "block3/conv2/se/matmul2")(block3_conv2_se_relu, initializers_onnx_initializer_21);  block3_conv2_se_relu = initializers_onnx_initializer_21 = None
        initializers_onnx_initializer_22 = self.initializers.onnx_initializer_22
        block3_conv2_se_add2 = getattr(self, "block3/conv2/se/add2")(block3_conv2_se_matmul2, initializers_onnx_initializer_22);  block3_conv2_se_matmul2 = initializers_onnx_initializer_22 = None
        initializers_onnx_initializer_23 = self.initializers.onnx_initializer_23
        block3_conv2_se_reshape = getattr(self, "block3/conv2/se/reshape")(block3_conv2_se_add2, initializers_onnx_initializer_23);  block3_conv2_se_add2 = initializers_onnx_initializer_23 = None
        block3_conv2_se_split = getattr(self, "block3/conv2/se/split")(block3_conv2_se_reshape);  block3_conv2_se_reshape = None
        getitem_6 = block3_conv2_se_split[0]
        block3_conv2_se_sigmoid = getattr(self, "block3/conv2/se/sigmoid")(getitem_6);  getitem_6 = None
        block3_conv2_se_mul = getattr(self, "block3/conv2/se/mul")(block3_conv2_se_sigmoid, block3_conv2);  block3_conv2_se_sigmoid = block3_conv2 = None
        getitem_7 = block3_conv2_se_split[1];  block3_conv2_se_split = None
        block3_conv2_se_add3 = getattr(self, "block3/conv2/se/add3")(block3_conv2_se_mul, getitem_7);  block3_conv2_se_mul = getitem_7 = None
        block3_conv2_mixin = getattr(self, "block3/conv2/mixin")(block3_conv2_se_add3, block2_conv2_relu);  block3_conv2_se_add3 = block2_conv2_relu = None
        block3_conv2_relu = getattr(self, "block3/conv2/relu")(block3_conv2_mixin);  block3_conv2_mixin = None
        block4_conv1 = getattr(self, "block4/conv1")(block3_conv2_relu)
        block4_conv1_relu = getattr(self, "block4/conv1/relu")(block4_conv1);  block4_conv1 = None
        block4_conv2 = getattr(self, "block4/conv2")(block4_conv1_relu);  block4_conv1_relu = None
        block4_conv2_se_pooled = getattr(self, "block4/conv2/se/pooled")(block4_conv2)
        initializers_onnx_initializer_24 = self.initializers.onnx_initializer_24
        block4_conv2_se_squeeze = getattr(self, "block4/conv2/se/squeeze")(block4_conv2_se_pooled, initializers_onnx_initializer_24);  block4_conv2_se_pooled = initializers_onnx_initializer_24 = None
        initializers_onnx_initializer_25 = self.initializers.onnx_initializer_25
        block4_conv2_se_matmul1 = getattr(self, "block4/conv2/se/matmul1")(block4_conv2_se_squeeze, initializers_onnx_initializer_25);  block4_conv2_se_squeeze = initializers_onnx_initializer_25 = None
        initializers_onnx_initializer_26 = self.initializers.onnx_initializer_26
        block4_conv2_se_add1 = getattr(self, "block4/conv2/se/add1")(block4_conv2_se_matmul1, initializers_onnx_initializer_26);  block4_conv2_se_matmul1 = initializers_onnx_initializer_26 = None
        block4_conv2_se_relu = getattr(self, "block4/conv2/se/relu")(block4_conv2_se_add1);  block4_conv2_se_add1 = None
        initializers_onnx_initializer_27 = self.initializers.onnx_initializer_27
        block4_conv2_se_matmul2 = getattr(self, "block4/conv2/se/matmul2")(block4_conv2_se_relu, initializers_onnx_initializer_27);  block4_conv2_se_relu = initializers_onnx_initializer_27 = None
        initializers_onnx_initializer_28 = self.initializers.onnx_initializer_28
        block4_conv2_se_add2 = getattr(self, "block4/conv2/se/add2")(block4_conv2_se_matmul2, initializers_onnx_initializer_28);  block4_conv2_se_matmul2 = initializers_onnx_initializer_28 = None
        initializers_onnx_initializer_29 = self.initializers.onnx_initializer_29
        block4_conv2_se_reshape = getattr(self, "block4/conv2/se/reshape")(block4_conv2_se_add2, initializers_onnx_initializer_29);  block4_conv2_se_add2 = initializers_onnx_initializer_29 = None
        block4_conv2_se_split = getattr(self, "block4/conv2/se/split")(block4_conv2_se_reshape);  block4_conv2_se_reshape = None
        getitem_8 = block4_conv2_se_split[0]
        block4_conv2_se_sigmoid = getattr(self, "block4/conv2/se/sigmoid")(getitem_8);  getitem_8 = None
        block4_conv2_se_mul = getattr(self, "block4/conv2/se/mul")(block4_conv2_se_sigmoid, block4_conv2);  block4_conv2_se_sigmoid = block4_conv2 = None
        getitem_9 = block4_conv2_se_split[1];  block4_conv2_se_split = None
        block4_conv2_se_add3 = getattr(self, "block4/conv2/se/add3")(block4_conv2_se_mul, getitem_9);  block4_conv2_se_mul = getitem_9 = None
        block4_conv2_mixin = getattr(self, "block4/conv2/mixin")(block4_conv2_se_add3, block3_conv2_relu);  block4_conv2_se_add3 = block3_conv2_relu = None
        block4_conv2_relu = getattr(self, "block4/conv2/relu")(block4_conv2_mixin);  block4_conv2_mixin = None
        block5_conv1 = getattr(self, "block5/conv1")(block4_conv2_relu)
        block5_conv1_relu = getattr(self, "block5/conv1/relu")(block5_conv1);  block5_conv1 = None
        block5_conv2 = getattr(self, "block5/conv2")(block5_conv1_relu);  block5_conv1_relu = None
        block5_conv2_se_pooled = getattr(self, "block5/conv2/se/pooled")(block5_conv2)
        initializers_onnx_initializer_30 = self.initializers.onnx_initializer_30
        block5_conv2_se_squeeze = getattr(self, "block5/conv2/se/squeeze")(block5_conv2_se_pooled, initializers_onnx_initializer_30);  block5_conv2_se_pooled = initializers_onnx_initializer_30 = None
        initializers_onnx_initializer_31 = self.initializers.onnx_initializer_31
        block5_conv2_se_matmul1 = getattr(self, "block5/conv2/se/matmul1")(block5_conv2_se_squeeze, initializers_onnx_initializer_31);  block5_conv2_se_squeeze = initializers_onnx_initializer_31 = None
        initializers_onnx_initializer_32 = self.initializers.onnx_initializer_32
        block5_conv2_se_add1 = getattr(self, "block5/conv2/se/add1")(block5_conv2_se_matmul1, initializers_onnx_initializer_32);  block5_conv2_se_matmul1 = initializers_onnx_initializer_32 = None
        block5_conv2_se_relu = getattr(self, "block5/conv2/se/relu")(block5_conv2_se_add1);  block5_conv2_se_add1 = None
        initializers_onnx_initializer_33 = self.initializers.onnx_initializer_33
        block5_conv2_se_matmul2 = getattr(self, "block5/conv2/se/matmul2")(block5_conv2_se_relu, initializers_onnx_initializer_33);  block5_conv2_se_relu = initializers_onnx_initializer_33 = None
        initializers_onnx_initializer_34 = self.initializers.onnx_initializer_34
        block5_conv2_se_add2 = getattr(self, "block5/conv2/se/add2")(block5_conv2_se_matmul2, initializers_onnx_initializer_34);  block5_conv2_se_matmul2 = initializers_onnx_initializer_34 = None
        initializers_onnx_initializer_35 = self.initializers.onnx_initializer_35
        block5_conv2_se_reshape = getattr(self, "block5/conv2/se/reshape")(block5_conv2_se_add2, initializers_onnx_initializer_35);  block5_conv2_se_add2 = initializers_onnx_initializer_35 = None
        block5_conv2_se_split = getattr(self, "block5/conv2/se/split")(block5_conv2_se_reshape);  block5_conv2_se_reshape = None
        getitem_10 = block5_conv2_se_split[0]
        block5_conv2_se_sigmoid = getattr(self, "block5/conv2/se/sigmoid")(getitem_10);  getitem_10 = None
        block5_conv2_se_mul = getattr(self, "block5/conv2/se/mul")(block5_conv2_se_sigmoid, block5_conv2);  block5_conv2_se_sigmoid = block5_conv2 = None
        getitem_11 = block5_conv2_se_split[1];  block5_conv2_se_split = None
        block5_conv2_se_add3 = getattr(self, "block5/conv2/se/add3")(block5_conv2_se_mul, getitem_11);  block5_conv2_se_mul = getitem_11 = None
        block5_conv2_mixin = getattr(self, "block5/conv2/mixin")(block5_conv2_se_add3, block4_conv2_relu);  block5_conv2_se_add3 = block4_conv2_relu = None
        block5_conv2_relu = getattr(self, "block5/conv2/relu")(block5_conv2_mixin);  block5_conv2_mixin = None
        block6_conv1 = getattr(self, "block6/conv1")(block5_conv2_relu)
        block6_conv1_relu = getattr(self, "block6/conv1/relu")(block6_conv1);  block6_conv1 = None
        block6_conv2 = getattr(self, "block6/conv2")(block6_conv1_relu);  block6_conv1_relu = None
        block6_conv2_se_pooled = getattr(self, "block6/conv2/se/pooled")(block6_conv2)
        initializers_onnx_initializer_36 = self.initializers.onnx_initializer_36
        block6_conv2_se_squeeze = getattr(self, "block6/conv2/se/squeeze")(block6_conv2_se_pooled, initializers_onnx_initializer_36);  block6_conv2_se_pooled = initializers_onnx_initializer_36 = None
        initializers_onnx_initializer_37 = self.initializers.onnx_initializer_37
        block6_conv2_se_matmul1 = getattr(self, "block6/conv2/se/matmul1")(block6_conv2_se_squeeze, initializers_onnx_initializer_37);  block6_conv2_se_squeeze = initializers_onnx_initializer_37 = None
        initializers_onnx_initializer_38 = self.initializers.onnx_initializer_38
        block6_conv2_se_add1 = getattr(self, "block6/conv2/se/add1")(block6_conv2_se_matmul1, initializers_onnx_initializer_38);  block6_conv2_se_matmul1 = initializers_onnx_initializer_38 = None
        block6_conv2_se_relu = getattr(self, "block6/conv2/se/relu")(block6_conv2_se_add1);  block6_conv2_se_add1 = None
        initializers_onnx_initializer_39 = self.initializers.onnx_initializer_39
        block6_conv2_se_matmul2 = getattr(self, "block6/conv2/se/matmul2")(block6_conv2_se_relu, initializers_onnx_initializer_39);  block6_conv2_se_relu = initializers_onnx_initializer_39 = None
        initializers_onnx_initializer_40 = self.initializers.onnx_initializer_40
        block6_conv2_se_add2 = getattr(self, "block6/conv2/se/add2")(block6_conv2_se_matmul2, initializers_onnx_initializer_40);  block6_conv2_se_matmul2 = initializers_onnx_initializer_40 = None
        initializers_onnx_initializer_41 = self.initializers.onnx_initializer_41
        block6_conv2_se_reshape = getattr(self, "block6/conv2/se/reshape")(block6_conv2_se_add2, initializers_onnx_initializer_41);  block6_conv2_se_add2 = initializers_onnx_initializer_41 = None
        block6_conv2_se_split = getattr(self, "block6/conv2/se/split")(block6_conv2_se_reshape);  block6_conv2_se_reshape = None
        getitem_12 = block6_conv2_se_split[0]
        block6_conv2_se_sigmoid = getattr(self, "block6/conv2/se/sigmoid")(getitem_12);  getitem_12 = None
        block6_conv2_se_mul = getattr(self, "block6/conv2/se/mul")(block6_conv2_se_sigmoid, block6_conv2);  block6_conv2_se_sigmoid = block6_conv2 = None
        getitem_13 = block6_conv2_se_split[1];  block6_conv2_se_split = None
        block6_conv2_se_add3 = getattr(self, "block6/conv2/se/add3")(block6_conv2_se_mul, getitem_13);  block6_conv2_se_mul = getitem_13 = None
        block6_conv2_mixin = getattr(self, "block6/conv2/mixin")(block6_conv2_se_add3, block5_conv2_relu);  block6_conv2_se_add3 = block5_conv2_relu = None
        block6_conv2_relu = getattr(self, "block6/conv2/relu")(block6_conv2_mixin);  block6_conv2_mixin = None
        block7_conv1 = getattr(self, "block7/conv1")(block6_conv2_relu)
        block7_conv1_relu = getattr(self, "block7/conv1/relu")(block7_conv1);  block7_conv1 = None
        block7_conv2 = getattr(self, "block7/conv2")(block7_conv1_relu);  block7_conv1_relu = None
        block7_conv2_se_pooled = getattr(self, "block7/conv2/se/pooled")(block7_conv2)
        initializers_onnx_initializer_42 = self.initializers.onnx_initializer_42
        block7_conv2_se_squeeze = getattr(self, "block7/conv2/se/squeeze")(block7_conv2_se_pooled, initializers_onnx_initializer_42);  block7_conv2_se_pooled = initializers_onnx_initializer_42 = None
        initializers_onnx_initializer_43 = self.initializers.onnx_initializer_43
        block7_conv2_se_matmul1 = getattr(self, "block7/conv2/se/matmul1")(block7_conv2_se_squeeze, initializers_onnx_initializer_43);  block7_conv2_se_squeeze = initializers_onnx_initializer_43 = None
        initializers_onnx_initializer_44 = self.initializers.onnx_initializer_44
        block7_conv2_se_add1 = getattr(self, "block7/conv2/se/add1")(block7_conv2_se_matmul1, initializers_onnx_initializer_44);  block7_conv2_se_matmul1 = initializers_onnx_initializer_44 = None
        block7_conv2_se_relu = getattr(self, "block7/conv2/se/relu")(block7_conv2_se_add1);  block7_conv2_se_add1 = None
        initializers_onnx_initializer_45 = self.initializers.onnx_initializer_45
        block7_conv2_se_matmul2 = getattr(self, "block7/conv2/se/matmul2")(block7_conv2_se_relu, initializers_onnx_initializer_45);  block7_conv2_se_relu = initializers_onnx_initializer_45 = None
        initializers_onnx_initializer_46 = self.initializers.onnx_initializer_46
        block7_conv2_se_add2 = getattr(self, "block7/conv2/se/add2")(block7_conv2_se_matmul2, initializers_onnx_initializer_46);  block7_conv2_se_matmul2 = initializers_onnx_initializer_46 = None
        initializers_onnx_initializer_47 = self.initializers.onnx_initializer_47
        block7_conv2_se_reshape = getattr(self, "block7/conv2/se/reshape")(block7_conv2_se_add2, initializers_onnx_initializer_47);  block7_conv2_se_add2 = initializers_onnx_initializer_47 = None
        block7_conv2_se_split = getattr(self, "block7/conv2/se/split")(block7_conv2_se_reshape);  block7_conv2_se_reshape = None
        getitem_14 = block7_conv2_se_split[0]
        block7_conv2_se_sigmoid = getattr(self, "block7/conv2/se/sigmoid")(getitem_14);  getitem_14 = None
        block7_conv2_se_mul = getattr(self, "block7/conv2/se/mul")(block7_conv2_se_sigmoid, block7_conv2);  block7_conv2_se_sigmoid = block7_conv2 = None
        getitem_15 = block7_conv2_se_split[1];  block7_conv2_se_split = None
        block7_conv2_se_add3 = getattr(self, "block7/conv2/se/add3")(block7_conv2_se_mul, getitem_15);  block7_conv2_se_mul = getitem_15 = None
        block7_conv2_mixin = getattr(self, "block7/conv2/mixin")(block7_conv2_se_add3, block6_conv2_relu);  block7_conv2_se_add3 = block6_conv2_relu = None
        block7_conv2_relu = getattr(self, "block7/conv2/relu")(block7_conv2_mixin);  block7_conv2_mixin = None
        block8_conv1 = getattr(self, "block8/conv1")(block7_conv2_relu)
        block8_conv1_relu = getattr(self, "block8/conv1/relu")(block8_conv1);  block8_conv1 = None
        block8_conv2 = getattr(self, "block8/conv2")(block8_conv1_relu);  block8_conv1_relu = None
        block8_conv2_se_pooled = getattr(self, "block8/conv2/se/pooled")(block8_conv2)
        initializers_onnx_initializer_48 = self.initializers.onnx_initializer_48
        block8_conv2_se_squeeze = getattr(self, "block8/conv2/se/squeeze")(block8_conv2_se_pooled, initializers_onnx_initializer_48);  block8_conv2_se_pooled = initializers_onnx_initializer_48 = None
        initializers_onnx_initializer_49 = self.initializers.onnx_initializer_49
        block8_conv2_se_matmul1 = getattr(self, "block8/conv2/se/matmul1")(block8_conv2_se_squeeze, initializers_onnx_initializer_49);  block8_conv2_se_squeeze = initializers_onnx_initializer_49 = None
        initializers_onnx_initializer_50 = self.initializers.onnx_initializer_50
        block8_conv2_se_add1 = getattr(self, "block8/conv2/se/add1")(block8_conv2_se_matmul1, initializers_onnx_initializer_50);  block8_conv2_se_matmul1 = initializers_onnx_initializer_50 = None
        block8_conv2_se_relu = getattr(self, "block8/conv2/se/relu")(block8_conv2_se_add1);  block8_conv2_se_add1 = None
        initializers_onnx_initializer_51 = self.initializers.onnx_initializer_51
        block8_conv2_se_matmul2 = getattr(self, "block8/conv2/se/matmul2")(block8_conv2_se_relu, initializers_onnx_initializer_51);  block8_conv2_se_relu = initializers_onnx_initializer_51 = None
        initializers_onnx_initializer_52 = self.initializers.onnx_initializer_52
        block8_conv2_se_add2 = getattr(self, "block8/conv2/se/add2")(block8_conv2_se_matmul2, initializers_onnx_initializer_52);  block8_conv2_se_matmul2 = initializers_onnx_initializer_52 = None
        initializers_onnx_initializer_53 = self.initializers.onnx_initializer_53
        block8_conv2_se_reshape = getattr(self, "block8/conv2/se/reshape")(block8_conv2_se_add2, initializers_onnx_initializer_53);  block8_conv2_se_add2 = initializers_onnx_initializer_53 = None
        block8_conv2_se_split = getattr(self, "block8/conv2/se/split")(block8_conv2_se_reshape);  block8_conv2_se_reshape = None
        getitem_16 = block8_conv2_se_split[0]
        block8_conv2_se_sigmoid = getattr(self, "block8/conv2/se/sigmoid")(getitem_16);  getitem_16 = None
        block8_conv2_se_mul = getattr(self, "block8/conv2/se/mul")(block8_conv2_se_sigmoid, block8_conv2);  block8_conv2_se_sigmoid = block8_conv2 = None
        getitem_17 = block8_conv2_se_split[1];  block8_conv2_se_split = None
        block8_conv2_se_add3 = getattr(self, "block8/conv2/se/add3")(block8_conv2_se_mul, getitem_17);  block8_conv2_se_mul = getitem_17 = None
        block8_conv2_mixin = getattr(self, "block8/conv2/mixin")(block8_conv2_se_add3, block7_conv2_relu);  block8_conv2_se_add3 = block7_conv2_relu = None
        block8_conv2_relu = getattr(self, "block8/conv2/relu")(block8_conv2_mixin);  block8_conv2_mixin = None
        block9_conv1 = getattr(self, "block9/conv1")(block8_conv2_relu)
        block9_conv1_relu = getattr(self, "block9/conv1/relu")(block9_conv1);  block9_conv1 = None
        block9_conv2 = getattr(self, "block9/conv2")(block9_conv1_relu);  block9_conv1_relu = None
        block9_conv2_se_pooled = getattr(self, "block9/conv2/se/pooled")(block9_conv2)
        initializers_onnx_initializer_54 = self.initializers.onnx_initializer_54
        block9_conv2_se_squeeze = getattr(self, "block9/conv2/se/squeeze")(block9_conv2_se_pooled, initializers_onnx_initializer_54);  block9_conv2_se_pooled = initializers_onnx_initializer_54 = None
        initializers_onnx_initializer_55 = self.initializers.onnx_initializer_55
        block9_conv2_se_matmul1 = getattr(self, "block9/conv2/se/matmul1")(block9_conv2_se_squeeze, initializers_onnx_initializer_55);  block9_conv2_se_squeeze = initializers_onnx_initializer_55 = None
        initializers_onnx_initializer_56 = self.initializers.onnx_initializer_56
        block9_conv2_se_add1 = getattr(self, "block9/conv2/se/add1")(block9_conv2_se_matmul1, initializers_onnx_initializer_56);  block9_conv2_se_matmul1 = initializers_onnx_initializer_56 = None
        block9_conv2_se_relu = getattr(self, "block9/conv2/se/relu")(block9_conv2_se_add1);  block9_conv2_se_add1 = None
        initializers_onnx_initializer_57 = self.initializers.onnx_initializer_57
        block9_conv2_se_matmul2 = getattr(self, "block9/conv2/se/matmul2")(block9_conv2_se_relu, initializers_onnx_initializer_57);  block9_conv2_se_relu = initializers_onnx_initializer_57 = None
        initializers_onnx_initializer_58 = self.initializers.onnx_initializer_58
        block9_conv2_se_add2 = getattr(self, "block9/conv2/se/add2")(block9_conv2_se_matmul2, initializers_onnx_initializer_58);  block9_conv2_se_matmul2 = initializers_onnx_initializer_58 = None
        initializers_onnx_initializer_59 = self.initializers.onnx_initializer_59
        block9_conv2_se_reshape = getattr(self, "block9/conv2/se/reshape")(block9_conv2_se_add2, initializers_onnx_initializer_59);  block9_conv2_se_add2 = initializers_onnx_initializer_59 = None
        block9_conv2_se_split = getattr(self, "block9/conv2/se/split")(block9_conv2_se_reshape);  block9_conv2_se_reshape = None
        getitem_18 = block9_conv2_se_split[0]
        block9_conv2_se_sigmoid = getattr(self, "block9/conv2/se/sigmoid")(getitem_18);  getitem_18 = None
        block9_conv2_se_mul = getattr(self, "block9/conv2/se/mul")(block9_conv2_se_sigmoid, block9_conv2);  block9_conv2_se_sigmoid = block9_conv2 = None
        getitem_19 = block9_conv2_se_split[1];  block9_conv2_se_split = None
        block9_conv2_se_add3 = getattr(self, "block9/conv2/se/add3")(block9_conv2_se_mul, getitem_19);  block9_conv2_se_mul = getitem_19 = None
        block9_conv2_mixin = getattr(self, "block9/conv2/mixin")(block9_conv2_se_add3, block8_conv2_relu);  block9_conv2_se_add3 = block8_conv2_relu = None
        block9_conv2_relu = getattr(self, "block9/conv2/relu")(block9_conv2_mixin);  block9_conv2_mixin = None
        block10_conv1 = getattr(self, "block10/conv1")(block9_conv2_relu)
        block10_conv1_relu = getattr(self, "block10/conv1/relu")(block10_conv1);  block10_conv1 = None
        block10_conv2 = getattr(self, "block10/conv2")(block10_conv1_relu);  block10_conv1_relu = None
        block10_conv2_se_pooled = getattr(self, "block10/conv2/se/pooled")(block10_conv2)
        initializers_onnx_initializer_60 = self.initializers.onnx_initializer_60
        block10_conv2_se_squeeze = getattr(self, "block10/conv2/se/squeeze")(block10_conv2_se_pooled, initializers_onnx_initializer_60);  block10_conv2_se_pooled = initializers_onnx_initializer_60 = None
        initializers_onnx_initializer_61 = self.initializers.onnx_initializer_61
        block10_conv2_se_matmul1 = getattr(self, "block10/conv2/se/matmul1")(block10_conv2_se_squeeze, initializers_onnx_initializer_61);  block10_conv2_se_squeeze = initializers_onnx_initializer_61 = None
        initializers_onnx_initializer_62 = self.initializers.onnx_initializer_62
        block10_conv2_se_add1 = getattr(self, "block10/conv2/se/add1")(block10_conv2_se_matmul1, initializers_onnx_initializer_62);  block10_conv2_se_matmul1 = initializers_onnx_initializer_62 = None
        block10_conv2_se_relu = getattr(self, "block10/conv2/se/relu")(block10_conv2_se_add1);  block10_conv2_se_add1 = None
        initializers_onnx_initializer_63 = self.initializers.onnx_initializer_63
        block10_conv2_se_matmul2 = getattr(self, "block10/conv2/se/matmul2")(block10_conv2_se_relu, initializers_onnx_initializer_63);  block10_conv2_se_relu = initializers_onnx_initializer_63 = None
        initializers_onnx_initializer_64 = self.initializers.onnx_initializer_64
        block10_conv2_se_add2 = getattr(self, "block10/conv2/se/add2")(block10_conv2_se_matmul2, initializers_onnx_initializer_64);  block10_conv2_se_matmul2 = initializers_onnx_initializer_64 = None
        initializers_onnx_initializer_65 = self.initializers.onnx_initializer_65
        block10_conv2_se_reshape = getattr(self, "block10/conv2/se/reshape")(block10_conv2_se_add2, initializers_onnx_initializer_65);  block10_conv2_se_add2 = initializers_onnx_initializer_65 = None
        block10_conv2_se_split = getattr(self, "block10/conv2/se/split")(block10_conv2_se_reshape);  block10_conv2_se_reshape = None
        getitem_20 = block10_conv2_se_split[0]
        block10_conv2_se_sigmoid = getattr(self, "block10/conv2/se/sigmoid")(getitem_20);  getitem_20 = None
        block10_conv2_se_mul = getattr(self, "block10/conv2/se/mul")(block10_conv2_se_sigmoid, block10_conv2);  block10_conv2_se_sigmoid = block10_conv2 = None
        getitem_21 = block10_conv2_se_split[1];  block10_conv2_se_split = None
        block10_conv2_se_add3 = getattr(self, "block10/conv2/se/add3")(block10_conv2_se_mul, getitem_21);  block10_conv2_se_mul = getitem_21 = None
        block10_conv2_mixin = getattr(self, "block10/conv2/mixin")(block10_conv2_se_add3, block9_conv2_relu);  block10_conv2_se_add3 = block9_conv2_relu = None
        block10_conv2_relu = getattr(self, "block10/conv2/relu")(block10_conv2_mixin);  block10_conv2_mixin = None
        block11_conv1 = getattr(self, "block11/conv1")(block10_conv2_relu)
        block11_conv1_relu = getattr(self, "block11/conv1/relu")(block11_conv1);  block11_conv1 = None
        block11_conv2 = getattr(self, "block11/conv2")(block11_conv1_relu);  block11_conv1_relu = None
        block11_conv2_se_pooled = getattr(self, "block11/conv2/se/pooled")(block11_conv2)
        initializers_onnx_initializer_66 = self.initializers.onnx_initializer_66
        block11_conv2_se_squeeze = getattr(self, "block11/conv2/se/squeeze")(block11_conv2_se_pooled, initializers_onnx_initializer_66);  block11_conv2_se_pooled = initializers_onnx_initializer_66 = None
        initializers_onnx_initializer_67 = self.initializers.onnx_initializer_67
        block11_conv2_se_matmul1 = getattr(self, "block11/conv2/se/matmul1")(block11_conv2_se_squeeze, initializers_onnx_initializer_67);  block11_conv2_se_squeeze = initializers_onnx_initializer_67 = None
        initializers_onnx_initializer_68 = self.initializers.onnx_initializer_68
        block11_conv2_se_add1 = getattr(self, "block11/conv2/se/add1")(block11_conv2_se_matmul1, initializers_onnx_initializer_68);  block11_conv2_se_matmul1 = initializers_onnx_initializer_68 = None
        block11_conv2_se_relu = getattr(self, "block11/conv2/se/relu")(block11_conv2_se_add1);  block11_conv2_se_add1 = None
        initializers_onnx_initializer_69 = self.initializers.onnx_initializer_69
        block11_conv2_se_matmul2 = getattr(self, "block11/conv2/se/matmul2")(block11_conv2_se_relu, initializers_onnx_initializer_69);  block11_conv2_se_relu = initializers_onnx_initializer_69 = None
        initializers_onnx_initializer_70 = self.initializers.onnx_initializer_70
        block11_conv2_se_add2 = getattr(self, "block11/conv2/se/add2")(block11_conv2_se_matmul2, initializers_onnx_initializer_70);  block11_conv2_se_matmul2 = initializers_onnx_initializer_70 = None
        initializers_onnx_initializer_71 = self.initializers.onnx_initializer_71
        block11_conv2_se_reshape = getattr(self, "block11/conv2/se/reshape")(block11_conv2_se_add2, initializers_onnx_initializer_71);  block11_conv2_se_add2 = initializers_onnx_initializer_71 = None
        block11_conv2_se_split = getattr(self, "block11/conv2/se/split")(block11_conv2_se_reshape);  block11_conv2_se_reshape = None
        getitem_22 = block11_conv2_se_split[0]
        block11_conv2_se_sigmoid = getattr(self, "block11/conv2/se/sigmoid")(getitem_22);  getitem_22 = None
        block11_conv2_se_mul = getattr(self, "block11/conv2/se/mul")(block11_conv2_se_sigmoid, block11_conv2);  block11_conv2_se_sigmoid = block11_conv2 = None
        getitem_23 = block11_conv2_se_split[1];  block11_conv2_se_split = None
        block11_conv2_se_add3 = getattr(self, "block11/conv2/se/add3")(block11_conv2_se_mul, getitem_23);  block11_conv2_se_mul = getitem_23 = None
        block11_conv2_mixin = getattr(self, "block11/conv2/mixin")(block11_conv2_se_add3, block10_conv2_relu);  block11_conv2_se_add3 = block10_conv2_relu = None
        block11_conv2_relu = getattr(self, "block11/conv2/relu")(block11_conv2_mixin);  block11_conv2_mixin = None
        block12_conv1 = getattr(self, "block12/conv1")(block11_conv2_relu)
        block12_conv1_relu = getattr(self, "block12/conv1/relu")(block12_conv1);  block12_conv1 = None
        block12_conv2 = getattr(self, "block12/conv2")(block12_conv1_relu);  block12_conv1_relu = None
        block12_conv2_se_pooled = getattr(self, "block12/conv2/se/pooled")(block12_conv2)
        initializers_onnx_initializer_72 = self.initializers.onnx_initializer_72
        block12_conv2_se_squeeze = getattr(self, "block12/conv2/se/squeeze")(block12_conv2_se_pooled, initializers_onnx_initializer_72);  block12_conv2_se_pooled = initializers_onnx_initializer_72 = None
        initializers_onnx_initializer_73 = self.initializers.onnx_initializer_73
        block12_conv2_se_matmul1 = getattr(self, "block12/conv2/se/matmul1")(block12_conv2_se_squeeze, initializers_onnx_initializer_73);  block12_conv2_se_squeeze = initializers_onnx_initializer_73 = None
        initializers_onnx_initializer_74 = self.initializers.onnx_initializer_74
        block12_conv2_se_add1 = getattr(self, "block12/conv2/se/add1")(block12_conv2_se_matmul1, initializers_onnx_initializer_74);  block12_conv2_se_matmul1 = initializers_onnx_initializer_74 = None
        block12_conv2_se_relu = getattr(self, "block12/conv2/se/relu")(block12_conv2_se_add1);  block12_conv2_se_add1 = None
        initializers_onnx_initializer_75 = self.initializers.onnx_initializer_75
        block12_conv2_se_matmul2 = getattr(self, "block12/conv2/se/matmul2")(block12_conv2_se_relu, initializers_onnx_initializer_75);  block12_conv2_se_relu = initializers_onnx_initializer_75 = None
        initializers_onnx_initializer_76 = self.initializers.onnx_initializer_76
        block12_conv2_se_add2 = getattr(self, "block12/conv2/se/add2")(block12_conv2_se_matmul2, initializers_onnx_initializer_76);  block12_conv2_se_matmul2 = initializers_onnx_initializer_76 = None
        initializers_onnx_initializer_77 = self.initializers.onnx_initializer_77
        block12_conv2_se_reshape = getattr(self, "block12/conv2/se/reshape")(block12_conv2_se_add2, initializers_onnx_initializer_77);  block12_conv2_se_add2 = initializers_onnx_initializer_77 = None
        block12_conv2_se_split = getattr(self, "block12/conv2/se/split")(block12_conv2_se_reshape);  block12_conv2_se_reshape = None
        getitem_24 = block12_conv2_se_split[0]
        block12_conv2_se_sigmoid = getattr(self, "block12/conv2/se/sigmoid")(getitem_24);  getitem_24 = None
        block12_conv2_se_mul = getattr(self, "block12/conv2/se/mul")(block12_conv2_se_sigmoid, block12_conv2);  block12_conv2_se_sigmoid = block12_conv2 = None
        getitem_25 = block12_conv2_se_split[1];  block12_conv2_se_split = None
        block12_conv2_se_add3 = getattr(self, "block12/conv2/se/add3")(block12_conv2_se_mul, getitem_25);  block12_conv2_se_mul = getitem_25 = None
        block12_conv2_mixin = getattr(self, "block12/conv2/mixin")(block12_conv2_se_add3, block11_conv2_relu);  block12_conv2_se_add3 = block11_conv2_relu = None
        block12_conv2_relu = getattr(self, "block12/conv2/relu")(block12_conv2_mixin);  block12_conv2_mixin = None
        block13_conv1 = getattr(self, "block13/conv1")(block12_conv2_relu)
        block13_conv1_relu = getattr(self, "block13/conv1/relu")(block13_conv1);  block13_conv1 = None
        block13_conv2 = getattr(self, "block13/conv2")(block13_conv1_relu);  block13_conv1_relu = None
        block13_conv2_se_pooled = getattr(self, "block13/conv2/se/pooled")(block13_conv2)
        initializers_onnx_initializer_78 = self.initializers.onnx_initializer_78
        block13_conv2_se_squeeze = getattr(self, "block13/conv2/se/squeeze")(block13_conv2_se_pooled, initializers_onnx_initializer_78);  block13_conv2_se_pooled = initializers_onnx_initializer_78 = None
        initializers_onnx_initializer_79 = self.initializers.onnx_initializer_79
        block13_conv2_se_matmul1 = getattr(self, "block13/conv2/se/matmul1")(block13_conv2_se_squeeze, initializers_onnx_initializer_79);  block13_conv2_se_squeeze = initializers_onnx_initializer_79 = None
        initializers_onnx_initializer_80 = self.initializers.onnx_initializer_80
        block13_conv2_se_add1 = getattr(self, "block13/conv2/se/add1")(block13_conv2_se_matmul1, initializers_onnx_initializer_80);  block13_conv2_se_matmul1 = initializers_onnx_initializer_80 = None
        block13_conv2_se_relu = getattr(self, "block13/conv2/se/relu")(block13_conv2_se_add1);  block13_conv2_se_add1 = None
        initializers_onnx_initializer_81 = self.initializers.onnx_initializer_81
        block13_conv2_se_matmul2 = getattr(self, "block13/conv2/se/matmul2")(block13_conv2_se_relu, initializers_onnx_initializer_81);  block13_conv2_se_relu = initializers_onnx_initializer_81 = None
        initializers_onnx_initializer_82 = self.initializers.onnx_initializer_82
        block13_conv2_se_add2 = getattr(self, "block13/conv2/se/add2")(block13_conv2_se_matmul2, initializers_onnx_initializer_82);  block13_conv2_se_matmul2 = initializers_onnx_initializer_82 = None
        initializers_onnx_initializer_83 = self.initializers.onnx_initializer_83
        block13_conv2_se_reshape = getattr(self, "block13/conv2/se/reshape")(block13_conv2_se_add2, initializers_onnx_initializer_83);  block13_conv2_se_add2 = initializers_onnx_initializer_83 = None
        block13_conv2_se_split = getattr(self, "block13/conv2/se/split")(block13_conv2_se_reshape);  block13_conv2_se_reshape = None
        getitem_26 = block13_conv2_se_split[0]
        block13_conv2_se_sigmoid = getattr(self, "block13/conv2/se/sigmoid")(getitem_26);  getitem_26 = None
        block13_conv2_se_mul = getattr(self, "block13/conv2/se/mul")(block13_conv2_se_sigmoid, block13_conv2);  block13_conv2_se_sigmoid = block13_conv2 = None
        getitem_27 = block13_conv2_se_split[1];  block13_conv2_se_split = None
        block13_conv2_se_add3 = getattr(self, "block13/conv2/se/add3")(block13_conv2_se_mul, getitem_27);  block13_conv2_se_mul = getitem_27 = None
        block13_conv2_mixin = getattr(self, "block13/conv2/mixin")(block13_conv2_se_add3, block12_conv2_relu);  block13_conv2_se_add3 = block12_conv2_relu = None
        block13_conv2_relu = getattr(self, "block13/conv2/relu")(block13_conv2_mixin);  block13_conv2_mixin = None
        block14_conv1 = getattr(self, "block14/conv1")(block13_conv2_relu)
        block14_conv1_relu = getattr(self, "block14/conv1/relu")(block14_conv1);  block14_conv1 = None
        block14_conv2 = getattr(self, "block14/conv2")(block14_conv1_relu);  block14_conv1_relu = None
        block14_conv2_se_pooled = getattr(self, "block14/conv2/se/pooled")(block14_conv2)
        initializers_onnx_initializer_84 = self.initializers.onnx_initializer_84
        block14_conv2_se_squeeze = getattr(self, "block14/conv2/se/squeeze")(block14_conv2_se_pooled, initializers_onnx_initializer_84);  block14_conv2_se_pooled = initializers_onnx_initializer_84 = None
        initializers_onnx_initializer_85 = self.initializers.onnx_initializer_85
        block14_conv2_se_matmul1 = getattr(self, "block14/conv2/se/matmul1")(block14_conv2_se_squeeze, initializers_onnx_initializer_85);  block14_conv2_se_squeeze = initializers_onnx_initializer_85 = None
        initializers_onnx_initializer_86 = self.initializers.onnx_initializer_86
        block14_conv2_se_add1 = getattr(self, "block14/conv2/se/add1")(block14_conv2_se_matmul1, initializers_onnx_initializer_86);  block14_conv2_se_matmul1 = initializers_onnx_initializer_86 = None
        block14_conv2_se_relu = getattr(self, "block14/conv2/se/relu")(block14_conv2_se_add1);  block14_conv2_se_add1 = None
        initializers_onnx_initializer_87 = self.initializers.onnx_initializer_87
        block14_conv2_se_matmul2 = getattr(self, "block14/conv2/se/matmul2")(block14_conv2_se_relu, initializers_onnx_initializer_87);  block14_conv2_se_relu = initializers_onnx_initializer_87 = None
        initializers_onnx_initializer_88 = self.initializers.onnx_initializer_88
        block14_conv2_se_add2 = getattr(self, "block14/conv2/se/add2")(block14_conv2_se_matmul2, initializers_onnx_initializer_88);  block14_conv2_se_matmul2 = initializers_onnx_initializer_88 = None
        initializers_onnx_initializer_89 = self.initializers.onnx_initializer_89
        block14_conv2_se_reshape = getattr(self, "block14/conv2/se/reshape")(block14_conv2_se_add2, initializers_onnx_initializer_89);  block14_conv2_se_add2 = initializers_onnx_initializer_89 = None
        block14_conv2_se_split = getattr(self, "block14/conv2/se/split")(block14_conv2_se_reshape);  block14_conv2_se_reshape = None
        getitem_28 = block14_conv2_se_split[0]
        block14_conv2_se_sigmoid = getattr(self, "block14/conv2/se/sigmoid")(getitem_28);  getitem_28 = None
        block14_conv2_se_mul = getattr(self, "block14/conv2/se/mul")(block14_conv2_se_sigmoid, block14_conv2);  block14_conv2_se_sigmoid = block14_conv2 = None
        getitem_29 = block14_conv2_se_split[1];  block14_conv2_se_split = None
        block14_conv2_se_add3 = getattr(self, "block14/conv2/se/add3")(block14_conv2_se_mul, getitem_29);  block14_conv2_se_mul = getitem_29 = None
        block14_conv2_mixin = getattr(self, "block14/conv2/mixin")(block14_conv2_se_add3, block13_conv2_relu);  block14_conv2_se_add3 = block13_conv2_relu = None
        block14_conv2_relu = getattr(self, "block14/conv2/relu")(block14_conv2_mixin);  block14_conv2_mixin = None
        block15_conv1 = getattr(self, "block15/conv1")(block14_conv2_relu)
        block15_conv1_relu = getattr(self, "block15/conv1/relu")(block15_conv1);  block15_conv1 = None
        block15_conv2 = getattr(self, "block15/conv2")(block15_conv1_relu);  block15_conv1_relu = None
        block15_conv2_se_pooled = getattr(self, "block15/conv2/se/pooled")(block15_conv2)
        initializers_onnx_initializer_90 = self.initializers.onnx_initializer_90
        block15_conv2_se_squeeze = getattr(self, "block15/conv2/se/squeeze")(block15_conv2_se_pooled, initializers_onnx_initializer_90);  block15_conv2_se_pooled = initializers_onnx_initializer_90 = None
        initializers_onnx_initializer_91 = self.initializers.onnx_initializer_91
        block15_conv2_se_matmul1 = getattr(self, "block15/conv2/se/matmul1")(block15_conv2_se_squeeze, initializers_onnx_initializer_91);  block15_conv2_se_squeeze = initializers_onnx_initializer_91 = None
        initializers_onnx_initializer_92 = self.initializers.onnx_initializer_92
        block15_conv2_se_add1 = getattr(self, "block15/conv2/se/add1")(block15_conv2_se_matmul1, initializers_onnx_initializer_92);  block15_conv2_se_matmul1 = initializers_onnx_initializer_92 = None
        block15_conv2_se_relu = getattr(self, "block15/conv2/se/relu")(block15_conv2_se_add1);  block15_conv2_se_add1 = None
        initializers_onnx_initializer_93 = self.initializers.onnx_initializer_93
        block15_conv2_se_matmul2 = getattr(self, "block15/conv2/se/matmul2")(block15_conv2_se_relu, initializers_onnx_initializer_93);  block15_conv2_se_relu = initializers_onnx_initializer_93 = None
        initializers_onnx_initializer_94 = self.initializers.onnx_initializer_94
        block15_conv2_se_add2 = getattr(self, "block15/conv2/se/add2")(block15_conv2_se_matmul2, initializers_onnx_initializer_94);  block15_conv2_se_matmul2 = initializers_onnx_initializer_94 = None
        initializers_onnx_initializer_95 = self.initializers.onnx_initializer_95
        block15_conv2_se_reshape = getattr(self, "block15/conv2/se/reshape")(block15_conv2_se_add2, initializers_onnx_initializer_95);  block15_conv2_se_add2 = initializers_onnx_initializer_95 = None
        block15_conv2_se_split = getattr(self, "block15/conv2/se/split")(block15_conv2_se_reshape);  block15_conv2_se_reshape = None
        getitem_30 = block15_conv2_se_split[0]
        block15_conv2_se_sigmoid = getattr(self, "block15/conv2/se/sigmoid")(getitem_30);  getitem_30 = None
        block15_conv2_se_mul = getattr(self, "block15/conv2/se/mul")(block15_conv2_se_sigmoid, block15_conv2);  block15_conv2_se_sigmoid = block15_conv2 = None
        getitem_31 = block15_conv2_se_split[1];  block15_conv2_se_split = None
        block15_conv2_se_add3 = getattr(self, "block15/conv2/se/add3")(block15_conv2_se_mul, getitem_31);  block15_conv2_se_mul = getitem_31 = None
        block15_conv2_mixin = getattr(self, "block15/conv2/mixin")(block15_conv2_se_add3, block14_conv2_relu);  block15_conv2_se_add3 = block14_conv2_relu = None
        block15_conv2_relu = getattr(self, "block15/conv2/relu")(block15_conv2_mixin);  block15_conv2_mixin = None
        block16_conv1 = getattr(self, "block16/conv1")(block15_conv2_relu)
        block16_conv1_relu = getattr(self, "block16/conv1/relu")(block16_conv1);  block16_conv1 = None
        block16_conv2 = getattr(self, "block16/conv2")(block16_conv1_relu);  block16_conv1_relu = None
        block16_conv2_se_pooled = getattr(self, "block16/conv2/se/pooled")(block16_conv2)
        initializers_onnx_initializer_96 = self.initializers.onnx_initializer_96
        block16_conv2_se_squeeze = getattr(self, "block16/conv2/se/squeeze")(block16_conv2_se_pooled, initializers_onnx_initializer_96);  block16_conv2_se_pooled = initializers_onnx_initializer_96 = None
        initializers_onnx_initializer_97 = self.initializers.onnx_initializer_97
        block16_conv2_se_matmul1 = getattr(self, "block16/conv2/se/matmul1")(block16_conv2_se_squeeze, initializers_onnx_initializer_97);  block16_conv2_se_squeeze = initializers_onnx_initializer_97 = None
        initializers_onnx_initializer_98 = self.initializers.onnx_initializer_98
        block16_conv2_se_add1 = getattr(self, "block16/conv2/se/add1")(block16_conv2_se_matmul1, initializers_onnx_initializer_98);  block16_conv2_se_matmul1 = initializers_onnx_initializer_98 = None
        block16_conv2_se_relu = getattr(self, "block16/conv2/se/relu")(block16_conv2_se_add1);  block16_conv2_se_add1 = None
        initializers_onnx_initializer_99 = self.initializers.onnx_initializer_99
        block16_conv2_se_matmul2 = getattr(self, "block16/conv2/se/matmul2")(block16_conv2_se_relu, initializers_onnx_initializer_99);  block16_conv2_se_relu = initializers_onnx_initializer_99 = None
        initializers_onnx_initializer_100 = self.initializers.onnx_initializer_100
        block16_conv2_se_add2 = getattr(self, "block16/conv2/se/add2")(block16_conv2_se_matmul2, initializers_onnx_initializer_100);  block16_conv2_se_matmul2 = initializers_onnx_initializer_100 = None
        initializers_onnx_initializer_101 = self.initializers.onnx_initializer_101
        block16_conv2_se_reshape = getattr(self, "block16/conv2/se/reshape")(block16_conv2_se_add2, initializers_onnx_initializer_101);  block16_conv2_se_add2 = initializers_onnx_initializer_101 = None
        block16_conv2_se_split = getattr(self, "block16/conv2/se/split")(block16_conv2_se_reshape);  block16_conv2_se_reshape = None
        getitem_32 = block16_conv2_se_split[0]
        block16_conv2_se_sigmoid = getattr(self, "block16/conv2/se/sigmoid")(getitem_32);  getitem_32 = None
        block16_conv2_se_mul = getattr(self, "block16/conv2/se/mul")(block16_conv2_se_sigmoid, block16_conv2);  block16_conv2_se_sigmoid = block16_conv2 = None
        getitem_33 = block16_conv2_se_split[1];  block16_conv2_se_split = None
        block16_conv2_se_add3 = getattr(self, "block16/conv2/se/add3")(block16_conv2_se_mul, getitem_33);  block16_conv2_se_mul = getitem_33 = None
        block16_conv2_mixin = getattr(self, "block16/conv2/mixin")(block16_conv2_se_add3, block15_conv2_relu);  block16_conv2_se_add3 = block15_conv2_relu = None
        block16_conv2_relu = getattr(self, "block16/conv2/relu")(block16_conv2_mixin);  block16_conv2_mixin = None
        block17_conv1 = getattr(self, "block17/conv1")(block16_conv2_relu)
        block17_conv1_relu = getattr(self, "block17/conv1/relu")(block17_conv1);  block17_conv1 = None
        block17_conv2 = getattr(self, "block17/conv2")(block17_conv1_relu);  block17_conv1_relu = None
        block17_conv2_se_pooled = getattr(self, "block17/conv2/se/pooled")(block17_conv2)
        initializers_onnx_initializer_102 = self.initializers.onnx_initializer_102
        block17_conv2_se_squeeze = getattr(self, "block17/conv2/se/squeeze")(block17_conv2_se_pooled, initializers_onnx_initializer_102);  block17_conv2_se_pooled = initializers_onnx_initializer_102 = None
        initializers_onnx_initializer_103 = self.initializers.onnx_initializer_103
        block17_conv2_se_matmul1 = getattr(self, "block17/conv2/se/matmul1")(block17_conv2_se_squeeze, initializers_onnx_initializer_103);  block17_conv2_se_squeeze = initializers_onnx_initializer_103 = None
        initializers_onnx_initializer_104 = self.initializers.onnx_initializer_104
        block17_conv2_se_add1 = getattr(self, "block17/conv2/se/add1")(block17_conv2_se_matmul1, initializers_onnx_initializer_104);  block17_conv2_se_matmul1 = initializers_onnx_initializer_104 = None
        block17_conv2_se_relu = getattr(self, "block17/conv2/se/relu")(block17_conv2_se_add1);  block17_conv2_se_add1 = None
        initializers_onnx_initializer_105 = self.initializers.onnx_initializer_105
        block17_conv2_se_matmul2 = getattr(self, "block17/conv2/se/matmul2")(block17_conv2_se_relu, initializers_onnx_initializer_105);  block17_conv2_se_relu = initializers_onnx_initializer_105 = None
        initializers_onnx_initializer_106 = self.initializers.onnx_initializer_106
        block17_conv2_se_add2 = getattr(self, "block17/conv2/se/add2")(block17_conv2_se_matmul2, initializers_onnx_initializer_106);  block17_conv2_se_matmul2 = initializers_onnx_initializer_106 = None
        initializers_onnx_initializer_107 = self.initializers.onnx_initializer_107
        block17_conv2_se_reshape = getattr(self, "block17/conv2/se/reshape")(block17_conv2_se_add2, initializers_onnx_initializer_107);  block17_conv2_se_add2 = initializers_onnx_initializer_107 = None
        block17_conv2_se_split = getattr(self, "block17/conv2/se/split")(block17_conv2_se_reshape);  block17_conv2_se_reshape = None
        getitem_34 = block17_conv2_se_split[0]
        block17_conv2_se_sigmoid = getattr(self, "block17/conv2/se/sigmoid")(getitem_34);  getitem_34 = None
        block17_conv2_se_mul = getattr(self, "block17/conv2/se/mul")(block17_conv2_se_sigmoid, block17_conv2);  block17_conv2_se_sigmoid = block17_conv2 = None
        getitem_35 = block17_conv2_se_split[1];  block17_conv2_se_split = None
        block17_conv2_se_add3 = getattr(self, "block17/conv2/se/add3")(block17_conv2_se_mul, getitem_35);  block17_conv2_se_mul = getitem_35 = None
        block17_conv2_mixin = getattr(self, "block17/conv2/mixin")(block17_conv2_se_add3, block16_conv2_relu);  block17_conv2_se_add3 = block16_conv2_relu = None
        block17_conv2_relu = getattr(self, "block17/conv2/relu")(block17_conv2_mixin);  block17_conv2_mixin = None
        block18_conv1 = getattr(self, "block18/conv1")(block17_conv2_relu)
        block18_conv1_relu = getattr(self, "block18/conv1/relu")(block18_conv1);  block18_conv1 = None
        block18_conv2 = getattr(self, "block18/conv2")(block18_conv1_relu);  block18_conv1_relu = None
        block18_conv2_se_pooled = getattr(self, "block18/conv2/se/pooled")(block18_conv2)
        initializers_onnx_initializer_108 = self.initializers.onnx_initializer_108
        block18_conv2_se_squeeze = getattr(self, "block18/conv2/se/squeeze")(block18_conv2_se_pooled, initializers_onnx_initializer_108);  block18_conv2_se_pooled = initializers_onnx_initializer_108 = None
        initializers_onnx_initializer_109 = self.initializers.onnx_initializer_109
        block18_conv2_se_matmul1 = getattr(self, "block18/conv2/se/matmul1")(block18_conv2_se_squeeze, initializers_onnx_initializer_109);  block18_conv2_se_squeeze = initializers_onnx_initializer_109 = None
        initializers_onnx_initializer_110 = self.initializers.onnx_initializer_110
        block18_conv2_se_add1 = getattr(self, "block18/conv2/se/add1")(block18_conv2_se_matmul1, initializers_onnx_initializer_110);  block18_conv2_se_matmul1 = initializers_onnx_initializer_110 = None
        block18_conv2_se_relu = getattr(self, "block18/conv2/se/relu")(block18_conv2_se_add1);  block18_conv2_se_add1 = None
        initializers_onnx_initializer_111 = self.initializers.onnx_initializer_111
        block18_conv2_se_matmul2 = getattr(self, "block18/conv2/se/matmul2")(block18_conv2_se_relu, initializers_onnx_initializer_111);  block18_conv2_se_relu = initializers_onnx_initializer_111 = None
        initializers_onnx_initializer_112 = self.initializers.onnx_initializer_112
        block18_conv2_se_add2 = getattr(self, "block18/conv2/se/add2")(block18_conv2_se_matmul2, initializers_onnx_initializer_112);  block18_conv2_se_matmul2 = initializers_onnx_initializer_112 = None
        initializers_onnx_initializer_113 = self.initializers.onnx_initializer_113
        block18_conv2_se_reshape = getattr(self, "block18/conv2/se/reshape")(block18_conv2_se_add2, initializers_onnx_initializer_113);  block18_conv2_se_add2 = initializers_onnx_initializer_113 = None
        block18_conv2_se_split = getattr(self, "block18/conv2/se/split")(block18_conv2_se_reshape);  block18_conv2_se_reshape = None
        getitem_36 = block18_conv2_se_split[0]
        block18_conv2_se_sigmoid = getattr(self, "block18/conv2/se/sigmoid")(getitem_36);  getitem_36 = None
        block18_conv2_se_mul = getattr(self, "block18/conv2/se/mul")(block18_conv2_se_sigmoid, block18_conv2);  block18_conv2_se_sigmoid = block18_conv2 = None
        getitem_37 = block18_conv2_se_split[1];  block18_conv2_se_split = None
        block18_conv2_se_add3 = getattr(self, "block18/conv2/se/add3")(block18_conv2_se_mul, getitem_37);  block18_conv2_se_mul = getitem_37 = None
        block18_conv2_mixin = getattr(self, "block18/conv2/mixin")(block18_conv2_se_add3, block17_conv2_relu);  block18_conv2_se_add3 = block17_conv2_relu = None
        block18_conv2_relu = getattr(self, "block18/conv2/relu")(block18_conv2_mixin);  block18_conv2_mixin = None
        policy_conv1 = getattr(self, "policy/conv1")(block18_conv2_relu)
        policy_conv1_relu = getattr(self, "policy/conv1/relu")(policy_conv1);  policy_conv1 = None
        policy_conv2 = getattr(self, "policy/conv2")(policy_conv1_relu);  policy_conv1_relu = None
        initializers_onnx_initializer_114 = self.initializers.onnx_initializer_114
        policy_flatten = getattr(self, "policy/flatten")(policy_conv2, initializers_onnx_initializer_114);  policy_conv2 = initializers_onnx_initializer_114 = None
        initializers_onnx_initializer_115 = self.initializers.onnx_initializer_115
        output_policy = getattr(self, "output/policy")(policy_flatten, initializers_onnx_initializer_115);  policy_flatten = initializers_onnx_initializer_115 = None
        value_conv = getattr(self, "value/conv")(block18_conv2_relu)
        value_conv_relu = getattr(self, "value/conv/relu")(value_conv);  value_conv = None
        initializers_onnx_initializer_116 = self.initializers.onnx_initializer_116
        value_reshape = getattr(self, "value/reshape")(value_conv_relu, initializers_onnx_initializer_116);  value_conv_relu = initializers_onnx_initializer_116 = None
        initializers_onnx_initializer_117 = self.initializers.onnx_initializer_117
        value_dense1_matmul = getattr(self, "value/dense1/matmul")(value_reshape, initializers_onnx_initializer_117);  value_reshape = initializers_onnx_initializer_117 = None
        initializers_onnx_initializer_118 = self.initializers.onnx_initializer_118
        value_dense1_add = getattr(self, "value/dense1/add")(value_dense1_matmul, initializers_onnx_initializer_118);  value_dense1_matmul = initializers_onnx_initializer_118 = None
        value_dense1_relu = getattr(self, "value/dense1/relu")(value_dense1_add);  value_dense1_add = None
        initializers_onnx_initializer_119 = self.initializers.onnx_initializer_119
        value_dense2_matmul = getattr(self, "value/dense2/matmul")(value_dense1_relu, initializers_onnx_initializer_119);  value_dense1_relu = initializers_onnx_initializer_119 = None
        initializers_onnx_initializer_120 = self.initializers.onnx_initializer_120
        value_dense2_add = getattr(self, "value/dense2/add")(value_dense2_matmul, initializers_onnx_initializer_120);  value_dense2_matmul = initializers_onnx_initializer_120 = None
        output_wdl = getattr(self, "output/wdl")(value_dense2_add);  value_dense2_add = None
        mlh_conv = getattr(self, "mlh/conv")(block18_conv2_relu);  block18_conv2_relu = None
        mlh_conv_relu = getattr(self, "mlh/conv/relu")(mlh_conv);  mlh_conv = None
        initializers_onnx_initializer_121 = self.initializers.onnx_initializer_121
        mlh_reshape = getattr(self, "mlh/reshape")(mlh_conv_relu, initializers_onnx_initializer_121);  mlh_conv_relu = initializers_onnx_initializer_121 = None
        initializers_onnx_initializer_122 = self.initializers.onnx_initializer_122
        mlh_dense1_matmul = getattr(self, "mlh/dense1/matmul")(mlh_reshape, initializers_onnx_initializer_122);  mlh_reshape = initializers_onnx_initializer_122 = None
        initializers_onnx_initializer_123 = self.initializers.onnx_initializer_123
        mlh_dense1_add = getattr(self, "mlh/dense1/add")(mlh_dense1_matmul, initializers_onnx_initializer_123);  mlh_dense1_matmul = initializers_onnx_initializer_123 = None
        mlh_dense1_relu = getattr(self, "mlh/dense1/relu")(mlh_dense1_add);  mlh_dense1_add = None
        initializers_onnx_initializer_124 = self.initializers.onnx_initializer_124
        mlh_dense2_matmul = getattr(self, "mlh/dense2/matmul")(mlh_dense1_relu, initializers_onnx_initializer_124);  mlh_dense1_relu = initializers_onnx_initializer_124 = None
        initializers_onnx_initializer_125 = self.initializers.onnx_initializer_125
        mlh_dense2_add = getattr(self, "mlh/dense2/add")(mlh_dense2_matmul, initializers_onnx_initializer_125);  mlh_dense2_matmul = initializers_onnx_initializer_125 = None
        mlh_dense2_relu = getattr(self, "mlh/dense2/relu")(mlh_dense2_add);  mlh_dense2_add = None
        output_mlh = getattr(self, "output/mlh")(mlh_dense2_relu);  mlh_dense2_relu = None
        return [output_policy, output_wdl, output_mlh]

    # To see more debug info, please use `graph_module.print_readable()`,
    device=cpu,
    in_keys=['board'],
    out_keys=['policy', 'wdl', 'mlh'])
[4]:
from tensordict import TensorDict
from lczerolens import LczeroBoard

board = LczeroBoard(fen="1rb1rbk1/2qn1p1p/p2p2p1/1ppPp2n/PP2P3/2P1BN1P/R1BN1PP1/3QR1K1 w - - 0 22")
td = TensorDict({"board": board.to_input_tensor().unsqueeze(0)}, batch_size=[1])
td
[4]:
TensorDict(
    fields={
        board: Tensor(shape=torch.Size([1, 112, 8, 8]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([1]),
    device=None,
    is_shared=False)
[5]:
policy = model(board)["policy"]

legal_indices = board.get_legal_indices()
legal_policy = policy[0].gather(0, legal_indices)
best_move_index = legal_indices[legal_policy.argmax()]
move = board.decode_move(best_move_index)

board.push(move)
display(board)
../../_images/notebooks_tutorials_framework-agnostic-interpretability_7_0.svg

Causal Mediation Analysis with nnsight#

[6]:
from IPython.display import display
from nnsight import NNsight

from lczerolens.data import PuzzleData

MEDIATED_MODULE = "block4/conv1/relu"
TARGET_MODULE = "output/wdl"

puzzle = PuzzleData.from_dict(
    {
        "PuzzleId": "004JD",
        "FEN": "3r4/R7/2p5/p1P2p2/1p4k1/nP6/P2KNP2/8 w - - 3 41",
        "Moves": "d2e3 a3c2",
        "Rating": 1336,
        "RatingDeviation": 76,
        "Popularity": 91,
        "NbPlays": 2543,
        "Themes": "endgame mate mateIn1 oneMove",
        "GameUrl": "https://lichess.org/6vTeEc3x#80",
        "OpeningTags": None,
    }
)

base_board = puzzle.initial_board  # board after the forced first move
target_move = puzzle.moves[0]  # puzzle second move
first_legal_move = next(iter(base_board.legal_moves))

target_board = base_board.copy(stack=True)
target_board.push(target_move)

contrast_board = base_board.copy(stack=True)
contrast_board.push(first_legal_move)

print(f"puzzle id: {puzzle.puzzle_id}")
print(f"forced first move: {puzzle.initial_move.uci()}")
print(f"target second move: {target_move.uci()}")
print(f"contrast move (first legal): {first_legal_move.uci()}")
print("\nTarget board (base + target move):")
display(target_board)
print("Contrast board (base + first legal move):")
display(contrast_board)

nnsight_model = NNsight(model)

with nnsight_model.trace(target_board):
    source_activation = getattr(nnsight_model, MEDIATED_MODULE).output.clone().save()
    target_wdl = getattr(nnsight_model, TARGET_MODULE).output.clone().save()

with nnsight_model.trace(contrast_board):
    contrast_wdl = getattr(nnsight_model, TARGET_MODULE).output.clone().save()

with nnsight_model.trace(contrast_board):
    getattr(nnsight_model, MEDIATED_MODULE).output = source_activation
    mediated_wdl = getattr(nnsight_model, TARGET_MODULE).output.clone().save()

total_effect = target_wdl[0, 0] - contrast_wdl[0, 0]
indirect_effect = mediated_wdl[0, 0] - contrast_wdl[0, 0]
direct_effect = target_wdl[0, 0] - mediated_wdl[0, 0]
proportion_mediated = indirect_effect / total_effect if abs(total_effect) > 1e-9 else float("nan")

print("\nEffect measures on value (W - L):")
print(
    {
        "target_w": target_wdl[0, 0],
        "contrast_w": contrast_wdl[0, 0],
        "mediated_w": mediated_wdl[0, 0],
        "total_effect": total_effect,
        "indirect_effect": indirect_effect,
        "direct_effect": direct_effect,
        "proportion_mediated": proportion_mediated,
    }
)
puzzle id: 004JD
forced first move: d2e3
target second move: a3c2
contrast move (first legal): d8h8

Target board (base + target move):
../../_images/notebooks_tutorials_framework-agnostic-interpretability_9_1.svg
Contrast board (base + first legal move):
../../_images/notebooks_tutorials_framework-agnostic-interpretability_9_3.svg

Effect measures on value (W - L):
{'target_w': tensor(0.0740, grad_fn=<SelectBackward0>), 'contrast_w': tensor(0.7348, grad_fn=<SelectBackward0>), 'mediated_w': tensor(0.6883, grad_fn=<SelectBackward0>), 'total_effect': tensor(-0.6608, grad_fn=<SubBackward0>), 'indirect_effect': tensor(-0.0466, grad_fn=<SubBackward0>), 'direct_effect': tensor(-0.6143, grad_fn=<SubBackward0>), 'proportion_mediated': tensor(0.0704, grad_fn=<DivBackward0>)}

Integral-Gradient with captum#

[7]:
import torch


def mlh_forward(board: LczeroBoard) -> torch.Tensor:
    return model(board)["mlh"]


baseline = torch.zeros_like(td["board"])
print("wdl (board):", mlh_forward(board))
print("wdl (baseline):", mlh_forward(baseline))
wdl (board): tensor([[105.5060]], grad_fn=<CloneBackward0>)
wdl (baseline): tensor([[0.]], grad_fn=<CloneBackward0>)
[8]:
from captum.attr import IntegratedGradients

ig = IntegratedGradients(mlh_forward)
attributions, approximation_error = ig.attribute(td["board"], baselines=baseline, return_convergence_delta=True)

approximation_error
[8]:
tensor([-3.6674])
[9]:
import IPython.display

heatmap = attributions.mean(dim=1)

svg_board, svg_colorbar = board.render_heatmap(heatmap[0].view(64).detach(), normalise="abs")
display(IPython.display.HTML(f"{svg_board}{svg_colorbar}"))
. r b . r b k .
. . q n . p . p
p . . p . . p .
. p p P p . . n
P P . . P . . .
. N P . B N . P
R . B . . P P .
. . . Q R . K .
2026-04-29T10:40:33.777938 image/svg+xml Matplotlib v3.10.0, https://matplotlib.org/

LRP with zennit#

[10]:
import torch
from torch.autograd import Function


def stabilize(tensor, epsilon=1e-6):
    return tensor + epsilon * ((-1) ** (tensor < 0))


class AddEpsilonFunction(Function):
    @staticmethod
    def forward(ctx, input_a, input_b, epsilon=1e-6):
        output = input_a + input_b
        ctx.save_for_backward(input_a, input_b, output, torch.tensor(epsilon))
        return output

    @staticmethod
    def backward(ctx, *grad_output):
        input_a, input_b, output, epsilon = ctx.saved_tensors
        out_relevance = grad_output[0] / stabilize(output, epsilon)
        return out_relevance * input_a, out_relevance * input_b, None


class AddEpsilon(torch.nn.Module):
    def __init__(self, epsilon=1e-6):
        super().__init__()
        self.epsilon = epsilon

    def forward(self, x, y):
        return AddEpsilonFunction.apply(x, y, self.epsilon)


class MatMulEpsilonFunction(Function):
    @staticmethod
    def forward(ctx, input, param, epsilon=1e-6):
        output = torch.matmul(input, param)
        ctx.save_for_backward(input, param, output, torch.tensor(epsilon))

        return output

    @staticmethod
    def backward(ctx, *grad_outputs):
        input, param, output, epsilon = ctx.saved_tensors
        out_relevance = grad_outputs[0]

        out_relevance = out_relevance / stabilize(output, epsilon)
        relevance = (out_relevance @ param.T) * input
        return relevance, None, None


class MatMulEpsilon(torch.nn.Module):
    def __init__(self, epsilon=1e-6):
        super().__init__()
        self.epsilon = epsilon

    def forward(self, x, y):
        return MatMulEpsilonFunction.apply(x, y, self.epsilon)


class MulUniformFunction(Function):
    @staticmethod
    def forward(ctx, input_a, input_b):
        return input_a * input_b

    @staticmethod
    def backward(ctx, *grad_outputs):
        relevance = grad_outputs[0] * 0.5

        return relevance, relevance


class MulUniform(torch.nn.Module):
    def forward(self, x, y):
        return MulUniformFunction.apply(x, y)
[11]:
from zennit.canonizers import SequentialMergeBatchNorm
from zennit.composites import LayerMapComposite
from zennit.rules import Epsilon, Pass, ZPlus
from zennit.types import Activation
import onnx2torch

canonizers = [SequentialMergeBatchNorm()]
layer_map = [
    (Activation, Pass()),
    (torch.nn.Conv2d, ZPlus()),
    (torch.nn.Linear, Epsilon(epsilon=1e-6)),
    (torch.nn.AdaptiveAvgPool2d, Epsilon(epsilon=1e-6)),
]
composite = LayerMapComposite(layer_map=layer_map, canonizers=canonizers)

new_module_mapping = {}
old_module_mapping = {}
for name, module in model.named_modules():
    if name == "":
        continue
    if isinstance(module, torch.nn.Softmax):
        new_module_mapping[name] = torch.nn.Identity()
        old_module_mapping[name] = module
    if isinstance(module, onnx2torch.node_converters.OnnxBinaryMathOperation):
        if module.math_op_function is torch.add:
            new_module_mapping[name] = AddEpsilon()
            old_module_mapping[name] = module
        elif module.math_op_function is torch.mul:
            new_module_mapping[name] = MulUniform()
            old_module_mapping[name] = module
    elif isinstance(module, onnx2torch.node_converters.OnnxMatMul):
        new_module_mapping[name] = MatMulEpsilon()
        old_module_mapping[name] = module
    elif isinstance(module, onnx2torch.node_converters.OnnxFunction):
        if module.function is torch.tanh:
            new_module_mapping[name] = torch.nn.Tanh()
            old_module_mapping[name] = module
    elif isinstance(
        module,
        onnx2torch.node_converters.OnnxGlobalAveragePoolWithKnownInputShape,  # noqa
    ):
        new_module_mapping[name] = torch.nn.AdaptiveAvgPool2d(1)
        old_module_mapping[name] = module
for name, module in new_module_mapping.items():
    setattr(model, name, module)
[12]:
with composite.context(model) as modified_model:
    td["board"].requires_grad_(True)
    output = modified_model(td["board"])["policy"]

    (relevance,) = torch.autograd.grad(output[:, best_move_index], td["board"])
[13]:
import IPython.display

heatmap = relevance.mean(dim=1)

svg_board, svg_colorbar = board.render_heatmap(heatmap[0].view(64).detach(), normalise="abs")
display(IPython.display.HTML(f"{svg_board}{svg_colorbar}"))
. r b . r b k .
. . q n . p . p
p . . p . . p .
. p p P p . . n
P P . . P . . .
. N P . B N . P
R . B . . P P .
. . . Q R . K .
2026-04-29T10:40:34.029594 image/svg+xml Matplotlib v3.10.0, https://matplotlib.org/
[14]:
for name, module in old_module_mapping.items():
    setattr(model, name, module)

Per-Square Dimension Estimation with tdhook#

[15]:
import torch
from datasets import Dataset, load_dataset
from IPython.display import HTML, display
from tdhook.latent import ActivationCaching
from tdhook.latent.dimension_estimation import TwoNnDimensionEstimator

TARGET_LAYER = "block0/conv2/relu"
TARGET_PATTERN = rf".*{TARGET_LAYER}.*"
N_SAMPLES = 128
SEED = 42


def to_eager_subset(stream_ds, n: int) -> Dataset:
    stream_ds = stream_ds.shuffle(seed=SEED).take(n)
    return Dataset.from_generator(lambda: (yield from stream_ds), features=stream_ds.features)


def row_to_board(row: dict) -> LczeroBoard | None:
    fen = row.get("fen") if isinstance(row, dict) else None
    if not fen:
        return None
    try:
        return LczeroBoard(fen)
    except Exception:
        return None


raw_stream = load_dataset("lczerolens/tcec-boards", split="train", streaming=True)
subset = to_eager_subset(raw_stream, N_SAMPLES)
rows = list(subset)
boards = [b for b in (row_to_board(r) for r in rows) if b is not None]
if len(boards) < 8:
    raise RuntimeError("Not enough valid boards for estimation; increase N_SAMPLES.")

board_tensor = torch.stack([b.to_input_tensor() for b in boards], dim=0)
td_batch = TensorDict({"board": board_tensor}, batch_size=[len(boards)])

cache_context = ActivationCaching(TARGET_PATTERN, relative=True, clear_cache=True)
with cache_context.prepare(model) as hooked_model:
    _ = hooked_model(td_batch)
    cache = hooked_model.hooking_context.cache

activation = next(
    value.detach().cpu()
    for key, value in cache.items(True, True)
    if str(key).endswith(TARGET_LAYER) and torch.is_tensor(value) and value.ndim == 4
)

estimator = TwoNnDimensionEstimator()
n, c, h, w = activation.shape
square_samples = activation.permute(2, 3, 0, 1).reshape(h * w, n, c)

square_dims = []
for i in range(h * w):
    td_i = TensorDict({"data": square_samples[i]}, batch_size=[])
    try:
        d = estimator(td_i)["dimension"].item()
    except ValueError:
        d = float("nan")
    square_dims.append(d)

square_dims = torch.tensor(square_dims, dtype=torch.float32)
finite = torch.isfinite(square_dims)
if not finite.any():
    raise RuntimeError("No finite per-square dimensions found for this layer.")

filled = square_dims.clone()
filled[~finite] = torch.nanmean(filled[finite])
svg_board, svg_colorbar = LczeroBoard().render_heatmap(filled.reshape(64), normalise="none")
display(HTML(f"<h4>{TARGET_LAYER}: intrinsic dimension per board square</h4>{svg_board}{svg_colorbar}"))

block0/conv2/relu: intrinsic dimension per board square

r n b q k b n r
p p p p p p p p
. . . . . . . .
. . . . . . . .
. . . . . . . .
. . . . . . . .
P P P P P P P P
R N B Q K B N R
2026-04-29T10:40:38.450610 image/svg+xml Matplotlib v3.10.0, https://matplotlib.org/