Source code for lczerolens.lens

"""Generic lens class."""

from abc import ABC, abstractmethod
from typing import Dict, Iterable, Generator, Callable, Type, Union, Optional, Any

import torch
import re

from lczerolens.model import LczeroModel
from lczerolens.board import LczeroBoard


[docs] class Lens(ABC): """Generic lens class for analysing model activations."""
[docs] _lens_type: str
[docs] _registry: Dict[str, Type["Lens"]] = {}
@classmethod
[docs] def register(cls, name: str) -> Callable: """Registers the lens. Parameters ---------- name : str The name of the lens. Returns ------- Callable The decorator to register the lens. Raises ------ ValueError If the lens name is already registered. """ if name in cls._registry: raise ValueError(f"Lens {name} already registered.") def decorator(subclass: Type["Lens"]): subclass._lens_type = name cls._registry[name] = subclass return subclass return decorator
@classmethod
[docs] def from_name(cls, name: str, *args, **kwargs) -> "Lens": """Returns the lens from its name. Parameters ---------- name : str The name of the lens. Returns ------- Lens The lens instance. Raises ------ KeyError If the lens name is not found. """ if name not in cls._registry: raise KeyError(f"Lens {name} not found.") return cls._registry[name](*args, **kwargs)
def __init__(self, pattern: Optional[str] = None): """Initialise the lens. Parameters ---------- pattern : Optional[str], default=None The pattern to match the modules. """ if pattern is None: pattern = r"a^" # match nothing by default
[docs] self._pattern = pattern
[docs] self._reg_exp = re.compile(pattern)
@property
[docs] def pattern(self) -> str: """The pattern to match the modules.""" return self._pattern
@pattern.setter def pattern(self, pattern: str): self._pattern = pattern self._reg_exp = re.compile(pattern)
[docs] def _get_modules(self, model: LczeroModel) -> Generator[tuple[str, Any], None, None]: """Get the modules to intervene on.""" for name, module in model.named_modules(): fixed_name = name.lstrip(". ") # nnsight outputs names with a dot if self._reg_exp.match(fixed_name): yield fixed_name, module
[docs] def is_compatible(self, model: LczeroModel) -> bool: """Returns whether the lens is compatible with the model. Parameters ---------- model : LczeroModel The NNsight model. Returns ------- bool Whether the lens is compatible with the model. """ return isinstance(model, LczeroModel)
[docs] def _ensure_compatible(self, model: LczeroModel): """Ensure the lens is compatible with the model. Parameters ---------- model : LczeroModel The NNsight model. Raises ------ ValueError If the lens is not compatible with the model. """ if not self.is_compatible(model): raise ValueError(f"Lens {self._lens_type} is not compatible with model of type {type(model)}.")
[docs] def prepare(self, model: LczeroModel, **kwargs) -> LczeroModel: """Prepare the model for the lens. Parameters ---------- model : LczeroModel The NNsight model. Returns ------- LczeroModel The prepared model. """ return model
@abstractmethod
[docs] def _intervene(self, model: LczeroModel, **kwargs) -> dict: """Intervene on the model. Parameters ---------- model : LczeroModel The NNsight model. Returns ------- dict The intervention results. """ pass
[docs] def _trace( self, model: LczeroModel, *inputs: Union[LczeroBoard, torch.Tensor], model_kwargs: dict, intervention_kwargs: dict, ): """Trace the model and intervene on it. Parameters ---------- model : LczeroModel The NNsight model. inputs : Union[LczeroBoard, torch.Tensor] The inputs. model_kwargs : dict The model kwargs. intervention_kwargs : dict The intervention kwargs. Returns ------- dict The intervention results. """ with model.trace(*inputs, **model_kwargs): return self._intervene(model, **intervention_kwargs)
[docs] def analyse( self, model: LczeroModel, *inputs: Union[LczeroBoard, torch.Tensor], **kwargs, ) -> dict: """Analyse the input. Parameters ---------- model : LczeroModel The NNsight model. inputs : Union[LczeroBoard, torch.Tensor] The inputs. Returns ------- dict The analysis results. Raises ------ ValueError If the lens is not compatible with the model. """ self._ensure_compatible(model) model_kwargs = kwargs.get("model_kwargs", {}) prepared_model = self.prepare(model, **kwargs) return self._trace(prepared_model, *inputs, model_kwargs=model_kwargs, intervention_kwargs=kwargs)
[docs] def analyse_batched( self, model: LczeroModel, iter_inputs: Iterable[Union[LczeroBoard, torch.Tensor]], **kwargs, ) -> Generator[dict, None, None]: """Analyse a batches of inputs. Parameters ---------- model : LczeroModel The NNsight model. iter_inputs : Iterable[Tuple[Union[LczeroBoard, torch.Tensor], dict]] The iterator over the inputs. Returns ------- Generator[dict, None, None] The iterator over the statistics. Raises ------ ValueError If the lens is not compatible with the model. """ self._ensure_compatible(model) model_kwargs = kwargs.get("model_kwargs", {}) prepared_model = self.prepare(model, **kwargs) for inputs, dynamic_intervention_kwargs in iter_inputs: kwargs.update(dynamic_intervention_kwargs) yield self._trace(prepared_model, *inputs, model_kwargs=model_kwargs, intervention_kwargs=kwargs)