Skip to content

Commit e2d5c01

Browse files
committed
Finish up the code and clean up tests
1 parent de733be commit e2d5c01

File tree

10 files changed

+68
-75
lines changed

10 files changed

+68
-75
lines changed

docs/src/operations/reference/logic/index.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,5 @@ Logic function
66

77
allclose() <allclose>
88
equal() <equal>
9-
equal_metadata() <equal_metadata>
10-
is_contiguous() <is_contiguous>
9+
equal_metadata() <equal-metadata>
10+
is_contiguous() <is-contiguous>

python/metatensor-operations/metatensor/operations/_dispatch.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ def concatenate(arrays: List[TorchTensor], axis: int):
343343
raise TypeError(UNKNOWN_ARRAY_TYPE)
344344

345345

346-
def is_contiguous_array(array):
346+
def is_contiguous(array):
347347
"""
348348
Checks if a given array is contiguous.
349349
@@ -360,11 +360,12 @@ def is_contiguous_array(array):
360360
raise TypeError(UNKNOWN_ARRAY_TYPE)
361361

362362

363-
def make_contiguous_array(array):
363+
def make_contiguous(array):
364364
"""
365365
Returns a contiguous array.
366-
It is equivalent of np.ascontiguousarray(array) and tensor.contiguous(). In
367-
the case of numpy, C order is used for consistency with torch. As such, only
366+
367+
This is equivalent of ``np.ascontiguousarray(array)`` and ``tensor.contiguous()``.
368+
In the case of numpy, C order is used for consistency with torch. As such, only
368369
C-contiguity is checked.
369370
"""
370371
if isinstance(array, TorchTensor):
Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,5 @@
11
from . import _dispatch
2-
from ._backend import (
3-
Labels,
4-
TensorBlock,
5-
TensorMap,
6-
torch_jit_is_scripting,
7-
torch_jit_script,
8-
)
2+
from ._backend import TensorBlock, TensorMap, torch_jit_script
93

104

115
@torch_jit_script
@@ -20,15 +14,17 @@ def is_contiguous_block(block: TensorBlock) -> bool:
2014
2115
:return: bool, true if all values arrays contiguous, false otherwise.
2216
"""
23-
check_contiguous: bool = True
24-
if not _dispatch.is_contiguous_array(block.values):
25-
check_contiguous = False
17+
if not _dispatch.is_contiguous(block.values):
18+
return False
2619

27-
for _param, gradient in block.gradients():
28-
if not _dispatch.is_contiguous_array(gradient.values):
29-
check_contiguous = False
20+
for _parameter, gradient in block.gradients():
21+
if len(gradient.gradients_list()) != 0:
22+
raise NotImplementedError("gradients of gradients are not supported")
3023

31-
return check_contiguous
24+
if not _dispatch.is_contiguous(gradient.values):
25+
return False
26+
27+
return True
3228

3329

3430
@torch_jit_script
@@ -43,9 +39,8 @@ def is_contiguous(tensor: TensorMap) -> bool:
4339
4440
:return: bool, true if all values arrays contiguous, false otherwise.
4541
"""
46-
check_contiguous: bool = True
47-
for _key, block in tensor.items():
42+
for block in tensor.blocks():
4843
if not is_contiguous_block(block):
49-
check_contiguous = False
44+
return False
5045

51-
return check_contiguous
46+
return True
Lines changed: 20 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,57 +1,49 @@
1+
from typing import List
2+
13
from . import _dispatch
2-
from ._backend import (
3-
Labels,
4-
TensorBlock,
5-
TensorMap,
6-
torch_jit_is_scripting,
7-
torch_jit_script,
8-
)
4+
from ._backend import TensorBlock, TensorMap, torch_jit_script
95

106

117
@torch_jit_script
128
def make_contiguous_block(block: TensorBlock) -> TensorBlock:
139
"""
14-
Returns a new :py:class:`TensorBlock` where the values and gradient values (if
15-
present) arrays are mades to be contiguous.
10+
Returns a new :py:class:`TensorBlock` where the values and gradient (if present)
11+
arrays are made to be contiguous.
1612
1713
:param block: the input :py:class:`TensorBlock`.
18-
19-
:return: a new :py:class:`TensorBlock` where the values and gradients arrays (if
20-
present) are contiguous.
2114
"""
22-
contiguous_block = TensorBlock(
23-
values=_dispatch.make_contiguous_array(block.values.copy()),
15+
new_block = TensorBlock(
16+
values=_dispatch.make_contiguous(block.values),
2417
samples=block.samples,
2518
components=block.components,
2619
properties=block.properties,
2720
)
28-
for param, gradient in block.gradients():
21+
22+
for parameter, gradient in block.gradients():
23+
if len(gradient.gradients_list()) != 0:
24+
raise NotImplementedError("gradients of gradients are not supported")
25+
2926
new_gradient = TensorBlock(
30-
values=_dispatch.make_contiguous_array(gradient.values.copy()),
27+
values=_dispatch.make_contiguous(gradient.values),
3128
samples=gradient.samples,
3229
components=gradient.components,
3330
properties=gradient.properties,
3431
)
35-
contiguous_block.add_gradient(param, new_gradient)
32+
new_block.add_gradient(parameter, new_gradient)
3633

37-
return contiguous_block
34+
return new_block
3835

3936

4037
@torch_jit_script
4138
def make_contiguous(tensor: TensorMap) -> TensorMap:
4239
"""
4340
Returns a new :py:class:`TensorMap` where all values and gradient values arrays are
44-
mades to be contiguous.
41+
made to be contiguous.
4542
4643
:param tensor: the input :py:class:`TensorMap`.
47-
48-
:return: a new :py:class:`TensorMap` with the same data and metadata as ``tensor``
49-
and contiguous values of ``tensor``.
5044
"""
51-
keys: Labels = tensor.keys
52-
contiguous_blocks: List[TensorBlock] = []
53-
for _key, block in tensor.items():
54-
contiguous_block = make_contiguous_block(block)
55-
contiguous_blocks.append(contiguous_block)
45+
new_blocks: List[TensorBlock] = []
46+
for block in tensor.blocks():
47+
new_blocks.append(make_contiguous_block(block))
5648

57-
return TensorMap(keys=keys, blocks=contiguous_blocks)
49+
return TensorMap(tensor.keys, new_blocks)

python/metatensor-operations/tests/is_contiguous.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,23 +19,22 @@ def tensor():
1919

2020

2121
@pytest.fixture
22-
def incontiguous_tensor(tensor) -> TensorMap:
22+
def non_contiguous_tensor(tensor) -> TensorMap:
2323
"""
2424
Make a TensorMap non-contiguous by reversing the order of the samples/properties
2525
rows/columns in all values and gradient blocks.
2626
"""
2727
keys = tensor.keys
2828
new_blocks = []
2929

30-
for _key, block in tensor.items():
31-
# Create a new TM with a non-contig array
32-
new_block = _incontiguous_block(block)
30+
for block in tensor.blocks():
31+
new_block = _non_contiguous_block(block)
3332
new_blocks.append(new_block)
3433

3534
return TensorMap(keys=keys, blocks=new_blocks)
3635

3736

38-
def _incontiguous_block(block: TensorBlock) -> TensorBlock:
37+
def _non_contiguous_block(block: TensorBlock) -> TensorBlock:
3938
"""
4039
Make a non-contiguous block by reversing the order in both the main value block and
4140
the gradient block(s).
@@ -62,9 +61,9 @@ def _incontiguous_block(block: TensorBlock) -> TensorBlock:
6261

6362
def test_is_contiguous_block(tensor):
6463
assert metatensor.is_contiguous_block(tensor.block(0))
65-
assert not metatensor.is_contiguous_block(_incontiguous_block(tensor.block(0)))
64+
assert not metatensor.is_contiguous_block(_non_contiguous_block(tensor.block(0)))
6665

6766

68-
def test_is_contiguous(tensor, incontiguous_tensor):
67+
def test_is_contiguous(tensor, non_contiguous_tensor):
6968
assert metatensor.is_contiguous(tensor)
70-
assert not metatensor.is_contiguous(incontiguous_tensor)
69+
assert not metatensor.is_contiguous(non_contiguous_tensor)

python/metatensor-operations/tests/make_contiguous.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def tensor():
1919

2020

2121
@pytest.fixture
22-
def incontiguous_tensor(tensor) -> TensorMap:
22+
def non_contiguous_tensor(tensor) -> TensorMap:
2323
"""
2424
Make a TensorMap non-contiguous by reversing the order of the samples/properties
2525
rows/columns in all values and gradient blocks.
@@ -28,14 +28,13 @@ def incontiguous_tensor(tensor) -> TensorMap:
2828
new_blocks = []
2929

3030
for _key, block in tensor.items():
31-
# Create a new TM with a non-contig array
32-
new_block = _incontiguous_block(block)
31+
new_block = _non_contiguous_block(block)
3332
new_blocks.append(new_block)
3433

3534
return TensorMap(keys=keys, blocks=new_blocks)
3635

3736

38-
def _incontiguous_block(block: TensorBlock) -> TensorBlock:
37+
def _non_contiguous_block(block: TensorBlock) -> TensorBlock:
3938
"""
4039
Make a non-contiguous block by reversing the order in both the main value block and
4140
the gradient block(s).
@@ -60,13 +59,13 @@ def _incontiguous_block(block: TensorBlock) -> TensorBlock:
6059
return new_block
6160

6261

63-
def test_is_contiguous_block(tensor):
64-
assert not metatensor.is_contiguous_block(_incontiguous_block(tensor.block(0)))
65-
assert metatensor.is_contiguous_block(
66-
metatensor.make_contiguous_block(_incontiguous_block(tensor.block(0)))
67-
)
62+
def test_make_contiguous_block(tensor):
63+
block = _non_contiguous_block(tensor.block(0))
64+
65+
assert not metatensor.is_contiguous_block(block)
66+
assert metatensor.is_contiguous_block(metatensor.make_contiguous_block(block))
6867

6968

70-
def test_is_contiguous(incontiguous_tensor):
71-
assert not metatensor.is_contiguous(incontiguous_tensor)
72-
assert metatensor.is_contiguous(metatensor.make_contiguous(incontiguous_tensor))
69+
def test_make_contiguous(non_contiguous_tensor):
70+
assert not metatensor.is_contiguous(non_contiguous_tensor)
71+
assert metatensor.is_contiguous(metatensor.make_contiguous(non_contiguous_tensor))

python/metatensor-torch/tests/operations/contiguous.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,13 @@
55
import metatensor.torch
66

77

8-
def test_is_contiguous():
9-
# TODO: write tests, used as a placeholder for now
10-
assert True
8+
def test_save_load():
9+
with io.BytesIO() as buffer:
10+
torch.jit.save(metatensor.torch.is_contiguous, buffer)
11+
buffer.seek(0)
12+
torch.jit.load(buffer)
13+
14+
with io.BytesIO() as buffer:
15+
torch.jit.save(metatensor.torch.make_contiguous, buffer)
16+
buffer.seek(0)
17+
torch.jit.load(buffer)

0 commit comments

Comments
 (0)