diff --git a/python/metatensor-operations/metatensor/operations/block_from_array.py b/python/metatensor-operations/metatensor/operations/block_from_array.py index d05508684..329e17f4d 100644 --- a/python/metatensor-operations/metatensor/operations/block_from_array.py +++ b/python/metatensor-operations/metatensor/operations/block_from_array.py @@ -1,5 +1,17 @@ +import numpy as np + from . import _dispatch -from ._backend import Labels, TensorBlock, torch_jit_script +from ._backend import Labels, TensorBlock, torch_jit_is_scripting, torch_jit_script + + +try: + import torch + + TorchScriptClass = torch.ScriptClass +except ImportError: + + class TorchScriptClass: + pass @torch_jit_script @@ -50,22 +62,37 @@ def block_from_array(array) -> TensorBlock: must have at least two dimensions. Too few provided: {n_dimensions}" ) + if torch_jit_is_scripting(): + # we are using metatensor-torch + labels_array_like = torch.empty(0) + else: + if isinstance(Labels, TorchScriptClass): + # we are using metatensor-torch + labels_array_like = torch.empty(0) + else: + # we are using metatensor-core + labels_array_like = np.empty(0) + samples = Labels( names=["sample"], - values=_dispatch.int_array_like(list(range(shape[0])), array).reshape(-1, 1), + values=_dispatch.int_array_like( + list(range(shape[0])), labels_array_like + ).reshape(-1, 1), ) components = [ Labels( names=[f"component_{component_index+1}"], - values=_dispatch.int_array_like(list(range(axis_size)), array).reshape( - -1, 1 - ), + values=_dispatch.int_array_like( + list(range(axis_size)), labels_array_like + ).reshape(-1, 1), ) for component_index, axis_size in enumerate(shape[1:-1]) ] properties = Labels( names=["property"], - values=_dispatch.int_array_like(list(range(shape[-1])), array).reshape(-1, 1), + values=_dispatch.int_array_like( + list(range(shape[-1])), labels_array_like + ).reshape(-1, 1), ) device = _dispatch.get_device(array) diff --git a/python/metatensor-operations/tests/block_from_array.py b/python/metatensor-operations/tests/block_from_array.py index 65998843e..7e44749a7 100644 --- a/python/metatensor-operations/tests/block_from_array.py +++ b/python/metatensor-operations/tests/block_from_array.py @@ -4,6 +4,14 @@ import metatensor +try: + import torch + + HAS_TORCH = True +except ImportError: + HAS_TORCH = False + + @pytest.mark.parametrize("n_axes", [0, 1]) def test_too_few_axes(n_axes): """Test block_from_array when too few axes are provided.""" @@ -50,3 +58,28 @@ def test_with_components(): np.testing.assert_equal( block.properties.values, np.arange(array.shape[2]).reshape((-1, 1)) ) + + +@pytest.mark.skipif(not HAS_TORCH, reason="requires torch") +def test_torch_with_components(): + """Test block_from_array with components and torch arrays""" + array = array = torch.zeros((6, 5, 7)) + block = metatensor.block_from_array(array) + assert block.values is array + + assert block.samples.names == ["sample"] + np.testing.assert_equal( + block.samples.values, np.arange(array.shape[0]).reshape((-1, 1)) + ) + + assert len(block.components) == 1 + component = block.components[0] + assert component.names == ["component_1"] + np.testing.assert_equal( + component.values, np.arange(array.shape[1]).reshape((-1, 1)) + ) + + assert block.properties.names == ["property"] + np.testing.assert_equal( + block.properties.values, np.arange(array.shape[2]).reshape((-1, 1)) + )