Source code for lczerolens.lenses.sae.buffer
"""Activation lens for XAI."""
from typing import Any, Optional, Callable
from dataclasses import dataclass
import torch
from datasets import Dataset
from torch.utils.data import DataLoader, TensorDataset
from lczerolens.model import LczeroModel
@dataclass
[docs]
class ActivationBuffer:
[docs]
compute_fn: Callable[[Any, LczeroModel], torch.Tensor]
[docs]
n_batches_in_buffer: int = 10
[docs]
compute_batch_size: int = 64
[docs]
train_batch_size: int = 2048
[docs]
dataloader_kwargs: Optional[dict] = None
[docs]
logger: Optional[Callable] = None
[docs]
def __post_init__(self):
if self.dataloader_kwargs is None:
self.dataloader_kwargs = {}
self._buffer = []
self._remainder = None
self._make_dataloader_it()
[docs]
def _make_dataloader_it(self):
self._dataloader_it = iter(
DataLoader(self.dataset, batch_size=self.compute_batch_size, **self.dataloader_kwargs)
)
@torch.no_grad()
[docs]
def _fill_buffer(self):
if self.logger is not None:
self.logger.info("Computing activations...")
self._buffer = []
while len(self._buffer) < self.n_batches_in_buffer:
try:
next_batch = next(self._dataloader_it)
except StopIteration:
break
activations = self.compute_fn(next_batch, self.model)
self._buffer.append(activations.to("cpu"))
if not self._buffer:
raise StopIteration
[docs]
def _make_activations_it(self):
if self._remainder is not None:
self._buffer.append(self._remainder)
self._remainder = None
activations_ds = TensorDataset(torch.cat(self._buffer, dim=0))
if self.logger is not None:
self.logger.info(f"Activations dataset of size {len(activations_ds)}")
self._activations_it = iter(
DataLoader(
activations_ds,
batch_size=self.train_batch_size,
shuffle=True,
)
)
[docs]
def __iter__(self):
self._make_dataloader_it()
self._fill_buffer()
self._make_activations_it()
self._remainder = None
return self
[docs]
def __next__(self):
try:
activations = next(self._activations_it)[0]
if activations.shape[0] < self.train_batch_size:
self._remainder = activations
self._fill_buffer()
self._make_activations_it()
activations = next(self._activations_it)[0]
return activations
except StopIteration:
try:
self._fill_buffer()
self._make_activations_it()
self.__next__()
except StopIteration as e:
if self._remainder is not None:
activations = self._remainder
self._remainder = None
return activations
raise StopIteration from e
raise StopIteration