Skip to content

Commit

Permalink
Fix block_from_array usage with torch array and core backend
Browse files Browse the repository at this point in the history
  • Loading branch information
agoscinski authored and Luthaf committed Jul 3, 2024
1 parent 8e236b5 commit 0d9ad62
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 6 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,17 @@
import numpy as np

from . import _dispatch
from ._backend import Labels, TensorBlock, torch_jit_script
from ._backend import Labels, TensorBlock, torch_jit_is_scripting, torch_jit_script


try:
import torch

TorchScriptClass = torch.ScriptClass
except ImportError:

class TorchScriptClass:
pass


@torch_jit_script
Expand Down Expand Up @@ -50,22 +62,37 @@ def block_from_array(array) -> TensorBlock:
must have at least two dimensions. Too few provided: {n_dimensions}"
)

if torch_jit_is_scripting():
# we are using metatensor-torch
labels_array_like = torch.empty(0)
else:
if isinstance(Labels, TorchScriptClass):
# we are using metatensor-torch
labels_array_like = torch.empty(0)
else:
# we are using metatensor-core
labels_array_like = np.empty(0)

samples = Labels(
names=["sample"],
values=_dispatch.int_array_like(list(range(shape[0])), array).reshape(-1, 1),
values=_dispatch.int_array_like(
list(range(shape[0])), labels_array_like
).reshape(-1, 1),
)
components = [
Labels(
names=[f"component_{component_index+1}"],
values=_dispatch.int_array_like(list(range(axis_size)), array).reshape(
-1, 1
),
values=_dispatch.int_array_like(
list(range(axis_size)), labels_array_like
).reshape(-1, 1),
)
for component_index, axis_size in enumerate(shape[1:-1])
]
properties = Labels(
names=["property"],
values=_dispatch.int_array_like(list(range(shape[-1])), array).reshape(-1, 1),
values=_dispatch.int_array_like(
list(range(shape[-1])), labels_array_like
).reshape(-1, 1),
)

device = _dispatch.get_device(array)
Expand Down
33 changes: 33 additions & 0 deletions python/metatensor-operations/tests/block_from_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,14 @@
import metatensor


try:
import torch

HAS_TORCH = True
except ImportError:
HAS_TORCH = False


@pytest.mark.parametrize("n_axes", [0, 1])
def test_too_few_axes(n_axes):
"""Test block_from_array when too few axes are provided."""
Expand Down Expand Up @@ -50,3 +58,28 @@ def test_with_components():
np.testing.assert_equal(
block.properties.values, np.arange(array.shape[2]).reshape((-1, 1))
)


@pytest.mark.skipif(not HAS_TORCH, reason="requires torch")
def test_torch_with_components():
"""Test block_from_array with components and torch arrays"""
array = array = torch.zeros((6, 5, 7))
block = metatensor.block_from_array(array)
assert block.values is array

assert block.samples.names == ["sample"]
np.testing.assert_equal(
block.samples.values, np.arange(array.shape[0]).reshape((-1, 1))
)

assert len(block.components) == 1
component = block.components[0]
assert component.names == ["component_1"]
np.testing.assert_equal(
component.values, np.arange(array.shape[1]).reshape((-1, 1))
)

assert block.properties.names == ["property"]
np.testing.assert_equal(
block.properties.values, np.arange(array.shape[2]).reshape((-1, 1))
)

0 comments on commit 0d9ad62

Please sign in to comment.