From 5ebb12f42a9df081fd6add2bde82c55696b0b916 Mon Sep 17 00:00:00 2001 From: Davide Tisi Date: Tue, 11 Jun 2024 16:10:03 +0200 Subject: [PATCH] Fix the bug with `metatensor.torch.sort` (#647) --- .../metatensor/operations/_dispatch.py | 245 +++++++++++++++- .../metatensor-torch/tests/operations/sort.py | 262 ++++++++++++++++++ 2 files changed, 498 insertions(+), 9 deletions(-) diff --git a/python/metatensor-operations/metatensor/operations/_dispatch.py b/python/metatensor-operations/metatensor/operations/_dispatch.py index 5905fc323..f020a3c17 100644 --- a/python/metatensor-operations/metatensor/operations/_dispatch.py +++ b/python/metatensor-operations/metatensor/operations/_dispatch.py @@ -1,10 +1,10 @@ import re import warnings -from typing import List, Optional, Union +from typing import List, Optional, Tuple, Union import numpy as np -from ._backend import torch_jit_is_scripting +from ._backend import torch_jit_is_scripting, torch_jit_script def parse_version(version): @@ -124,13 +124,139 @@ def argsort_labels_values(labels_values, reverse: bool = False): :return: indices corresponding to the sorted values in ``labels_values`` """ if isinstance(labels_values, TorchTensor): - # torchscript does not support sorted for List[List[int]] - # so we temporary do this trick. this will be fixed with issue #366 - max_int = torch.max(labels_values) - idx = torch.sum( - max_int ** torch.arange(labels_values.shape[1]) * labels_values, dim=1 - ) - return torch.argsort(idx, dim=-1, descending=reverse) + if labels_values.shape[1] == 1: + # Index is appended at the end to get + # the indices corresponding to the + # sorted values. Because we append the indices at + # the end and since metadata is unique, we do not affect + # the sorted order. Using `torch.argsort` is not an option + # because only on Tensors and sorting a Tensor is + # different than sorting a List(Tuple) + data = [(int(row[0]), i) for i, row in enumerate(labels_values)] + return torch.tensor( + [i[-1] for i in sort_list_2(data, reverse=reverse)], + dtype=torch.int64, + device=labels_values.device, + ) + if labels_values.shape[1] == 2: + data = [ + (int(row[0]), int(row[1]), i) for i, row in enumerate(labels_values) + ] + return torch.tensor( + [i[-1] for i in sort_list_3(data, reverse=reverse)], + dtype=torch.int64, + device=labels_values.device, + ) + if labels_values.shape[1] == 3: + data = [ + (int(row[0]), int(row[1]), int(row[2]), i) + for i, row in enumerate(labels_values) + ] + return torch.tensor( + [i[-1] for i in sort_list_4(data, reverse=reverse)], + dtype=torch.int64, + device=labels_values.device, + ) + if labels_values.shape[1] == 4: + data = [ + (int(row[0]), int(row[1]), int(row[2]), int(row[3]), i) + for i, row in enumerate(labels_values) + ] + return torch.tensor( + [i[-1] for i in sort_list_5(data, reverse=reverse)], + dtype=torch.int64, + device=labels_values.device, + ) + if labels_values.shape[1] == 5: + data = [ + (int(row[0]), int(row[1]), int(row[2]), int(row[3]), int(row[4]), i) + for i, row in enumerate(labels_values) + ] + return torch.tensor( + [i[-1] for i in sort_list_6(data, reverse=reverse)], + dtype=torch.int64, + device=labels_values.device, + ) + if labels_values.shape[1] == 6: + data = [ + ( + int(row[0]), + int(row[1]), + int(row[2]), + int(row[3]), + int(row[4]), + int(row[5]), + i, + ) + for i, row in enumerate(labels_values) + ] + return torch.tensor( + [i[-1] for i in sort_list_7(data, reverse=reverse)], + dtype=torch.int64, + device=labels_values.device, + ) + if labels_values.shape[1] == 7: + data = [ + ( + int(row[0]), + int(row[1]), + int(row[2]), + int(row[3]), + int(row[4]), + int(row[5]), + int(row[6]), + i, + ) + for i, row in enumerate(labels_values) + ] + return torch.tensor( + [i[-1] for i in sort_list_8(data, reverse=reverse)], + dtype=torch.int64, + device=labels_values.device, + ) + if labels_values.shape[1] == 8: + data = [ + ( + int(row[0]), + int(row[1]), + int(row[2]), + int(row[3]), + int(row[4]), + int(row[5]), + int(row[6]), + int(row[7]), + i, + ) + for i, row in enumerate(labels_values) + ] + return torch.tensor( + [i[-1] for i in sort_list_9(data, reverse=reverse)], + dtype=torch.int64, + device=labels_values.device, + ) + if labels_values.shape[1] == 9: + data = [ + ( + int(row[0]), + int(row[1]), + int(row[2]), + int(row[3]), + int(row[4]), + int(row[5]), + int(row[6]), + int(row[7]), + int(row[8]), + i, + ) + for i, row in enumerate(labels_values) + ] + return torch.tensor( + [i[-1] for i in sort_list_10(data, reverse=reverse)], + dtype=torch.int64, + device=labels_values.device, + ) + else: + raise Exception("labels_values.shape[1]> 9 is not supported") elif isinstance(labels_values, np.ndarray): # Index is appended at the end to get the indices corresponding to the # sorted values. Because we append the indices at the end and since metadata @@ -756,3 +882,104 @@ def zeros_like(array, shape: Optional[List[int]] = None, requires_grad: bool = F return np.zeros_like(array, shape=shape, subok=False) else: raise TypeError(UNKNOWN_ARRAY_TYPE) + + +@torch_jit_script +def sort_list_2( + data: List[Tuple[int, int]], reverse: bool = False +) -> List[Tuple[int, int]]: + # Using the `reverse` flag of sorted does not work with TorchScript + if reverse: + return list(sorted(data))[::-1] + else: + return list(sorted(data)) + + +@torch_jit_script +def sort_list_3( + data: List[Tuple[int, int, int]], reverse: bool = False +) -> List[Tuple[int, int, int]]: + # Using the `reverse` flag of sorted does not work with TorchScript + if reverse: + return list(sorted(data))[::-1] + else: + return list(sorted(data)) + + +@torch_jit_script +def sort_list_4( + data: List[Tuple[int, int, int, int]], reverse: bool = False +) -> List[Tuple[int, int, int, int]]: + # Using the `reverse` flag of sorted does not work with TorchScript + if reverse: + return list(sorted(data))[::-1] + else: + return list(sorted(data)) + + +@torch_jit_script +def sort_list_5( + data: List[Tuple[int, int, int, int, int]], reverse: bool = False +) -> List[Tuple[int, int, int, int, int]]: + # Using the `reverse` flag of sorted does not work with TorchScript + if reverse: + return list(sorted(data))[::-1] + else: + return list(sorted(data)) + + +@torch_jit_script +def sort_list_6( + data: List[Tuple[int, int, int, int, int, int]], reverse: bool = False +) -> List[Tuple[int, int, int, int, int, int]]: + # Using the `reverse` flag of sorted does not work with TorchScript + if reverse: + return list(sorted(data))[::-1] + else: + return list(sorted(data)) + + +@torch_jit_script +def sort_list_7( + data: List[Tuple[int, int, int, int, int, int, int]], reverse: bool = False +) -> List[Tuple[int, int, int, int, int, int, int]]: + # Using the `reverse` flag of sorted does not work with TorchScript + if reverse: + return list(sorted(data))[::-1] + else: + return list(sorted(data)) + + +@torch_jit_script +def sort_list_8( + data: List[Tuple[int, int, int, int, int, int, int, int]], reverse: bool = False +) -> List[Tuple[int, int, int, int, int, int, int, int]]: + # Using the `reverse` flag of sorted does not work with TorchScript + if reverse: + return list(sorted(data))[::-1] + else: + return list(sorted(data)) + + +@torch_jit_script +def sort_list_9( + data: List[Tuple[int, int, int, int, int, int, int, int, int]], + reverse: bool = False, +) -> List[Tuple[int, int, int, int, int, int, int, int, int]]: + # Using the `reverse` flag of sorted does not work with TorchScript + if reverse: + return list(sorted(data))[::-1] + else: + return list(sorted(data)) + + +@torch_jit_script +def sort_list_10( + data: List[Tuple[int, int, int, int, int, int, int, int, int, int]], + reverse: bool = False, +) -> List[Tuple[int, int, int, int, int, int, int, int, int, int]]: + # Using the `reverse` flag of sorted does not work with TorchScript + if reverse: + return list(sorted(data))[::-1] + else: + return list(sorted(data)) diff --git a/python/metatensor-torch/tests/operations/sort.py b/python/metatensor-torch/tests/operations/sort.py index 1ed4bb53e..35d423372 100644 --- a/python/metatensor-torch/tests/operations/sort.py +++ b/python/metatensor-torch/tests/operations/sort.py @@ -1,9 +1,11 @@ import io +import pytest import torch from packaging import version import metatensor.torch +from metatensor.torch import Labels, TensorBlock, TensorMap from ._data import load_data @@ -24,3 +26,263 @@ def test_save(): torch.jit.save(metatensor.torch.sort, buffer) buffer.seek(0) torch.jit.load(buffer) + + +@pytest.fixture +def tensor(): + # samples are descending, components and properties are ascending + block_1 = TensorBlock( + values=torch.tensor([[3, 5], [1, 2]]), + samples=Labels(["s"], torch.tensor([[2], [0]])), + components=[], + properties=Labels(["p"], torch.tensor([[0], [1]])), + ) + + block_1.add_gradient( + parameter="g", + gradient=TensorBlock( + values=torch.tensor([[[8, 3], [9, 4]], [[6, 1], [7, 2]]]), + samples=Labels(["sample", "g"], torch.tensor([[1, 1], [0, 1]])), + components=[ + Labels(["c"], torch.tensor([[0], [1]])), + ], + properties=block_1.properties, + ), + ) + + # samples are disordered, components are ascending, properties are descending + block_2 = TensorBlock( + values=torch.tensor([[3, 4], [5, 6], [1, 2]]), + samples=Labels(["s"], torch.tensor([[7], [0], [2]])), + components=[], + properties=Labels(["p"], torch.tensor([[1], [0]])), + ) + block_2.add_gradient( + parameter="g", + gradient=TensorBlock( + values=torch.tensor( + [[[15, 14], [11, 10]], [[13, 12], [15, 14]], [[11, 10], [13, 12]]] + ), + samples=Labels( + ["sample", "g"], + torch.tensor([[1, 1], [2, 1], [0, 1]]), + ), + components=[ + Labels(["c"], torch.tensor([[0], [1]])), + ], + properties=block_2.properties, + ), + ) + keys = Labels(names=["key_1", "key_2"], values=torch.tensor([[1, 0], [0, 0]])) + # block order is descending + return TensorMap(keys, [block_2, block_1]) + + +@pytest.fixture +def tensor_sorted_ascending(): + """ + This is the `tensor` fixture sorted in ascending order how it should be returned + when applying metatensor.operations.sort with `descending=False` option. + """ + block_1 = TensorBlock( + values=torch.tensor([[1, 2], [3, 5]]), + samples=Labels(["s"], torch.tensor([[0], [2]])), + components=[], + properties=Labels(["p"], torch.tensor([[0], [1]])), + ) + + block_1.add_gradient( + parameter="g", + gradient=TensorBlock( + values=torch.tensor([[[8, 3], [9, 4]], [[6, 1], [7, 2]]]), + samples=Labels( + ["sample", "g"], + torch.tensor([[0, 1], [1, 1]]), + ), + components=[ + Labels(["c"], torch.tensor([[0], [1]])), + ], + properties=block_1.properties, + ), + ) + block_2 = TensorBlock( + values=torch.tensor([[6, 5], [2, 1], [4, 3]]), + samples=Labels(["s"], torch.tensor([[0], [2], [7]])), + components=[], + properties=Labels(["p"], torch.tensor([[0], [1]])), + ) + + block_2.add_gradient( + parameter="g", + gradient=TensorBlock( + values=torch.tensor( + [[[14, 15], [10, 11]], [[12, 13], [14, 15]], [[10, 11], [12, 13]]] + ), + samples=Labels( + ["sample", "g"], + torch.tensor([[0, 1], [1, 1], [2, 1]]), + ), + components=[ + Labels(["c"], torch.tensor([[0], [1]])), + ], + properties=block_2.properties, + ), + ) + + keys = Labels(names=["key_1", "key_2"], values=torch.tensor([[0, 0], [1, 0]])) + return TensorMap(keys, [block_1, block_2]) + + +@pytest.fixture +def tensor_sorted_descending(): + """ + This is the `tensor` fixture sorted in descending order how it should be returned + when applying metatensor.operations.sort with `descending=True` option. + """ + block_1 = TensorBlock( + values=torch.tensor([[3, 5], [1, 2]]), + samples=Labels(["s"], torch.tensor([[2], [0]])), + components=[], + properties=Labels(["p"], torch.tensor([[1], [0]])), + ) + + block_1.add_gradient( + parameter="g", + gradient=TensorBlock( + values=torch.tensor([[[4, 9], [3, 8]], [[2, 7], [1, 6]]]), + samples=Labels( + ["sample", "g"], + torch.tensor([[1, 1], [0, 1]]), + ), + components=[ + Labels(["c"], torch.tensor([[1], [0]])), + ], + properties=block_1.properties, + ), + ) + + block_2 = TensorBlock( + values=torch.tensor([[3, 4], [1, 2], [5, 6]]), + samples=Labels(["s"], torch.tensor([[7], [2], [0]])), + components=[], + properties=Labels(["p"], torch.tensor([[1], [0]])), + ) + block_2.add_gradient( + parameter="g", + gradient=TensorBlock( + values=torch.tensor( + [[[11, 10], [15, 14]], [[15, 14], [13, 12]], [[13, 12], [11, 10]]] + ), + samples=Labels( + ["sample", "g"], + torch.tensor([[2, 1], [1, 1], [0, 1]]), + ), + components=[ + Labels(["c"], torch.tensor([[1], [0]])), + ], + properties=block_2.properties, + ), + ) + keys = Labels( + names=["key_1", "key_2"], + values=torch.tensor([[1, 0], [0, 0]]), + ) + return TensorMap(keys, [block_2, block_1]) + + +def test_sort_ascending(tensor, tensor_sorted_ascending): + metatensor.torch.allclose_block_raise( + metatensor.torch.sort_block(tensor.block(0)), tensor_sorted_ascending.block(1) + ) + metatensor.torch.allclose_block_raise( + metatensor.torch.sort_block(tensor.block(1)), tensor_sorted_ascending.block(0) + ) + + metatensor.torch.allclose_raise( + metatensor.torch.sort(tensor), tensor_sorted_ascending + ) + + +def test_sort_descending(tensor, tensor_sorted_descending): + metatensor.torch.allclose_block_raise( + tensor_sorted_descending.block(0), + metatensor.torch.sort_block(tensor.block(0), descending=True), + ) + metatensor.torch.allclose_block_raise( + tensor_sorted_descending.block(0), + metatensor.torch.sort_block(tensor.block(0), descending=True), + ) + + +def test_high_numb(): + tensor = TensorMap( + keys=Labels( + names=["a", "b"], + values=torch.tensor([[2, 1], [1, 0]]), + ), + blocks=[ + TensorBlock( + values=torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=torch.float32), + samples=Labels( + names=["s1", "s2", "s3"], + values=torch.tensor([[0, 1, 2], [2, 3, 4], [1, 5, 7]]), + ), + components=[], + properties=Labels( + names=["p1", "p2"], + values=torch.tensor([[100, 0], [5, 7000]]), + ), + ), + TensorBlock( + values=torch.tensor( + [[2.2, 3.1, 4.1], [2.2, 1.1, 2.1], [2.2, 5.1, 6.1]] + ), + samples=Labels( + names=["s1", "s2", "s3"], + values=torch.tensor([[0, 2, 2], [0, 1, 2], [1, 5, 7]]), + ), + components=[], + properties=Labels( + names=["p1", "p2"], + values=torch.tensor([[5, 10], [5, 5], [5, 6]]), + ), + ), + ], + ) + + tensor_order = TensorMap( + keys=Labels( + names=["a", "b"], + values=torch.tensor([[1, 0], [2, 1]]), + ), + blocks=[ + TensorBlock( + values=torch.tensor( + [[1.1, 2.1, 2.2], [3.1, 4.1, 2.2], [5.1, 6.1, 2.2]] + ), + samples=Labels( + names=["s1", "s2", "s3"], + values=torch.tensor([[0, 1, 2], [0, 2, 2], [1, 5, 7]]), + ), + components=[], + properties=Labels( + names=["p1", "p2"], + values=torch.tensor([[5, 5], [5, 6], [5, 10]]), + ), + ), + TensorBlock( + values=torch.tensor([[2, 1], [6, 5], [4, 3]], dtype=torch.float32), + samples=Labels( + names=["s1", "s2", "s3"], + values=torch.tensor([[0, 1, 2], [1, 5, 7], [2, 3, 4]]), + ), + components=[], + properties=Labels( + names=["p1", "p2"], + values=torch.tensor([[5, 7000], [100, 0]]), + ), + ), + ], + ) + sorted = metatensor.torch.sort(tensor) + metatensor.torch.allclose_raise(sorted, tensor_order)