diff --git a/python/metatensor-operations/metatensor/operations/_dispatch.py b/python/metatensor-operations/metatensor/operations/_dispatch.py index 8befd6554..1f6f4efff 100644 --- a/python/metatensor-operations/metatensor/operations/_dispatch.py +++ b/python/metatensor-operations/metatensor/operations/_dispatch.py @@ -148,7 +148,9 @@ def argsort_labels_values(labels_values, reverse: bool = False): # so we temporary do this trick. this will be fixed with issue #366 max_int = torch.max(labels_values) idx = torch.sum( - max_int ** torch.arange(labels_values.shape[1]) * labels_values, dim=1 + max_int ** torch.arange(labels_values.shape[1]) + * labels_values.flip(dims=[1]), + dim=1, ) return torch.argsort(idx, dim=-1, descending=reverse) elif isinstance(labels_values, np.ndarray): diff --git a/python/metatensor-torch/tests/operations/sort.py b/python/metatensor-torch/tests/operations/sort.py index 36a4c765e..79782680c 100644 --- a/python/metatensor-torch/tests/operations/sort.py +++ b/python/metatensor-torch/tests/operations/sort.py @@ -305,3 +305,125 @@ def test_sort_descending(tensor, tensor_sorted_descending): tensor_sorted_descending.block(0), metatensor.torch.sort_block(tensor.block(0), descending=True), ) + + +def test_high_numb(): + tensor = TensorMap( + keys=Labels( + names=["a", "b"], + values=torch.tensor([[2, 1], [1, 0]]), + ), + blocks=[ + TensorBlock( + values=torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=torch.float32), + samples=Labels( + names=["s1", "s2", "s3"], + values=torch.tensor( + [ + [0, 1, 2], + [2, 3, 4], + [1, 5, 7], + ] + ), + ), + components=[], + properties=Labels( + names=["p1", "p2"], + values=torch.tensor( + [ + [100, 0], + [5, 7000], + ] + ), + ), + ), + TensorBlock( + values=torch.tensor( + [[2.2, 3.1, 4.1], [2.2, 1.1, 2.1], [2.2, 5.1, 6.1]] + ), + samples=Labels( + names=["s1", "s2", "s3"], + values=torch.tensor( + [ + [0, 2, 2], + [0, 1, 2], + [1, 5, 7], + ] + ), + ), + components=[], + properties=Labels( + names=["p1", "p2"], + values=torch.tensor( + [ + [5, 10], + [5, 5], + [5, 6], + ] + ), + ), + ), + ], + ) + + tensor_order = TensorMap( + keys=Labels( + names=["a", "b"], + values=torch.tensor([[1, 0], [2, 1]]), + ), + blocks=[ + TensorBlock( + values=torch.tensor( + [[1.1, 2.1, 2.2], [3.1, 4.1, 2.2], [5.1, 6.1, 2.2]] + ), + samples=Labels( + names=["s1", "s2", "s3"], + values=torch.tensor( + [ + [0, 1, 2], + [0, 2, 2], + [1, 5, 7], + ] + ), + ), + components=[], + properties=Labels( + names=["p1", "p2"], + values=torch.tensor( + [ + [5, 5], + [5, 6], + [5, 10], + ] + ), + ), + ), + TensorBlock( + values=torch.tensor([[2, 1], [6, 5],[4, 3]], dtype=torch.float32), + samples=Labels( + names=["s1", "s2", "s3"], + values=torch.tensor( + [ + [0, 1, 2], + [1, 5, 7], + [2, 3, 4], + + ] + ), + ), + components=[], + properties=Labels( + names=["p1", "p2"], + values=torch.tensor( + [ + [5, 7000], + [100, 0], + + ] + ), + ), + ), + ], + ) + sorted = metatensor.torch.sort(tensor) + metatensor.torch.allclose_raise(sorted, tensor_order)