diff --git a/docs/src/atomistic/outputs.rst b/docs/src/atomistic/outputs.rst index c4f8998ef..9bc63c147 100644 --- a/docs/src/atomistic/outputs.rst +++ b/docs/src/atomistic/outputs.rst @@ -15,6 +15,8 @@ encouraged to come together, define the metadata they need and add a new section to this page. +.. _energy: + Energy ^^^^^^ @@ -55,6 +57,8 @@ have the following metadata: - the energy must have a single property dimension named ``"energy"``, with a single entry set to ``0``. +.. _energy-gradients: + Energy gradients ---------------- @@ -119,3 +123,47 @@ The following gradients can be defined and requested with - ``["xyz_1", "xyz_2"]`` - Both ``"xyz_1"`` and ``"xyz_2"`` have values ``[0, 1, 2]``, and correspond to the two axes of the 3x3 strain matrix :math:`\epsilon`. + + +Energy ensemble +^^^^^^^^^^^^^^^ + +An ensemble of energies is associated with the ``"energy_ensemble"`` key in the +model outputs. Such ensembles are sometimes used to perform uncertainty +quantification, using multiple prediction to estimate an error on the mean +prediction. + +Energy ensembles must have the following metadata: + +.. list-table:: Metadata for energy ensemble output + :widths: 2 3 7 + :header-rows: 1 + + * - Metadata + - Names + - Description + + * - keys + - same as `Energy`_ + - same as `Energy`_ + + * - samples + - same as `Energy`_ + - same as `Energy`_ + + * - components + - same as `Energy`_ + - same as `Energy`_ + + * - properties + - ``"energy"`` + - the energy ensemble must have a single property dimension named + ``"energy"``, with entries ranging from 0 to the number of members of the + ensemble minus one. + + +Energy ensemble gradients +------------------------- + +The gradient metadata for energy ensemble is the same as for the ``energy`` +output (see `Energy gradients`_). diff --git a/metatensor-torch/src/atomistic/model.cpp b/metatensor-torch/src/atomistic/model.cpp index dea720f7f..59a009c68 100644 --- a/metatensor-torch/src/atomistic/model.cpp +++ b/metatensor-torch/src/atomistic/model.cpp @@ -136,7 +136,8 @@ ModelOutput ModelOutputHolder::from_json(std::string_view json) { /******************************************************************************/ std::unordered_set KNOWN_OUTPUTS = { - "energy" + "energy", + "energy_ensemble" }; void ModelCapabilitiesHolder::set_outputs(torch::Dict outputs) { diff --git a/python/metatensor-torch/metatensor/torch/atomistic/outputs.py b/python/metatensor-torch/metatensor/torch/atomistic/outputs.py index d06f1897f..aa90401f6 100644 --- a/python/metatensor-torch/metatensor/torch/atomistic/outputs.py +++ b/python/metatensor-torch/metatensor/torch/atomistic/outputs.py @@ -44,26 +44,46 @@ def _check_outputs( ) if name == "energy": - _check_energy(systems, request, selected_atoms, energy=value) + _check_energy_like( + "energy", + value, + systems, + request, + selected_atoms, + ) + elif name == "energy_ensemble": + _check_energy_like( + "energy_ensemble", + value, + systems, + request, + selected_atoms, + ) else: # this is a non-standard output, there is nothing to check continue -def _check_energy( +def _check_energy_like( + name: str, + value: TensorMap, systems: List[System], request: ModelOutput, selected_atoms: Optional[Labels], - energy: TensorMap, ): - """Check the "energy" output metadata""" - if energy.keys != Labels("_", torch.tensor([[0]])): + """ + Check either "energy" or "energy_ensemble" output metadata + """ + + assert name in ["energy", "energy_ensemble"] + + if value.keys != Labels("_", torch.tensor([[0]])): raise ValueError( - "invalid keys for 'energy' output: expected `Labels('_', [[0]])`" + f"invalid keys for '{name}' output: expected `Labels('_', [[0]])`" ) - device = energy.device - energy_block = energy.block_by_id(0) + device = value.device + energy_block = value.block_by_id(0) if request.per_atom: expected_samples_names = ["system", "atom"] @@ -72,7 +92,7 @@ def _check_energy( if energy_block.samples.names != expected_samples_names: raise ValueError( - "invalid sample names for 'energy' output: " + f"invalid sample names for '{name}' output: " f"expected {expected_samples_names}, got {energy_block.samples.names}" ) @@ -91,7 +111,7 @@ def _check_energy( if len(expected_samples.union(energy_block.samples)) != len(expected_samples): raise ValueError( - "invalid samples entries for 'energy' output, they do not match the " + f"invalid samples entries for '{name}' output, they do not match the " f"`systems` and `selected_atoms`. Expected samples:\n{expected_samples}" ) @@ -107,48 +127,58 @@ def _check_energy( if len(expected_samples.union(energy_block.samples)) != len(expected_samples): raise ValueError( - "invalid samples entries for 'energy' output, they do not match the " + f"invalid samples entries for '{name}' output, they do not match the " f"`systems` and `selected_atoms`. Expected samples:\n{expected_samples}" ) if len(energy_block.components) != 0: raise ValueError( - "invalid components for 'energy' output: components should be empty" + f"invalid components for '{name}' output: components should be empty" ) - if energy_block.properties != Labels("energy", torch.tensor([[0]], device=device)): - raise ValueError( - "invalid properties for 'energy' output: expected `Labels('energy', [[0]])`" + # the only difference between energy & energy_ensemble is in the properties + if name == "energy": + expected_properties = Labels("energy", torch.tensor([[0]], device=device)) + message = "`Labels('energy', [[0]])`" + else: + assert name == "energy_ensemble" + n_ensemble_members = energy_block.values.shape[-1] + expected_properties = Labels( + "energy", torch.arange(n_ensemble_members, device=device).reshape(-1, 1) ) + message = "`Labels('energy', [[0], ..., [n]])`" + + if energy_block.properties != expected_properties: + raise ValueError(f"invalid properties for '{name}' output: expected {message}") for parameter, gradient in energy_block.gradients(): if parameter not in ["strain", "positions"]: - raise ValueError(f"invalid gradient for 'energy' output: {parameter}") + raise ValueError(f"invalid gradient for '{name}' output: {parameter}") xyz = torch.tensor([[0], [1], [2]], device=device) # strain gradient checks if parameter == "strain": if gradient.samples.names != ["sample"]: raise ValueError( - "invalid samples for 'energy' output 'strain' gradients: " + f"invalid samples for '{name}' output 'strain' gradients: " f"expected the names to be ['sample'], got {gradient.samples.names}" ) if len(gradient.components) != 2: raise ValueError( - "invalid components for 'energy' output 'strain' gradients: " + f"invalid components for '{name}' output 'strain' gradients: " "expected two components" ) if gradient.components[0] != Labels("xyz_1", xyz): raise ValueError( - "invalid components for 'energy' output 'strain' gradients: " + f"invalid components for '{name}' output 'strain' gradients: " "expected Labels('xyz_1', [[0], [1], [2]]) for the first component" ) if gradient.components[1] != Labels("xyz_2", xyz): raise ValueError( - "invalid components for 'energy' output 'strain' gradients: " + f"invalid components for '{name}' output 'strain' gradients: " "expected Labels('xyz_2', [[0], [1], [2]]) for the second component" ) @@ -156,19 +186,19 @@ def _check_energy( if parameter == "positions": if gradient.samples.names != ["sample", "system", "atom"]: raise ValueError( - "invalid samples for 'energy' output 'positions' gradients: " + f"invalid samples for '{name}' output 'positions' gradients: " "expected the names to be ['sample', 'system', 'atom'], " f"got {gradient.samples.names}" ) if len(gradient.components) != 1: raise ValueError( - "invalid components for 'energy' output 'positions' gradients: " + f"invalid components for '{name}' output 'positions' gradients: " "expected one component" ) if gradient.components[0] != Labels("xyz", xyz): raise ValueError( - "invalid components for 'energy' output 'positions' gradients: " + f"invalid components for '{name}' output 'positions' gradients: " "expected Labels('xyz', [[0], [1], [2]]) for the first component" ) diff --git a/python/metatensor-torch/tests/atomistic/outputs.py b/python/metatensor-torch/tests/atomistic/outputs.py new file mode 100644 index 000000000..489be1d33 --- /dev/null +++ b/python/metatensor-torch/tests/atomistic/outputs.py @@ -0,0 +1,74 @@ +from typing import Dict, List, Optional + +import torch + +from metatensor.torch import Labels, TensorBlock, TensorMap +from metatensor.torch.atomistic import ( + MetatensorAtomisticModel, + ModelCapabilities, + ModelEvaluationOptions, + ModelMetadata, + ModelOutput, + System, +) + + +class EnergyEnsembleModel(torch.nn.Module): + """A metatensor atomistic model returning an energy ensemble""" + + def forward( + self, + systems: List[System], + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels] = None, + ) -> Dict[str, TensorMap]: + assert "energy_ensemble" in outputs + assert not outputs["energy_ensemble"].per_atom + assert selected_atoms is None + + return_dict: Dict[str, TensorMap] = {} + block = TensorBlock( + values=torch.tensor([[0.0, 1.0, 2.0]] * len(systems), dtype=torch.float64), + samples=Labels("system", torch.arange(len(systems)).reshape(-1, 1)), + components=[], + properties=Labels("energy", torch.tensor([[0], [1], [2]])), + ) + return_dict["energy_ensemble"] = TensorMap( + Labels("_", torch.tensor([[0]])), [block] + ) + return return_dict + + +def test_energy_ensemble_model(): + model = EnergyEnsembleModel() + + capabilities = ModelCapabilities( + length_unit="angstrom", + atomic_types=[1, 2, 3], + interaction_range=4.3, + outputs={"energy_ensemble": ModelOutput(per_atom=False)}, + supported_devices=["cpu"], + dtype="float64", + ) + + atomistic = MetatensorAtomisticModel(model.eval(), ModelMetadata(), capabilities) + + system = System( + types=torch.tensor([1, 2, 3]), + positions=torch.tensor([[1, 1, 1], [2, 2, 2], [3, 3, 3]], dtype=torch.float64), + cell=torch.zeros([3, 3], dtype=torch.float64), + ) + + options = ModelEvaluationOptions( + outputs={"energy_ensemble": ModelOutput(per_atom=False)} + ) + + result = atomistic([system, system], options, check_consistency=True) + assert "energy_ensemble" in result + + ensemble = result["energy_ensemble"] + + assert ensemble.keys == Labels("_", torch.tensor([[0]])) + assert list(ensemble.block().values.shape) == [2, 3] + assert ensemble.block().samples.names == ["system"] + assert ensemble.block().properties.names == ["energy"]