Source code for lczerolens.lenses.composite
"""Composite lens for XAI."""
from typing import List, Dict, Union, Any
from lczerolens.lens import Lens
from lczerolens.model import LczeroModel
[docs]
class CompositeLens(Lens):
"""Composite lens for XAI.
Examples
--------
.. code-block:: python
model = LczeroModel.from_path(model_path)
lens = CompositeLens([ActivationLens(), GradientLens()])
board = LczeroBoard()
results = lens.analyse(board, model=model)
"""
def __init__(self, lenses: Union[List[Lens], Dict[str, Lens]], merge_results: bool = True):
[docs]
self._lens_map = lenses if isinstance(lenses, dict) else {f"lens_{i}": lens for i, lens in enumerate(lenses)}
[docs]
self.merge_results = merge_results
[docs]
def is_compatible(self, model: LczeroModel) -> bool:
return all(lens.is_compatible(model) for lens in self._lens_map.values())
[docs]
def prepare(self, model: LczeroModel, **kwargs) -> LczeroModel:
for lens in self._lens_map.values():
model = lens.prepare(model, **kwargs)
return model
[docs]
def _intervene(self, model: LczeroModel, **kwargs) -> Dict[str, Any]:
results = {name: lens._intervene(model, **kwargs) for name, lens in self._lens_map.items()}
if self.merge_results:
return {k: v for d in results.values() for k, v in d.items()}
return results