"""
Search utilities.
"""
import chess
import torch
from lczerolens.board import LczeroBoard
from lczerolens.model import ForceValue, LczeroModel
from tensordict import TensorDict
from typing import Callable, Dict, Optional, Protocol, Tuple
[docs]
class Heuristic(Protocol):
"""Heuristic protocol for evaluating chess positions."""
[docs]
def evaluate(
self,
board: LczeroBoard,
) -> TensorDict:
"""
Evaluate a single board.
Parameters
----------
board : LczeroBoard
LczeroBoard instance representing the current board state
Returns
-------
TensorDict
Dictionary with fields value and policy.
"""
...
[docs]
class RandomHeuristic:
"""Simple heuristic for MCTS."""
[docs]
def evaluate(
self,
board: LczeroBoard,
) -> TensorDict:
"""Evaluate a single board.
Parameters
----------
board : LczeroBoard
LczeroBoard instance
Returns
-------
TensorDict
Dictionary with fields value and policy.
"""
n = board.legal_moves.count()
return TensorDict(value=torch.tensor(0.0), policy=torch.full((n,), 1 / n))
[docs]
class MaterialHeuristic:
"""Heuristic that outputs uniform policy and material advantage as value."""
[docs]
default_piece_values = {
chess.PAWN: 1,
chess.KNIGHT: 3,
chess.BISHOP: 3,
chess.ROOK: 5,
chess.QUEEN: 9,
chess.KING: 100,
}
def __init__(
self,
piece_values: Optional[Dict[int, int]] = None,
normalization_constant: float = 0.1,
activation: Callable[[torch.Tensor], torch.Tensor] = torch.tanh,
):
[docs]
self.piece_values = piece_values or self.default_piece_values
[docs]
self.normalization_constant = normalization_constant
[docs]
self.activation = activation
[docs]
def evaluate(
self,
board: LczeroBoard,
) -> TensorDict:
"""
Compute the label for a given model and input.
"""
us, them = board.turn, not board.turn
relative_value = 0
for piece in range(1, 7):
relative_value += len(board.pieces(piece, us)) * self.piece_values[piece]
relative_value -= len(board.pieces(piece, them)) * self.piece_values[piece]
value = self.activation(torch.tensor(relative_value / self.normalization_constant, dtype=torch.float32))
n = board.legal_moves.count()
policy = torch.full((n,), 1 / n)
return TensorDict(value=value, policy=policy)
[docs]
class ModelHeuristic:
"""Evaluate boards using a neural network model for MCTS."""
def __init__(self, model: LczeroModel):
[docs]
self._model = ForceValue.from_model(model.module)
[docs]
def evaluate(self, board: LczeroBoard) -> TensorDict:
"""
Evaluate a single board using the Lczero model.
Returns TensorDict with 'value' and 'policy'.
"""
td = self._model(board)[0]
legal_indices = board.get_legal_indices()
td["policy"] = td["policy"].gather(0, legal_indices)
return td
[docs]
class Node:
"""Node for MCTS using LczeroBoard."""
def __init__(
self,
board: LczeroBoard,
parent: "Node",
) -> None:
"""Initialize a Node with a given board and parent node."""
[docs]
self.is_terminal: bool = board.is_game_over()
[docs]
self.children: Dict[chess.Move, "Node"] = {}
[docs]
self.legal_moves: Tuple[chess.Move, ...] = tuple(board.legal_moves)
[docs]
self.visits: torch.Tensor = torch.zeros(len(self.legal_moves))
[docs]
self.q_values: torch.Tensor = torch.zeros(len(self.legal_moves))
[docs]
self._value: Optional[torch.Tensor] = None
[docs]
self._policy: Optional[torch.Tensor] = None
[docs]
self._initialized: bool = False
@property
[docs]
def value(self):
return self._value
@property
[docs]
def policy(self):
return self._policy
@property
[docs]
def initialized(self):
return self._initialized
[docs]
def set_evaluation(
self,
td: TensorDict,
) -> None:
"""Set the evaluation for the node.
Parameters
----------
td : TensorDict
TensorDict containing value and policy tensors.
"""
if self._value is not None or self._policy is not None:
raise RuntimeError("Node already initialized.")
self._value = td.get("value")
self._policy = td.get("policy")
self._initialized = True
[docs]
class MCTS:
"""Monte Carlo Tree Search with PUCT formula."""
def __init__(
self,
c_puct: float = 1.0,
n_parallel_rollouts: int = 1,
):
"""Initialize the class."""
[docs]
self.n_parallel_rollouts = n_parallel_rollouts
[docs]
def search_(
self,
root: Node,
heuristic: Heuristic,
iterations: int = 10,
) -> None:
"""Perform MCTS search on the given root node.
Parameters
----------
root : Node
Node instance representing the current board state.
heuristic : Heuristic
Heuristic instance to evaluate board states.
iterations : int
Number of iterations to run the MCTS search.
"""
if root.board.is_game_over():
raise RuntimeError("Game already over.")
if not root.initialized:
root.set_evaluation(heuristic.evaluate(root.board))
for _ in range(iterations):
node = root
done = False
# Selection
while not done:
move = self._select_(node)
# Expansion
if move not in node.children:
done = True
new_board = node.board.copy()
new_board.push(move)
node.children[move] = Node(board=new_board, parent=node)
node = node.children[move]
done = done or node.is_terminal
# Evaluation
value = self._evaluate_(node, heuristic)
# Backpropagation
self._backpropagate_(node, value)
[docs]
def _select_(
self,
node: Node,
) -> chess.Move:
"""Select the move to explore based on the PUCT formula."""
Q = node.q_values.detach() if node.q_values.requires_grad else node.q_values
policy = node.policy.detach() if node.policy.requires_grad else node.policy
# PUCT formula = Q + U
# Q = average value from simulations
# U = exploration bonus encouraging less-visited moves
U = self.c_puct * policy * ((node.visits.sum() + 1) ** 0.5) / (1 + node.visits)
a = torch.argmax(Q + U).item()
node.visits[a] += 1
return node.legal_moves[a]
[docs]
def _evaluate_(
self,
node: Node,
heuristic: Heuristic,
) -> torch.Tensor:
"""Evaluate a single board.
Parameters
----------
node : Node
Node instance representing the current board state.
heuristic : Heuristic
Heuristic instance to evaluate board states.
Returns
-------
value : torch.Tensor
Value tensor for the current node.
"""
if node.initialized:
return node.value
elif node.is_terminal:
outcome = node.board.outcome()
value = torch.Tensor([0.0]) if outcome.winner is None else torch.Tensor([-1.0])
td = TensorDict(value=value, policy=None)
node.set_evaluation(td)
else:
node.set_evaluation(heuristic.evaluate(node.board))
return node.value
[docs]
def _backpropagate_(
self,
node: Node,
value: float,
) -> None:
"""Backpropagate the reward from the leaf node to the root node.
Parameters
----------
node : Node
Node instance representing the leaf node.
value : float
Float value to backpropagate.
"""
while node.parent is not None:
parent = node.parent
value = -value
move = node.board.move_stack[-1]
idx = parent.legal_moves.index(move)
parent.q_values[idx] = (parent.q_values[idx] * parent.visits[idx] + value) / (parent.visits[idx] + 1)
node = parent
[docs]
def render_tree(
root: Node, max_depth: int = 3, save_to: Optional[str] = None, min_visit_percentage: float = 0.0
) -> Optional[str]:
"""
Render the MCTS tree as an SVG.
Parameters
----------
root : Node
Root node of the tree.
max_depth : int, default=3
Maximum depth to render.
save_to : Optional[str], default=None
Path to save the SVG. If None, returns the SVG string.
Returns
-------
Optional[str]
SVG string of the tree, or None if saved to file.
"""
try:
from graphviz import Digraph
except ImportError as e:
raise ImportError(
"graphviz is required to render trees, install it with `pip install lczerolens[viz]`."
) from e
dot = Digraph(comment="MCTS Tree")
dot.attr("graph", rankdir="TB", ranksep="1.5")
dot.attr("node", shape="circle")
dot.node(str(id(root)), label=f"Root\nN={int(root.visits.sum().item())}")
def add_nodes(node: Node, depth: int = 0):
if depth > max_depth:
return
if not node.children:
return
visit_percentages = node.visits / node.visits.sum()
for move, child in node.children.items():
child_index = node.legal_moves.index(move)
if visit_percentages[child_index] < min_visit_percentage:
continue
idx = node.legal_moves.index(move)
n_visits = int(node.visits[idx].item())
label = f"{move}\nN={n_visits}\nV={child.value.item():.3f}"
dot.node(str(id(child)), label=label)
dot.edge(str(id(node)), str(id(child)))
add_nodes(child, depth + 1)
add_nodes(root, 0)
svg_tree = dot.pipe(format="svg").decode("utf-8")
if save_to is not None:
if not save_to.endswith(".svg"):
raise ValueError("Only saving to `.svg` is supported")
with open(save_to, "w") as f:
f.write(svg_tree)
return None
return svg_tree