lczerolens.model#
Class for wrapping the LCZero models.
Attributes#
Classes#
Class for wrapping the LCZero models. |
|
Class for forcing and isolating the value flow. |
|
Base class for isolating a flow. |
|
Class for isolating the policy flow. |
|
Class for isolating the value flow. |
|
Class for isolating the WDL flow. |
|
Class for isolating the MLH flow. |
Module Contents#
- lczerolens.model.MISSING_HF_ERROR = 'huggingface_hub is required to push or load the model from the Hugging Face Hub. Install it...[source]#
- class lczerolens.model.LczeroModel(module, out_keys, **kwargs)[source]#
Bases:
tensordict.nn.TensorDictModuleClass for wrapping the LCZero models.
- Parameters:
module (torch.nn.Module)
out_keys (List[str])
- prepare_boards(*boards, input_encoding=InputEncoding.INPUT_CLASSICAL_112_PLANE)[source]#
Prepares the boards for the model.
- Parameters:
*boards (LczeroBoard) – The boards to prepare.
input_encoding (InputEncoding, optional) – The encoding of the boards.
- Returns:
The prepared boards.
- Return type:
torch.Tensor
- forward(inputs, prepare_kwargs=None, **kwargs)[source]#
- Parameters:
inputs (Union[TensorDict, Iterable[LczeroBoard], torch.Tensor]) – The inputs to the model.
prepare_kwargs (Optional[Dict[str, Any]], optional) – Keyword arguments to pass to the prepare_boards method, by default None
**kwargs (Any) – Additional keyword arguments to pass to the super().forward method.
- Returns:
The output of the model.
- Return type:
TensorDict
- _call_module(tensors, **kwargs)[source]#
- Parameters:
tensors (Sequence[torch.Tensor])
kwargs (Any)
- Return type:
Sequence[torch.Tensor]
- classmethod from_model(model, **kwargs)[source]#
Creates a wrapper from a model.
- Parameters:
model (nn.Module) – The model to wrap.
**kwargs (Any) – Additional keyword arguments to pass to the super().__init__ method.
- Returns:
The wrapped model instance
- Return type:
- classmethod from_path(model_path, **kwargs)[source]#
Creates a wrapper from a model path.
- Parameters:
model_path (str) – Path to the model file (.onnx or .pt)
- Returns:
The wrapped model instance
- Return type:
- Raises:
NotImplementedError – If the model file extension is not supported
- classmethod from_onnx_path(onnx_model_path, check=True, **kwargs)[source]#
Builds a model from an ONNX file path.
- Parameters:
onnx_model_path (str) – Path to the ONNX model file
check (bool, optional) – Whether to perform shape inference check, by default True
- Returns:
The wrapped model instance
- Return type:
- Raises:
FileNotFoundError – If the model file does not exist
ValueError – If the model could not be loaded
- classmethod from_torch_path(torch_model_path, weights_only=False, **kwargs)[source]#
Builds a model from a PyTorch file path.
- Parameters:
torch_model_path (str) – Path to the PyTorch model file
weights_only (bool)
- Returns:
The wrapped model instance
- Return type:
- Raises:
FileNotFoundError – If the model file does not exist
ValueError – If the model could not be loaded or is not a valid model type
- push_to_hf(repo_id, create_if_not_exists=True, create_kwargs=None, path_in_repo='model.pt', **kwargs)[source]#
Pushes the model to the Hugging Face Hub.
- Parameters:
repo_id (str) – The repository id to push the model to.
create_if_not_exists (bool, optional) – Whether to create the repository if it does not exist, by default True
create_kwargs (Optional[Dict[str, Any]], optional) – Additional keyword arguments to pass to the create_repo method.
path_in_repo (str, optional) – The path in the repository to save the model to.
**kwargs (Any) – Additional keyword arguments to pass to the upload_file method.
- Raises:
ImportError – If the huggingface_hub library is not installed.
- classmethod from_hf(repo_id, filename='model.pt', hf_hub_kwargs=None, **kwargs)[source]#
Loads a model from the Hugging Face Hub.
- Parameters:
repo_id (str) – The repository id to load the model from.
filename (str) – The filename of the model to load.
hf_hub_kwargs (Optional[Dict[str, Any]], optional) – Additional keyword arguments to pass to the hf_hub_download method.
**kwargs (Any) – Additional keyword arguments to pass to the from_path method.
- Returns:
The loaded model instance
- Return type:
- Raises:
ImportError – If the huggingface_hub library is not installed.
- class lczerolens.model.ForceValue(module, out_keys, **kwargs)[source]#
Bases:
LczeroModelClass for forcing and isolating the value flow.
- Parameters:
module (torch.nn.Module)
out_keys (List[str])
- class lczerolens.model.Flow(module, out_keys, **kwargs)[source]#
Bases:
LczeroModelBase class for isolating a flow.
- Parameters:
module (torch.nn.Module)
out_keys (List[str])
- class lczerolens.model.PolicyFlow(module, out_keys, **kwargs)[source]#
Bases:
FlowClass for isolating the policy flow.
- Parameters:
module (torch.nn.Module)
out_keys (List[str])
- class lczerolens.model.ValueFlow(module, out_keys, **kwargs)[source]#
Bases:
FlowClass for isolating the value flow.
- Parameters:
module (torch.nn.Module)
out_keys (List[str])