diff --git a/python/metatensor-torch/metatensor/torch/atomistic/model.py b/python/metatensor-torch/metatensor/torch/atomistic/model.py index 54cd9e861..e63c6742b 100644 --- a/python/metatensor-torch/metatensor/torch/atomistic/model.py +++ b/python/metatensor-torch/metatensor/torch/atomistic/model.py @@ -4,7 +4,8 @@ import os import platform import warnings -from typing import Dict, List, Optional +from pathlib import Path +from typing import Dict, List, Optional, Union import torch from torch.profiler import record_function @@ -39,7 +40,8 @@ def load_atomistic_model(path, extensions_directory=None) -> "MetatensorAtomisti :param extensions_directory: path to a directory containing all extensions required by the exported model """ - load_model_extensions(path, extensions_directory) + path = str(path) + load_model_extensions(path, str(extensions_directory)) check_atomistic_model(path) return torch.jit.load(path) @@ -233,6 +235,14 @@ class MetatensorAtomisticModel(torch.nn.Module): >>> with tempfile.TemporaryDirectory() as directory: ... wrapped.save(os.path.join(directory, "constant-energy-model.pt")) ... + + .. py:attribute:: module + :type: ModelInterface + + The torch module wrapped by this :py:class:`MetatensorAtomisticModel`. + + Reading from this attribute is safe, but modifying it is not recommended, + unless you are familiar with the implementation of the model. """ # Some annotation to make the TorchScript compiler happy @@ -260,15 +270,15 @@ def __init__( raise ValueError("module should not be in training mode") _check_annotation(module) - self._module = module + self.module = module # ============================================================================ # # recursively explore `module` to get all the requested_neighbor_lists self._requested_neighbor_lists = [] _get_requested_neighbor_lists( - self._module, - self._module.__class__.__name__, + self.module, + self.module.__class__.__name__, self._requested_neighbor_lists, capabilities.length_unit, ) @@ -309,10 +319,6 @@ def __init__( else: raise ValueError(f"unknown dtype in capabilities: {capabilities.dtype}") - def wrapped_module(self) -> torch.nn.Module: - """Get the module wrapped in this :py:class:`MetatensorAtomisticModel`""" - return self._module - @torch.jit.export def capabilities(self) -> ModelCapabilities: """Get the capabilities of the wrapped model""" @@ -384,7 +390,7 @@ def forward( # run the actual calculations with record_function("Model::forward"): - outputs = self._module( + outputs = self.module( systems=systems, outputs=options.outputs, selected_atoms=options.selected_atoms, @@ -447,7 +453,7 @@ def export(self, file: str, collect_extensions: Optional[str] = None): ) return self.save(file, collect_extensions) - def save(self, file: str, collect_extensions: Optional[str] = None): + def save(self, file: Union[str, Path], collect_extensions: Optional[str] = None): """Save this model to a file that can then be loaded by simulation engine. :param file: where to save the model. This can be a path or a file-like object. @@ -484,7 +490,7 @@ def save(self, file: str, collect_extensions: Optional[str] = None): torch.jit.save( module.to("cpu"), # this allows to torch.jit.load without devices - file, + str(file), _extra_files={ "torch-version": torch.__version__, "metatensor-version": metatensor_version, diff --git a/python/metatensor-torch/tests/atomistic/ase_calculator.py b/python/metatensor-torch/tests/atomistic/ase_calculator.py index 4dceb42df..e35a553e9 100644 --- a/python/metatensor-torch/tests/atomistic/ase_calculator.py +++ b/python/metatensor-torch/tests/atomistic/ase_calculator.py @@ -237,7 +237,7 @@ def test_dtype_device(tmpdir, model, atoms): # re-create the model with a different dtype dtype_model = MetatensorAtomisticModel( - model._module.to(STR_TO_DTYPE[dtype]), + model.module.to(STR_TO_DTYPE[dtype]), model.metadata(), capabilities, ) diff --git a/python/metatensor-torch/tests/atomistic/model.py b/python/metatensor-torch/tests/atomistic/model.py index 83432bd09..8acb3b5b6 100644 --- a/python/metatensor-torch/tests/atomistic/model.py +++ b/python/metatensor-torch/tests/atomistic/model.py @@ -14,6 +14,7 @@ NeighborListOptions, System, check_atomistic_model, + load_atomistic_model, ) @@ -240,3 +241,29 @@ def test_bad_capabilities(): ) with pytest.raises(ValueError, match=message): ModelCapabilities(outputs={"not-a-standard::": ModelOutput()}) + + +def test_access_module(tmpdir): + model = FullModel() + model.train(False) + + capabilities = ModelCapabilities( + interaction_range=0.0, + supported_devices=["cpu"], + dtype="float64", + ) + atomistic = MetatensorAtomisticModel(model, ModelMetadata(), capabilities) + + # Access wrapped module + assert atomistic.module is model + + atomistic.save(tmpdir / "export.pt") + loaded_atomistic = load_atomistic_model(tmpdir / "export.pt") + + # Access wrapped module after loading + loaded_atomistic.module + + # Verfify that it contains the original submodules + loaded_atomistic.module.first + loaded_atomistic.module.second + loaded_atomistic.module.other