Skip to content

Commit e68dede

Browse files
committed
Error prototype
1 parent 7688355 commit e68dede

File tree

3 files changed

+91
-14
lines changed

3 files changed

+91
-14
lines changed

python/metatensor-learn/metatensor/learn/_backend.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,14 @@
99
#
1010
# Any change to this file MUST be also be made to `metatensor/torch/learn.py`.
1111

12-
from metatensor import Labels, LabelsEntry, TensorBlock, TensorMap
12+
from metatensor import (
13+
Labels,
14+
LabelsEntry,
15+
TensorBlock,
16+
TensorMap,
17+
equal_metadata,
18+
equal_metadata_block,
19+
)
1320

1421

1522
def torch_jit_is_scripting():
@@ -19,6 +26,8 @@ def torch_jit_is_scripting():
1926
check_isinstance = isinstance
2027

2128
__all__ = [
29+
"equal_metadata",
30+
"equal_metadata_block",
2231
"Labels",
2332
"LabelsEntry",
2433
"TensorBlock",

python/metatensor-learn/metatensor/learn/nn/errors.py

Lines changed: 71 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,82 @@
22
Module containing the :py:class:`AbsoluteError` and :py:class:`SquaredError` classes.
33
"""
44

5-
from .._backend import TensorMap
6-
from .module_map import ModuleMap
5+
import torch
76

7+
from .._backend import TensorBlock, TensorMap, equal_metadata, equal_metadata_block
88

9-
class AbsoluteError(ModuleMap):
109

11-
def __init__(self):
12-
pass
10+
def absolute_error(A: TensorMap, B: TensorMap) -> TensorMap:
11+
if not equal_metadata(A, B):
12+
raise ValueError(
13+
"The two maps must have the same metadata in `absolute_error`."
14+
)
1315

14-
def __call__(self, A: TensorMap, B: TensorMap) -> TensorMap:
15-
pass
16+
keys = []
17+
blocks = []
18+
for key, block_A in A.items():
19+
block_B = B.block(key)
20+
keys.append(key)
21+
blocks.append(absolute_error_block(block_A, block_B))
1622

23+
return TensorMap(keys, blocks)
1724

18-
class SquaredError(ModuleMap):
1925

20-
def __init__(self):
21-
pass
26+
def absolute_error_block(A: TensorBlock, B: TensorBlock) -> TensorBlock:
27+
if not equal_metadata_block(A, B):
28+
raise ValueError(
29+
"The two blocks must have the same metadata in `absolute_error_block`."
30+
)
2231

23-
def __call__(self, A: TensorMap, B: TensorMap) -> TensorMap:
24-
pass
32+
values = torch.abs(A.values - B.values)
33+
block = TensorBlock(
34+
values=values,
35+
samples=A.samples,
36+
components=A.components,
37+
properties=A.properties,
38+
)
39+
for gradient_name, gradient_A in A.gradients():
40+
gradient_B = B.gradient(gradient_name)
41+
block.add_gradient(
42+
gradient_name,
43+
absolute_error_block(gradient_A, gradient_B),
44+
)
45+
46+
return block
47+
48+
49+
def squared_error(A: TensorMap, B: TensorMap) -> TensorMap:
50+
if not equal_metadata(A, B):
51+
raise ValueError("The two maps must have the same metadata in `squared_error`.")
52+
53+
keys = []
54+
blocks = []
55+
for key, block_A in A.items():
56+
block_B = B.block(key)
57+
keys.append(key)
58+
blocks.append(squared_error_block(block_A, block_B))
59+
60+
return TensorMap(keys, blocks)
61+
62+
63+
def squared_error_block(A: TensorBlock, B: TensorBlock) -> TensorBlock:
64+
if not equal_metadata_block(A, B):
65+
raise ValueError(
66+
"The two blocks must have the same metadata in `squared_error_block`."
67+
)
68+
69+
values = (A.values - B.values) ** 2
70+
block = TensorBlock(
71+
values=values,
72+
samples=A.samples,
73+
components=A.components,
74+
properties=A.properties,
75+
)
76+
for gradient_name, gradient_A in A.gradients():
77+
gradient_B = B.gradient(gradient_name)
78+
block.add_gradient(
79+
gradient_name,
80+
squared_error_block(gradient_A, gradient_B),
81+
)
82+
83+
return block

python/metatensor-torch/metatensor/torch/learn.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,14 @@
44
import torch
55

66
import metatensor.learn
7-
from metatensor.torch import Labels, LabelsEntry, TensorBlock, TensorMap
7+
from metatensor.torch import (
8+
Labels,
9+
LabelsEntry,
10+
TensorBlock,
11+
TensorMap,
12+
equal_metadata,
13+
equal_metadata_block,
14+
)
815

916

1017
# ==================================================================================== #
@@ -25,6 +32,8 @@
2532
module.__dict__["LabelsEntry"] = LabelsEntry
2633
module.__dict__["TensorBlock"] = TensorBlock
2734
module.__dict__["TensorMap"] = TensorMap
35+
module.__dict__["equal_metadata"] = equal_metadata
36+
module.__dict__["equal_metadata_block"] = equal_metadata_block
2837
module.__dict__["torch_jit_is_scripting"] = torch.jit.is_scripting
2938

3039

0 commit comments

Comments
 (0)