Skip to content

Commit

Permalink
attempted fix for sort
Browse files Browse the repository at this point in the history
  • Loading branch information
DavideTisi committed Jun 5, 2024
1 parent 2095cb1 commit 10ae329
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,9 @@ def argsort_labels_values(labels_values, reverse: bool = False):
# so we temporary do this trick. this will be fixed with issue #366
max_int = torch.max(labels_values)
idx = torch.sum(
max_int ** torch.arange(labels_values.shape[1]) * labels_values, dim=1
max_int ** torch.arange(labels_values.shape[1])
* labels_values.flip(dims=[1]),
dim=1,
)
return torch.argsort(idx, dim=-1, descending=reverse)
elif isinstance(labels_values, np.ndarray):
Expand Down
122 changes: 122 additions & 0 deletions python/metatensor-torch/tests/operations/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,3 +305,125 @@ def test_sort_descending(tensor, tensor_sorted_descending):
tensor_sorted_descending.block(0),
metatensor.torch.sort_block(tensor.block(0), descending=True),
)


def test_high_numb():
tensor = TensorMap(
keys=Labels(
names=["a", "b"],
values=torch.tensor([[2, 1], [1, 0]]),
),
blocks=[
TensorBlock(
values=torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=torch.float32),
samples=Labels(
names=["s1", "s2", "s3"],
values=torch.tensor(
[
[0, 1, 2],
[2, 3, 4],
[1, 5, 7],
]
),
),
components=[],
properties=Labels(
names=["p1", "p2"],
values=torch.tensor(
[
[100, 0],
[5, 7000],
]
),
),
),
TensorBlock(
values=torch.tensor(
[[2.2, 3.1, 4.1], [2.2, 1.1, 2.1], [2.2, 5.1, 6.1]]
),
samples=Labels(
names=["s1", "s2", "s3"],
values=torch.tensor(
[
[0, 2, 2],
[0, 1, 2],
[1, 5, 7],
]
),
),
components=[],
properties=Labels(
names=["p1", "p2"],
values=torch.tensor(
[
[5, 10],
[5, 5],
[5, 6],
]
),
),
),
],
)

tensor_order = TensorMap(
keys=Labels(
names=["a", "b"],
values=torch.tensor([[1, 0], [2, 1]]),
),
blocks=[
TensorBlock(
values=torch.tensor(
[[1.1, 2.1, 2.2], [3.1, 4.1, 2.2], [5.1, 6.1, 2.2]]
),
samples=Labels(
names=["s1", "s2", "s3"],
values=torch.tensor(
[
[0, 1, 2],
[0, 2, 2],
[1, 5, 7],
]
),
),
components=[],
properties=Labels(
names=["p1", "p2"],
values=torch.tensor(
[
[5, 5],
[5, 6],
[5, 10],
]
),
),
),
TensorBlock(
values=torch.tensor([[2, 1], [6, 5],[4, 3]], dtype=torch.float32),
samples=Labels(
names=["s1", "s2", "s3"],
values=torch.tensor(
[
[0, 1, 2],
[1, 5, 7],
[2, 3, 4],

]
),
),
components=[],
properties=Labels(
names=["p1", "p2"],
values=torch.tensor(
[
[5, 7000],
[100, 0],

]
),
),
),
],
)
sorted = metatensor.torch.sort(tensor)
metatensor.torch.allclose_raise(sorted, tensor_order)

0 comments on commit 10ae329

Please sign in to comment.