Source code for lczerolens.lenses.patching
"""Patching lens."""
from typing import Callable
from lczerolens.model import LczeroModel
from lczerolens.lens import Lens
@Lens.register("patching")
[docs]
class PatchingLens(Lens):
"""
Class for activation-based XAI methods.
Examples
--------
.. code-block:: python
model = LczeroModel.from_path(model_path)
lens = PatchingLens()
board = LczeroBoard()
patch_fn = lambda n, m, *kwargs: pass
results = lens.analyse(board, model=model)
"""
def __init__(self, patch_fn: Callable, **kwargs):
[docs]
self._patch_fn = patch_fn
super().__init__(**kwargs)
[docs]
def _intervene(
self,
model: LczeroModel,
**kwargs,
) -> dict:
for name, module in self._get_modules(model):
self._patch_fn(name, module, **kwargs)
return {}