Source code for lczerolens.concepts.move

"""All concepts related to move."""

import chess
import torch

from lczerolens.board import LczeroBoard
from lczerolens.model import LczeroModel, PolicyFlow
from lczerolens.concept import BinaryConcept, MulticlassConcept


[docs] class BestLegalMove(MulticlassConcept): """Class for move concept-based XAI methods.""" def __init__( self, model: LczeroModel, ): """Initialize the class."""
[docs] self.policy_flow = PolicyFlow(model)
[docs] def compute_label( self, board: LczeroBoard, ) -> int: """Compute the label for a given model and input.""" (policy,) = self.policy_flow(board) policy = torch.softmax(policy.squeeze(0), dim=-1) legal_move_indices = [LczeroBoard.encode_move(move, board.turn) for move in board.legal_moves] sub_index = policy[legal_move_indices].argmax().item() return legal_move_indices[sub_index]
[docs] class PieceBestLegalMove(BinaryConcept): """Class for move concept-based XAI methods.""" def __init__( self, model: LczeroModel, piece: str, ): """Initialize the class."""
[docs] self.policy_flow = PolicyFlow(model)
[docs] self.piece = chess.Piece.from_symbol(piece)
[docs] def compute_label( self, board: LczeroBoard, ) -> int: """Compute the label for a given model and input.""" (policy,) = self.policy_flow(board) policy = torch.softmax(policy.squeeze(0), dim=-1) legal_moves = list(board.legal_moves) legal_move_indices = [LczeroBoard.encode_move(move, board.turn) for move in legal_moves] sub_index = policy[legal_move_indices].argmax().item() best_legal_move = legal_moves[sub_index] if board.piece_at(best_legal_move.from_square) == self.piece: return 1 return 0