Skip to content

Commit 884ab63

Browse files
Add testing utilities to perform finite difference for operations
- Add finite difference test for add operation - Add dispatch functions to add tests for torch array backend Co-authored-by: Divya Suman <[email protected]>
1 parent 111f0f9 commit 884ab63

File tree

3 files changed

+61
-0
lines changed

3 files changed

+61
-0
lines changed

python/metatensor-operations/metatensor/operations/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,5 @@
6060
from .subtract import subtract # noqa
6161
from .unique_metadata import unique_metadata, unique_metadata_block # noqa
6262
from .zeros_like import zeros_like, zeros_like_block # noqa
63+
64+
from . import _testing # noqa

python/metatensor-operations/metatensor/operations/_dispatch.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,20 @@ def _check_all_np_ndarray(arrays):
4646
)
4747

4848

49+
def sum(array, axis: Optional[int] = None):
50+
"""
51+
Returns the sum of the elements in the array at the axis.
52+
53+
It is equivalent of np.sum(array, axis=axis) and torch.sum(tensor, dim=axis)
54+
"""
55+
if isinstance(array, TorchTensor):
56+
return torch.sum(array, dim=axis)
57+
elif isinstance(array, np.ndarray):
58+
return np.sum(array, axis=axis).astype(array.dtype)
59+
else:
60+
raise TypeError(UNKNOWN_ARRAY_TYPE)
61+
62+
4963
def abs(array):
5064
"""
5165
Returns the absolute value of the elements in the array.

python/metatensor-operations/tests/add.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,30 @@
33

44
import metatensor
55
from metatensor import Labels, TensorBlock, TensorMap
6+
from metatensor.operations._testing import (
7+
cartesian_cubic,
8+
cartesian_linear,
9+
finite_differences,
10+
)
11+
12+
13+
try:
14+
import torch # noqa
15+
16+
HAS_TORCH = True
17+
except ImportError:
18+
HAS_TORCH = False
19+
20+
21+
@pytest.fixture(scope="module", autouse=True)
22+
def set_random_generator():
23+
"""Set the random generator to same seed before each test is run.
24+
Otherwise test behaviour is dependend on the order of the tests
25+
in this file and the number of parameters of the test.
26+
"""
27+
np.random.seed(1225787)
28+
if HAS_TORCH:
29+
torch.manual_seed(1225787)
630

731

832
@pytest.fixture
@@ -258,3 +282,24 @@ def test_self_add_error():
258282
)
259283
with pytest.raises(TypeError, match=message):
260284
metatensor.add(tensor, np.ones((3, 4)))
285+
286+
287+
def test_add_finite_difference():
288+
def add_callable(cartesian_vectors, compute_grad=False):
289+
tensor1 = cartesian_linear(cartesian_vectors, compute_grad)
290+
tensor2 = cartesian_cubic(cartesian_vectors, compute_grad)
291+
return metatensor.add(tensor1, tensor2)
292+
293+
input_array = np.random.rand(5, 3)
294+
finite_differences(add_callable, input_array, "positions")
295+
296+
297+
@pytest.mark.skipif(not HAS_TORCH, reason="requires torch")
298+
def test_torch_add_finite_difference():
299+
def add_callable(cartesian_vectors, compute_grad=False):
300+
tensor1 = cartesian_linear(cartesian_vectors, compute_grad)
301+
tensor2 = cartesian_cubic(cartesian_vectors, compute_grad)
302+
return metatensor.add(tensor1, tensor2)
303+
304+
input_array = torch.rand(5, 3, dtype=torch.float64)
305+
finite_differences(add_callable, input_array, "positions")

0 commit comments

Comments
 (0)