Skip to content

Commit

Permalink
Finish up the code and clean up tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Luthaf committed Jul 3, 2024
1 parent de733be commit e2d5c01
Show file tree
Hide file tree
Showing 10 changed files with 68 additions and 75 deletions.
4 changes: 2 additions & 2 deletions docs/src/operations/reference/logic/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@ Logic function

allclose() <allclose>
equal() <equal>
equal_metadata() <equal_metadata>
is_contiguous() <is_contiguous>
equal_metadata() <equal-metadata>
is_contiguous() <is-contiguous>
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ def concatenate(arrays: List[TorchTensor], axis: int):
raise TypeError(UNKNOWN_ARRAY_TYPE)


def is_contiguous_array(array):
def is_contiguous(array):
"""
Checks if a given array is contiguous.
Expand All @@ -360,11 +360,12 @@ def is_contiguous_array(array):
raise TypeError(UNKNOWN_ARRAY_TYPE)


def make_contiguous_array(array):
def make_contiguous(array):
"""
Returns a contiguous array.
It is equivalent of np.ascontiguousarray(array) and tensor.contiguous(). In
the case of numpy, C order is used for consistency with torch. As such, only
This is equivalent of ``np.ascontiguousarray(array)`` and ``tensor.contiguous()``.
In the case of numpy, C order is used for consistency with torch. As such, only
C-contiguity is checked.
"""
if isinstance(array, TorchTensor):
Expand Down
31 changes: 13 additions & 18 deletions python/metatensor-operations/metatensor/operations/is_contiguous.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@
from . import _dispatch
from ._backend import (
Labels,
TensorBlock,
TensorMap,
torch_jit_is_scripting,
torch_jit_script,
)
from ._backend import TensorBlock, TensorMap, torch_jit_script


@torch_jit_script
Expand All @@ -20,15 +14,17 @@ def is_contiguous_block(block: TensorBlock) -> bool:
:return: bool, true if all values arrays contiguous, false otherwise.
"""
check_contiguous: bool = True
if not _dispatch.is_contiguous_array(block.values):
check_contiguous = False
if not _dispatch.is_contiguous(block.values):
return False

for _param, gradient in block.gradients():
if not _dispatch.is_contiguous_array(gradient.values):
check_contiguous = False
for _parameter, gradient in block.gradients():
if len(gradient.gradients_list()) != 0:
raise NotImplementedError("gradients of gradients are not supported")

return check_contiguous
if not _dispatch.is_contiguous(gradient.values):
return False

return True


@torch_jit_script
Expand All @@ -43,9 +39,8 @@ def is_contiguous(tensor: TensorMap) -> bool:
:return: bool, true if all values arrays contiguous, false otherwise.
"""
check_contiguous: bool = True
for _key, block in tensor.items():
for block in tensor.blocks():
if not is_contiguous_block(block):
check_contiguous = False
return False

return check_contiguous
return True
Original file line number Diff line number Diff line change
@@ -1,57 +1,49 @@
from typing import List

from . import _dispatch
from ._backend import (
Labels,
TensorBlock,
TensorMap,
torch_jit_is_scripting,
torch_jit_script,
)
from ._backend import TensorBlock, TensorMap, torch_jit_script


@torch_jit_script
def make_contiguous_block(block: TensorBlock) -> TensorBlock:
"""
Returns a new :py:class:`TensorBlock` where the values and gradient values (if
present) arrays are mades to be contiguous.
Returns a new :py:class:`TensorBlock` where the values and gradient (if present)
arrays are made to be contiguous.
:param block: the input :py:class:`TensorBlock`.
:return: a new :py:class:`TensorBlock` where the values and gradients arrays (if
present) are contiguous.
"""
contiguous_block = TensorBlock(
values=_dispatch.make_contiguous_array(block.values.copy()),
new_block = TensorBlock(
values=_dispatch.make_contiguous(block.values),
samples=block.samples,
components=block.components,
properties=block.properties,
)
for param, gradient in block.gradients():

for parameter, gradient in block.gradients():
if len(gradient.gradients_list()) != 0:
raise NotImplementedError("gradients of gradients are not supported")

new_gradient = TensorBlock(
values=_dispatch.make_contiguous_array(gradient.values.copy()),
values=_dispatch.make_contiguous(gradient.values),
samples=gradient.samples,
components=gradient.components,
properties=gradient.properties,
)
contiguous_block.add_gradient(param, new_gradient)
new_block.add_gradient(parameter, new_gradient)

return contiguous_block
return new_block


@torch_jit_script
def make_contiguous(tensor: TensorMap) -> TensorMap:
"""
Returns a new :py:class:`TensorMap` where all values and gradient values arrays are
mades to be contiguous.
made to be contiguous.
:param tensor: the input :py:class:`TensorMap`.
:return: a new :py:class:`TensorMap` with the same data and metadata as ``tensor``
and contiguous values of ``tensor``.
"""
keys: Labels = tensor.keys
contiguous_blocks: List[TensorBlock] = []
for _key, block in tensor.items():
contiguous_block = make_contiguous_block(block)
contiguous_blocks.append(contiguous_block)
new_blocks: List[TensorBlock] = []
for block in tensor.blocks():
new_blocks.append(make_contiguous_block(block))

return TensorMap(keys=keys, blocks=contiguous_blocks)
return TensorMap(tensor.keys, new_blocks)
15 changes: 7 additions & 8 deletions python/metatensor-operations/tests/is_contiguous.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,22 @@ def tensor():


@pytest.fixture
def incontiguous_tensor(tensor) -> TensorMap:
def non_contiguous_tensor(tensor) -> TensorMap:
"""
Make a TensorMap non-contiguous by reversing the order of the samples/properties
rows/columns in all values and gradient blocks.
"""
keys = tensor.keys
new_blocks = []

for _key, block in tensor.items():
# Create a new TM with a non-contig array
new_block = _incontiguous_block(block)
for block in tensor.blocks():
new_block = _non_contiguous_block(block)
new_blocks.append(new_block)

return TensorMap(keys=keys, blocks=new_blocks)


def _incontiguous_block(block: TensorBlock) -> TensorBlock:
def _non_contiguous_block(block: TensorBlock) -> TensorBlock:
"""
Make a non-contiguous block by reversing the order in both the main value block and
the gradient block(s).
Expand All @@ -62,9 +61,9 @@ def _incontiguous_block(block: TensorBlock) -> TensorBlock:

def test_is_contiguous_block(tensor):
assert metatensor.is_contiguous_block(tensor.block(0))
assert not metatensor.is_contiguous_block(_incontiguous_block(tensor.block(0)))
assert not metatensor.is_contiguous_block(_non_contiguous_block(tensor.block(0)))


def test_is_contiguous(tensor, incontiguous_tensor):
def test_is_contiguous(tensor, non_contiguous_tensor):
assert metatensor.is_contiguous(tensor)
assert not metatensor.is_contiguous(incontiguous_tensor)
assert not metatensor.is_contiguous(non_contiguous_tensor)
23 changes: 11 additions & 12 deletions python/metatensor-operations/tests/make_contiguous.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def tensor():


@pytest.fixture
def incontiguous_tensor(tensor) -> TensorMap:
def non_contiguous_tensor(tensor) -> TensorMap:
"""
Make a TensorMap non-contiguous by reversing the order of the samples/properties
rows/columns in all values and gradient blocks.
Expand All @@ -28,14 +28,13 @@ def incontiguous_tensor(tensor) -> TensorMap:
new_blocks = []

for _key, block in tensor.items():
# Create a new TM with a non-contig array
new_block = _incontiguous_block(block)
new_block = _non_contiguous_block(block)
new_blocks.append(new_block)

return TensorMap(keys=keys, blocks=new_blocks)


def _incontiguous_block(block: TensorBlock) -> TensorBlock:
def _non_contiguous_block(block: TensorBlock) -> TensorBlock:
"""
Make a non-contiguous block by reversing the order in both the main value block and
the gradient block(s).
Expand All @@ -60,13 +59,13 @@ def _incontiguous_block(block: TensorBlock) -> TensorBlock:
return new_block


def test_is_contiguous_block(tensor):
assert not metatensor.is_contiguous_block(_incontiguous_block(tensor.block(0)))
assert metatensor.is_contiguous_block(
metatensor.make_contiguous_block(_incontiguous_block(tensor.block(0)))
)
def test_make_contiguous_block(tensor):
block = _non_contiguous_block(tensor.block(0))

assert not metatensor.is_contiguous_block(block)
assert metatensor.is_contiguous_block(metatensor.make_contiguous_block(block))


def test_is_contiguous(incontiguous_tensor):
assert not metatensor.is_contiguous(incontiguous_tensor)
assert metatensor.is_contiguous(metatensor.make_contiguous(incontiguous_tensor))
def test_make_contiguous(non_contiguous_tensor):
assert not metatensor.is_contiguous(non_contiguous_tensor)
assert metatensor.is_contiguous(metatensor.make_contiguous(non_contiguous_tensor))
13 changes: 10 additions & 3 deletions python/metatensor-torch/tests/operations/contiguous.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@
import metatensor.torch


def test_is_contiguous():
# TODO: write tests, used as a placeholder for now
assert True
def test_save_load():
with io.BytesIO() as buffer:
torch.jit.save(metatensor.torch.is_contiguous, buffer)
buffer.seek(0)
torch.jit.load(buffer)

with io.BytesIO() as buffer:
torch.jit.save(metatensor.torch.make_contiguous, buffer)
buffer.seek(0)
torch.jit.load(buffer)

0 comments on commit e2d5c01

Please sign in to comment.