Skip to content

Commit

Permalink
fix block_from_array for torch array backend (metatensor-core)
Browse files Browse the repository at this point in the history
  • Loading branch information
agoscinski committed Feb 5, 2024
1 parent 30acbf1 commit fe3b83c
Showing 1 changed file with 35 additions and 6 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
import numpy as np

from . import _dispatch
from ._classes import Labels, TensorBlock
from ._classes import Labels, TensorBlock, check_isinstance, torch_jit_is_scripting


try:
import torch

HAS_TORCH = True
except ImportError:
HAS_TORCH = False


def block_from_array(array) -> TensorBlock:
Expand Down Expand Up @@ -49,22 +59,41 @@ def block_from_array(array) -> TensorBlock:
must have at least two dimensions. Too few provided: {n_dimensions}"
)

if torch_jit_is_scripting():
# for metatensor-torch the array backend is always torch
labels_array_like = torch.empty(0)
else:
# This will return true only if we are using metatensor-torch
# please refer to function definition in files
# metatensor-core python/metatensor-operations/metatensor/operations/_classes.py
# metatensor-torch python/metatensor-torch/metatensor/torch/operations.py
if check_isinstance(None, Labels):
# for metatensor-torch the array backend is always torch
labels_array_like = torch.empty(0)
else:
# for metatensor-core the array backend is always numpy
labels_array_like = np.empty(0)

samples = Labels(
names=["sample"],
values=_dispatch.int_array_like(list(range(shape[0])), array).reshape(-1, 1),
values=_dispatch.int_array_like(
list(range(shape[0])), labels_array_like
).reshape(-1, 1),
)
components = [
Labels(
names=[f"component_{component_index+1}"],
values=_dispatch.int_array_like(list(range(axis_size)), array).reshape(
-1, 1
),
values=_dispatch.int_array_like(
list(range(axis_size)), labels_array_like
).reshape(-1, 1),
)
for component_index, axis_size in enumerate(shape[1:-1])
]
properties = Labels(
names=["property"],
values=_dispatch.int_array_like(list(range(shape[-1])), array).reshape(-1, 1),
values=_dispatch.int_array_like(
list(range(shape[-1])), labels_array_like
).reshape(-1, 1),
)

device = _dispatch.get_device(array)
Expand Down

0 comments on commit fe3b83c

Please sign in to comment.