"""Class for concept-based XAI methods."""
from abc import ABC, abstractmethod
from typing import Any, Optional, Dict
import torch
import chess
from lczerolens.board import LczeroBoard
from lczerolens.model import LczeroModel, PolicyFlow
[docs]
class Concept(ABC):
"""
Class for concept-based XAI methods.
"""
@abstractmethod
[docs]
def compute_label(
self,
board: LczeroBoard,
) -> Any:
"""
Compute the label for a given model and input.
"""
pass
@staticmethod
@abstractmethod
[docs]
def compute_metrics(
predictions,
labels,
):
"""
Compute the metrics for a given model and input.
"""
pass
@staticmethod
@abstractmethod
[docs]
def get_dataset_feature():
"""Returns the feature for the dataset."""
pass
[docs]
class BinaryConcept(Concept):
"""
Class for binary concept-based XAI methods.
"""
@staticmethod
[docs]
def compute_metrics(
predictions,
labels,
):
"""
Compute the metrics for a given model and input.
"""
try:
from sklearn import metrics
except ImportError as e:
raise ImportError("scikit-learn is required to compute metrics.") from e
return {
"accuracy": metrics.accuracy_score(labels, predictions),
"precision": metrics.precision_score(labels, predictions),
"recall": metrics.recall_score(labels, predictions),
"f1": metrics.f1_score(labels, predictions),
}
@staticmethod
[docs]
def get_dataset_feature():
"""Returns the feature for the dataset."""
try:
from datasets import ClassLabel
except ImportError as e:
raise ImportError("datasets is required to get the dataset features.") from e
return ClassLabel(num_classes=2)
[docs]
class NullConcept(BinaryConcept):
"""
Class for binary concept-based XAI methods.
"""
[docs]
def compute_label(
self,
board: LczeroBoard,
) -> Any:
"""
Compute the label for a given model and input.
"""
return 0
[docs]
class OrBinaryConcept(BinaryConcept):
"""
Class for binary concept-based XAI methods.
"""
def __init__(self, *concepts: BinaryConcept):
for concept in concepts:
if not isinstance(concept, BinaryConcept):
raise ValueError(f"{concept} is not a BinaryConcept")
[docs]
self.concepts = concepts
[docs]
def compute_label(
self,
board: LczeroBoard,
) -> Any:
"""
Compute the label for a given model and input.
"""
return any(concept.compute_label(board) for concept in self.concepts)
[docs]
class AndBinaryConcept(BinaryConcept):
"""
Class for binary concept-based XAI methods.
"""
def __init__(self, *concepts: BinaryConcept):
for concept in concepts:
if not isinstance(concept, BinaryConcept):
raise ValueError(f"{concept} is not a BinaryConcept")
[docs]
self.concepts = concepts
[docs]
def compute_label(
self,
board: LczeroBoard,
) -> Any:
"""
Compute the label for a given model and input.
"""
return all(concept.compute_label(board) for concept in self.concepts)
[docs]
class MulticlassConcept(Concept):
"""
Class for multiclass concept-based XAI methods.
"""
@staticmethod
[docs]
def compute_metrics(
predictions,
labels,
):
"""
Compute the metrics for a given model and input.
"""
try:
from sklearn import metrics
except ImportError as e:
raise ImportError("scikit-learn is required to compute metrics.") from e
return {
"accuracy": metrics.accuracy_score(labels, predictions),
"precision": metrics.precision_score(labels, predictions, average="weighted"),
"recall": metrics.recall_score(labels, predictions, average="weighted"),
"f1": metrics.f1_score(labels, predictions, average="weighted"),
}
@staticmethod
[docs]
def get_dataset_feature():
"""Returns the feature for the dataset."""
try:
from datasets import Value
except ImportError as e:
raise ImportError("datasets is required to get the dataset features.") from e
return (Value("int32"),)
[docs]
class ContinuousConcept(Concept):
"""
Class for continuous concept-based XAI methods.
"""
@staticmethod
[docs]
def compute_metrics(
predictions,
labels,
):
"""
Compute the metrics for a given model and input.
"""
try:
from sklearn import metrics
except ImportError as e:
raise ImportError("scikit-learn is required to compute metrics.") from e
return {
"rmse": metrics.root_mean_squared_error(labels, predictions),
"mae": metrics.mean_absolute_error(labels, predictions),
"r2": metrics.r2_score(labels, predictions),
}
@staticmethod
[docs]
def get_dataset_feature():
"""Returns the feature for the dataset."""
try:
from datasets import Value
except ImportError as e:
raise ImportError("datasets is required to get the dataset features.") from e
return Value("float32")
[docs]
class HasPiece(BinaryConcept):
"""Class for material concept-based XAI methods."""
def __init__(
self,
piece: str,
relative: bool = True,
):
"""Initialize the class."""
[docs]
self.piece = chess.Piece.from_symbol(piece)
[docs]
self.relative = relative
[docs]
def compute_label(
self,
board: LczeroBoard,
) -> int:
"""Compute the label for a given model and input."""
if self.relative:
color = self.piece.color if board.turn else not self.piece.color
else:
color = self.piece.color
squares = board.pieces(self.piece.piece_type, color)
return 1 if len(squares) > 0 else 0
# Material concepts
[docs]
class HasMaterialAdvantage(BinaryConcept):
"""Class for material concept-based XAI methods.
Attributes
----------
piece_values : Dict[int, int]
The piece values.
"""
[docs]
piece_values = {
chess.PAWN: 1,
chess.KNIGHT: 3,
chess.BISHOP: 3,
chess.ROOK: 5,
chess.QUEEN: 9,
chess.KING: 0,
}
def __init__(
self,
relative: bool = True,
):
"""
Initialize the class.
"""
[docs]
self.relative = relative
[docs]
def compute_label(
self,
board: LczeroBoard,
piece_values: Optional[Dict[int, int]] = None,
) -> int:
"""
Compute the label for a given model and input.
"""
if piece_values is None:
piece_values = self.piece_values
if self.relative:
us, them = board.turn, not board.turn
else:
us, them = chess.WHITE, chess.BLACK
our_value = 0
their_value = 0
for piece in range(1, 7):
our_value += len(board.pieces(piece, us)) * piece_values[piece]
their_value += len(board.pieces(piece, them)) * piece_values[piece]
return 1 if our_value > their_value else 0
# Move concepts
[docs]
class BestLegalMove(MulticlassConcept):
"""Class for move concept-based XAI methods."""
def __init__(
self,
model: LczeroModel,
):
"""Initialize the class."""
[docs]
self.policy_flow = PolicyFlow.from_model(model.module)
[docs]
def compute_label(
self,
board: LczeroBoard,
) -> int:
"""Compute the label for a given model and input."""
output = self.policy_flow(board)
policy = torch.softmax(output["policy"].squeeze(0), dim=-1)
legal_move_indices = [LczeroBoard.encode_move(move, board.turn) for move in board.legal_moves]
sub_index = policy[legal_move_indices].argmax().item()
return legal_move_indices[sub_index]
[docs]
class PieceBestLegalMove(BinaryConcept):
"""Class for move concept-based XAI methods."""
def __init__(
self,
model: LczeroModel,
piece: str,
):
"""Initialize the class."""
[docs]
self.policy_flow = PolicyFlow.from_model(model.module)
[docs]
self.piece = chess.Piece.from_symbol(piece)
[docs]
def compute_label(
self,
board: LczeroBoard,
) -> int:
"""Compute the label for a given model and input."""
output = self.policy_flow(board)
policy = torch.softmax(output["policy"].squeeze(0), dim=-1)
legal_moves = list(board.legal_moves)
legal_move_indices = [LczeroBoard.encode_move(move, board.turn) for move in legal_moves]
sub_index = policy[legal_move_indices].argmax().item()
best_legal_move = legal_moves[sub_index]
return 1 if board.piece_at(best_legal_move.from_square) == self.piece else 0
# Threat concepts
[docs]
class HasThreat(BinaryConcept):
"""
Class for material concept-based XAI methods.
"""
def __init__(
self,
piece: str,
relative: bool = True,
):
"""
Initialize the class.
"""
[docs]
self.piece = chess.Piece.from_symbol(piece)
[docs]
self.relative = relative
[docs]
def compute_label(
self,
board: LczeroBoard,
) -> int:
"""
Compute the label for a given model and input.
"""
if self.relative:
color = self.piece.color if board.turn else not self.piece.color
else:
color = self.piece.color
squares = board.pieces(self.piece.piece_type, color)
return next((1 for square in squares if board.is_attacked_by(not color, square)), 0)
[docs]
class HasMateThreat(BinaryConcept):
"""
Class for material concept-based XAI methods.
"""
[docs]
def compute_label(
self,
board: LczeroBoard,
) -> int:
"""
Compute the label for a given model and input.
"""
for move in board.legal_moves:
board.push(move)
if board.is_checkmate():
board.pop()
return 1
board.pop()
return 0