Skip to content

Commit ef1383c

Browse files
agoscinskiLuthaf
authored andcommitted
Add testing utilities to perform finite difference for operations
1 parent 0d9ad62 commit ef1383c

File tree

3 files changed

+197
-0
lines changed

3 files changed

+197
-0
lines changed

python/metatensor-operations/metatensor/operations/_dispatch.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,20 @@ def _check_all_np_ndarray(arrays):
5858
)
5959

6060

61+
def sum(array, axis: Optional[int] = None):
62+
"""
63+
Returns the sum of the elements in the array at the axis.
64+
65+
It is equivalent of np.sum(array, axis=axis) and torch.sum(tensor, dim=axis)
66+
"""
67+
if isinstance(array, TorchTensor):
68+
return torch.sum(array, dim=axis)
69+
elif isinstance(array, np.ndarray):
70+
return np.sum(array, axis=axis).astype(array.dtype)
71+
else:
72+
raise TypeError(UNKNOWN_ARRAY_TYPE)
73+
74+
6175
def abs(array):
6276
"""
6377
Returns the absolute value of the elements in the array.
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
import numpy as np
2+
3+
import metatensor
4+
from metatensor import Labels, TensorBlock, TensorMap
5+
from metatensor.operations import _dispatch
6+
7+
8+
def check_finite_differences(
9+
function,
10+
array,
11+
*,
12+
parameter: str,
13+
displacement: float = 1e-6,
14+
rtol: float = 1e-5,
15+
atol: float = 1e-15,
16+
) -> None:
17+
"""
18+
Check that analytical gradients with respect to ``parameter`` in the
19+
:py:class:`TensorMap` returned by ``function`` agree with finite differences.
20+
21+
The ``function`` must take an array (either torch or numpy) and return a
22+
:py:class:`TensorMap`. All the blocks in the returned TensorMap should have one
23+
sample per row of the ``array``, and the gradient-specific components must match the
24+
other dimensions of the ``array``.
25+
"""
26+
n_samples = array.shape[0]
27+
n_grad_components = array.shape[1:]
28+
29+
reference = function(array)
30+
31+
values_components = reference.block(0).components
32+
grad_components = reference.block(0).gradient(parameter).components
33+
34+
assert len(grad_components) == len(values_components) + len(n_grad_components)
35+
36+
for sample_i in range(n_samples):
37+
for grad_components_i in np.ndindex(n_grad_components):
38+
array_pos = _dispatch.copy(array)
39+
index = (sample_i,) + grad_components_i
40+
array_pos[index] += displacement / 2
41+
updated_pos = function(array_pos)
42+
43+
array_neg = _dispatch.copy(array)
44+
array_neg[index] -= displacement / 2
45+
updated_neg = function(array_neg)
46+
47+
assert updated_pos.keys == reference.keys
48+
assert updated_neg.keys == reference.keys
49+
50+
for key, block in reference.items():
51+
gradients = block.gradient(parameter)
52+
53+
block_pos = updated_pos.block(key)
54+
block_neg = updated_neg.block(key)
55+
56+
for gradient_i, gradient_sample in enumerate(gradients.samples):
57+
current_sample_i = gradient_sample[0]
58+
if current_sample_i != sample_i:
59+
continue
60+
61+
assert block_pos.samples[sample_i] == block.samples[sample_i]
62+
assert block_neg.samples[sample_i] == block.samples[sample_i]
63+
64+
value_pos = block_pos.values[sample_i]
65+
value_neg = block_neg.values[sample_i]
66+
67+
grad_index = (gradient_i,) + grad_components_i
68+
gradient = gradients.values[grad_index]
69+
70+
assert value_pos.shape == gradient.shape
71+
assert value_neg.shape == gradient.shape
72+
73+
finite_difference = (value_pos - value_neg) / displacement
74+
75+
np.testing.assert_allclose(
76+
finite_difference,
77+
gradient,
78+
rtol=rtol,
79+
atol=atol,
80+
)
81+
82+
83+
def cartesian_cubic(array) -> TensorMap:
84+
"""
85+
Creates a TensorMap from a set of cartesian vectors according to the function:
86+
87+
.. math::
88+
89+
f(x, y, z) = x^3 + y^3 + z^3
90+
91+
\\nabla f = (3x^2, 3y^2, 3z^2)
92+
93+
"""
94+
n_samples = array.shape[0]
95+
assert array.shape == (n_samples, 3)
96+
97+
values = _dispatch.sum(array**3, axis=1)
98+
values_grad = 3 * array**2
99+
100+
block = metatensor.block_from_array(values.reshape(n_samples, 1))
101+
block.add_gradient(
102+
parameter="g",
103+
gradient=TensorBlock(
104+
values=values_grad.reshape(n_samples, 3, 1),
105+
samples=Labels.range("sample", len(values)),
106+
components=[Labels.range("xyz", 3)],
107+
properties=block.properties,
108+
),
109+
)
110+
111+
return TensorMap(Labels.range("_", 1), [block])
112+
113+
114+
def cartesian_linear(array) -> TensorMap:
115+
"""
116+
Creates a TensorMap from a set of cartesian vectors according to the function:
117+
118+
.. math::
119+
120+
f(x, y, z) = 3x + 2y + 8*z + 4
121+
122+
\\nabla f = (3, 2, 8)
123+
124+
"""
125+
n_samples = array.shape[0]
126+
assert array.shape == (n_samples, 3)
127+
128+
values = 3 * array[:, 0] + 2 * array[:, 1] + 8 * array[:, 2] + 4
129+
130+
values_grad = _dispatch.zeros_like(array, (n_samples, 3, 1))
131+
values_grad[:, 0] = 3 * _dispatch.ones_like(array, (n_samples, 1))
132+
values_grad[:, 1] = 2 * _dispatch.ones_like(array, (n_samples, 1))
133+
values_grad[:, 2] = 8 * _dispatch.ones_like(array, (n_samples, 1))
134+
135+
block = metatensor.block_from_array(values.reshape(-1, 1))
136+
block.add_gradient(
137+
parameter="g",
138+
gradient=TensorBlock(
139+
values=values_grad,
140+
samples=Labels.range("sample", len(values)),
141+
components=[Labels.range("xyz", 3)],
142+
properties=block.properties,
143+
),
144+
)
145+
146+
return TensorMap(Labels.range("_", 1), [block])
147+
148+
149+
def test_basic_functions():
150+
array = np.random.rand(42, 3)
151+
check_finite_differences(cartesian_cubic, array, parameter="g")
152+
check_finite_differences(cartesian_linear, array, parameter="g")

python/metatensor-operations/tests/add.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,16 @@
44
import metatensor
55
from metatensor import Labels, TensorBlock, TensorMap
66

7+
from . import _gradcheck
8+
9+
10+
try:
11+
import torch
12+
13+
HAS_TORCH = True
14+
except ImportError:
15+
HAS_TORCH = False
16+
717

818
@pytest.fixture
919
def keys():
@@ -258,3 +268,24 @@ def test_self_add_error():
258268
)
259269
with pytest.raises(TypeError, match=message):
260270
metatensor.add(tensor, np.ones((3, 4)))
271+
272+
273+
def test_add_finite_difference():
274+
def function(array):
275+
tensor_1 = _gradcheck.cartesian_linear(array)
276+
tensor_2 = _gradcheck.cartesian_cubic(array)
277+
return metatensor.add(tensor_1, tensor_2)
278+
279+
array = np.random.rand(5, 3)
280+
_gradcheck.check_finite_differences(function, array, parameter="g")
281+
282+
283+
@pytest.mark.skipif(not HAS_TORCH, reason="requires torch")
284+
def test_torch_add_finite_difference():
285+
def function(array):
286+
tensor_1 = _gradcheck.cartesian_linear(array)
287+
tensor_2 = _gradcheck.cartesian_cubic(array)
288+
return metatensor.add(tensor_1, tensor_2)
289+
290+
array = torch.rand(5, 3, dtype=torch.float64)
291+
_gradcheck.check_finite_differences(function, array, parameter="g")

0 commit comments

Comments
 (0)