Skip to content

Commit

Permalink
Error prototype
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Feb 28, 2024
1 parent 7688355 commit e68dede
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 14 deletions.
11 changes: 10 additions & 1 deletion python/metatensor-learn/metatensor/learn/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,14 @@
#
# Any change to this file MUST be also be made to `metatensor/torch/learn.py`.

from metatensor import Labels, LabelsEntry, TensorBlock, TensorMap
from metatensor import (
Labels,
LabelsEntry,
TensorBlock,
TensorMap,
equal_metadata,
equal_metadata_block,
)


def torch_jit_is_scripting():
Expand All @@ -19,6 +26,8 @@ def torch_jit_is_scripting():
check_isinstance = isinstance

__all__ = [
"equal_metadata",
"equal_metadata_block",
"Labels",
"LabelsEntry",
"TensorBlock",
Expand Down
83 changes: 71 additions & 12 deletions python/metatensor-learn/metatensor/learn/nn/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,82 @@
Module containing the :py:class:`AbsoluteError` and :py:class:`SquaredError` classes.
"""

from .._backend import TensorMap
from .module_map import ModuleMap
import torch

from .._backend import TensorBlock, TensorMap, equal_metadata, equal_metadata_block

class AbsoluteError(ModuleMap):

def __init__(self):
pass
def absolute_error(A: TensorMap, B: TensorMap) -> TensorMap:
if not equal_metadata(A, B):
raise ValueError(
"The two maps must have the same metadata in `absolute_error`."
)

def __call__(self, A: TensorMap, B: TensorMap) -> TensorMap:
pass
keys = []
blocks = []
for key, block_A in A.items():
block_B = B.block(key)
keys.append(key)
blocks.append(absolute_error_block(block_A, block_B))

return TensorMap(keys, blocks)

class SquaredError(ModuleMap):

def __init__(self):
pass
def absolute_error_block(A: TensorBlock, B: TensorBlock) -> TensorBlock:
if not equal_metadata_block(A, B):
raise ValueError(
"The two blocks must have the same metadata in `absolute_error_block`."
)

def __call__(self, A: TensorMap, B: TensorMap) -> TensorMap:
pass
values = torch.abs(A.values - B.values)
block = TensorBlock(
values=values,
samples=A.samples,
components=A.components,
properties=A.properties,
)
for gradient_name, gradient_A in A.gradients():
gradient_B = B.gradient(gradient_name)
block.add_gradient(
gradient_name,
absolute_error_block(gradient_A, gradient_B),
)

return block


def squared_error(A: TensorMap, B: TensorMap) -> TensorMap:
if not equal_metadata(A, B):
raise ValueError("The two maps must have the same metadata in `squared_error`.")

keys = []
blocks = []
for key, block_A in A.items():
block_B = B.block(key)
keys.append(key)
blocks.append(squared_error_block(block_A, block_B))

return TensorMap(keys, blocks)


def squared_error_block(A: TensorBlock, B: TensorBlock) -> TensorBlock:
if not equal_metadata_block(A, B):
raise ValueError(
"The two blocks must have the same metadata in `squared_error_block`."
)

values = (A.values - B.values) ** 2
block = TensorBlock(
values=values,
samples=A.samples,
components=A.components,
properties=A.properties,
)
for gradient_name, gradient_A in A.gradients():
gradient_B = B.gradient(gradient_name)
block.add_gradient(
gradient_name,
squared_error_block(gradient_A, gradient_B),
)

return block
11 changes: 10 additions & 1 deletion python/metatensor-torch/metatensor/torch/learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,14 @@
import torch

import metatensor.learn
from metatensor.torch import Labels, LabelsEntry, TensorBlock, TensorMap
from metatensor.torch import (
Labels,
LabelsEntry,
TensorBlock,
TensorMap,
equal_metadata,
equal_metadata_block,
)


# ==================================================================================== #
Expand All @@ -25,6 +32,8 @@
module.__dict__["LabelsEntry"] = LabelsEntry
module.__dict__["TensorBlock"] = TensorBlock
module.__dict__["TensorMap"] = TensorMap
module.__dict__["equal_metadata"] = equal_metadata
module.__dict__["equal_metadata_block"] = equal_metadata_block
module.__dict__["torch_jit_is_scripting"] = torch.jit.is_scripting


Expand Down

0 comments on commit e68dede

Please sign in to comment.