diff --git a/python/metatensor-operations/metatensor/operations/_dispatch.py b/python/metatensor-operations/metatensor/operations/_dispatch.py index f020a3c17..03d9ac087 100644 --- a/python/metatensor-operations/metatensor/operations/_dispatch.py +++ b/python/metatensor-operations/metatensor/operations/_dispatch.py @@ -113,6 +113,26 @@ def allclose( raise TypeError(UNKNOWN_ARRAY_TYPE) +def argsort(values, axis: int = -1, reverse: bool = False): + """ + Similar to :py:func:`np.argsort`. + + :param labels_values: numpy.array or torch.Tensor + :param reverse: if true, order is descending + + :return: indices corresponding to the sorted values in ``labels_values`` + """ + if isinstance(values, TorchTensor): + return torch.argsort(values, dim=axis, descending=reverse) + elif isinstance(values, np.ndarray): + tmp = np.argsort(values, axis=axis) + if reverse: + tmp = tmp[::-1] + return tmp + else: + raise TypeError(UNKNOWN_ARRAY_TYPE) + + def argsort_labels_values(labels_values, reverse: bool = False): """ Similar to :py:func:`np.argsort`, but sort the rows as one aggregated diff --git a/python/metatensor-operations/metatensor/operations/sort.py b/python/metatensor-operations/metatensor/operations/sort.py index 9eceeb4db..010fb6736 100644 --- a/python/metatensor-operations/metatensor/operations/sort.py +++ b/python/metatensor-operations/metatensor/operations/sort.py @@ -16,6 +16,7 @@ def _sort_single_gradient_block( gradient_block: TensorBlock, axes: List[str], descending: bool, + name: str = "-1", ) -> TensorBlock: """ Sorts a single gradient tensor block given the tensor block which the gradients are @@ -43,9 +44,20 @@ def _sort_single_gradient_block( # the parent block block_sample_values = block.samples.values # sample index -> sample labels - sorted_idx = _dispatch.argsort_labels_values( - block_sample_values, reverse=descending - ) + if name == "-1": + sorted_idx = _dispatch.argsort_labels_values( + block_sample_values, reverse=descending + ) + else: + axis = 0 + for iv, v in enumerate(block.samples.names): + if v == name: + axis = iv + break + sorted_idx = _dispatch.argsort( + block_sample_values[:, axis], reverse=descending + ) + # obtain inverse mapping sample labels -> sample index sorted_idx_inverse = _dispatch.empty_like(sorted_idx, shape=(len(sorted_idx),)) sorted_idx_inverse[sorted_idx] = _dispatch.int_array_like( @@ -95,6 +107,7 @@ def _sort_single_block( block: TensorBlock, axes: List[str], descending: bool, + name: str = "-1", ) -> TensorBlock: """ Sorts a single TensorBlock without the user input checking and sorting of gradients @@ -102,6 +115,8 @@ def _sort_single_block( sample_names = block.samples.names sample_values = block.samples.values + if name != "-1" and name not in sample_names: + raise ValueError("`name` must be or '-1' or one of the sample names") component_names: List[List[str]] = [] components_values = [] @@ -114,20 +129,50 @@ def _sort_single_block( values = block.values if "samples" in axes: - sorted_idx = _dispatch.argsort_labels_values(sample_values, reverse=descending) + if name == "-1": + sorted_idx = _dispatch.argsort_labels_values( + sample_values, reverse=descending + ) + else: + axis = 0 + for iv, v in enumerate(sample_names): + if v == name: + axis = iv + break + sorted_idx = _dispatch.argsort(sample_values[:, axis], reverse=descending) sample_values = sample_values[sorted_idx] values = values[sorted_idx] if "components" in axes: for i, _ in enumerate(block.components): - sorted_idx = _dispatch.argsort_labels_values( - components_values[i], reverse=descending - ) + if name == "-1": + sorted_idx = _dispatch.argsort_labels_values( + components_values[i], reverse=descending + ) + else: + axis = 0 + for ic, c in enumerate(component_names[i]): + if c == name: + axis = ic + break + sorted_idx = _dispatch.argsort( + components_values[i][:, axis], reverse=descending + ) components_values[i] = components_values[i][sorted_idx] values = _dispatch.take(values, sorted_idx, axis=i + 1) if "properties" in axes: - sorted_idx = _dispatch.argsort_labels_values( - properties_values, reverse=descending - ) + if name == "-1": + sorted_idx = _dispatch.argsort_labels_values( + properties_values, reverse=descending + ) + else: + axis = 0 + for ip, p in enumerate(property_names): + if p == name: + axis = ip + break + sorted_idx = _dispatch.argsort( + properties_values[:, axis], reverse=descending + ) properties_values = properties_values[sorted_idx] values = _dispatch.take(values, sorted_idx, axis=-1) @@ -151,6 +196,7 @@ def sort_block( block: TensorBlock, axes: Union[str, List[str]] = "all", descending: bool = False, + name: str = "-1", ) -> TensorBlock: """ Rearrange the values of a block according to the order given by the sorted metadata @@ -165,6 +211,9 @@ def sort_block( :param descending: if false, the order is ascending + :param name: name of `axes` to be used for the sorting. Default == "-1" + means sort along the last axis + :return: sorted tensor block >>> import numpy as np @@ -226,14 +275,43 @@ def sort_block( [[ 6, 7, 8], [ 9, 10, 11]]]) - + >>> # You can also choose along which axis of "samples“ you sort + >>> block2 = TensorBlock( + ... values=np.arange(12).reshape(4, 3), + ... samples=Labels( + ... ["system", "atom"], np.array([[0, 2], [1, 0], [2, 5], [2, 1]]) + ... ), + ... components=[], + ... properties=Labels(["n", "l"], np.array([[2, 0], [3, 0], [1, 0]])), + ... ) + >>> block_sorted_2_sample = metatensor.sort_block( + ... block2, axes=["samples"], name="atom" + ... ) + >>> # samples (first dimension of the array) are sorted + >>> block_sorted_2_sample.values + array([[ 3, 4, 5], + [ 9, 10, 11], + [ 0, 1, 2], + [ 6, 7, 8]]) """ if isinstance(axes, str): if axes == "all": axes_list = ["samples", "components", "properties"] + if name != "-1": + raise ValueError( + "'name' is allowed only if 'axes' is one of" + "'samples', 'components','properties' but" + "'axes'=='all'" + ) else: axes_list = [axes] elif isinstance(axes, list): + if len(axes) > 1 and name != "-1": + raise ValueError( + "'name' is allowed only if 'axes' is one of" + "'samples', 'components','properties' but" + "'axes' is a List" + ) axes_list = axes else: if torch_jit_is_scripting(): @@ -250,7 +328,7 @@ def sort_block( f"not '{axis}'" ) - result_block = _sort_single_block(block, axes_list, descending) + result_block = _sort_single_block(block, axes_list, descending, name) for parameter, gradient in block.gradients(): if len(gradient.gradients_list()) != 0: @@ -259,7 +337,7 @@ def sort_block( result_block.add_gradient( parameter=parameter, gradient=_sort_single_gradient_block( - block, gradient, axes_list, descending + block, gradient, axes_list, descending, name ), ) @@ -271,6 +349,7 @@ def sort( tensor: TensorMap, axes: Union[str, List[str]] = "all", descending: bool = False, + name: str = "-1", ) -> TensorMap: """ Sort the ``tensor`` according to the key values and the blocks for each specified @@ -285,6 +364,8 @@ def sort( Possible values are ``'keys'``, ``'samples'``, ``'components'``, ``'properties'`` and ``'all'`` to sort everything. :param descending: if false, the order is ascending + :param name: name of `axes` to be used for the sorting. Default == "-1" + means sort along the last axis :return: sorted tensor map >>> import numpy as np @@ -316,6 +397,12 @@ def sort( if axes == "all": axes_list = ["samples", "components", "properties"] sort_keys = True + if name != "-1": + raise ValueError( + "'name' is allowed only if 'axes' is one of" + "'samples', 'components','properties' but" + "'axes'=='all'" + ) elif axes == "keys": axes_list = torch_jit_annotate(List[str], []) sort_keys = True @@ -324,6 +411,12 @@ def sort( sort_keys = False elif isinstance(axes, list): + if len(axes) > 1 and name != "-1": + raise ValueError( + "'name' is allowed only if 'axes' is one of" + "'samples', 'components','properties' but" + "'axes' is a List" + ) axes_list = axes if "keys" in axes_list: @@ -365,6 +458,7 @@ def sort( block=tensor.block(tensor.keys[int(i)]), axes=axes_list, descending=descending, + name=name, ) ) diff --git a/python/metatensor-operations/tests/sort.py b/python/metatensor-operations/tests/sort.py index 1d26e42f8..a15c57b7d 100644 --- a/python/metatensor-operations/tests/sort.py +++ b/python/metatensor-operations/tests/sort.py @@ -260,6 +260,144 @@ def tensor_sorted_descending(): return TensorMap(keys, [block_2, block_1]) +@pytest.fixture +def tensor_two_samples(): + # samples are descending, components and properties are ascending + block_1 = TensorBlock( + values=np.array( + [ + [3, 5], + [1, 2], + [-1, -2], + [11, 22], + [41, 42], + ] + ), + samples=Labels( + ["s", "a"], np.array([[2, 5], [0, 0], [0, 9], [0, 1], [0, 110]]) + ), + components=[], + properties=Labels(["p"], np.array([[0], [1]])), + ) + block_1.add_gradient( + parameter="g", + gradient=TensorBlock( + values=np.array( + [ + [[1, 2], [3, 4]], + [[5, 6], [7, 8]], + [[9, 10], [11, 12]], + [[13, 14], [15, 16]], + [[17, 18], [19, 20]], + [[21, 22], [23, 24]], + [[25, 26], [27, 28]], + ] + ), + samples=Labels( + ["sample", "g"], + np.array( + [ + [2, 1], + [1, 1], + [0, 1], + [0, 2], + [3, 1], + [4, 1], + [4, 0], + ] + ), + ), + components=[ + Labels(["c"], np.array([[1], [0]])), + ], + properties=block_1.properties, + ), + ) + keys = Labels(names=["key_1"], values=np.array([[0]])) + # block order is descending + return TensorMap(keys, [block_1]) + + +@pytest.fixture +def tensor_two_samples_ascending_a(): + block_1 = TensorBlock( + values=np.array( + [ + [1, 2], + [11, 22], + [3, 5], + [-1, -2], + [41, 42], + ] + ), + samples=Labels( + ["s", "a"], np.array([[0, 0], [0, 1], [2, 5], [0, 9], [0, 110]]) + ), + components=[], + properties=Labels(["p"], np.array([[0], [1]])), + ) + block_1.add_gradient( + parameter="g", + gradient=TensorBlock( + values=np.array( + [ + [[5, 6], [7, 8]], + [[17, 18], [19, 20]], + [[9, 10], [11, 12]], + [[13, 14], [15, 16]], + [[1, 2], [3, 4]], + [[25, 26], [27, 28]], + [[21, 22], [23, 24]], + ] + ), + samples=Labels( + ["sample", "g"], + np.array( + [ + [0, 1], + [1, 1], + [2, 1], + [2, 2], + [3, 1], + [4, 0], + [4, 1], + ] + ), + ), + components=[ + Labels(["c"], np.array([[1], [0]])), + ], + properties=block_1.properties, + ), + ) + keys = Labels(names=["key_1"], values=np.array([[0]])) + # block order is descending + return TensorMap(keys, [block_1]) + + +@pytest.fixture +def tensor_two_samples_descending_a(): + block_1 = TensorBlock( + values=np.array( + [ + [41, 42], + [-1, -2], + [3, 5], + [11, 22], + [1, 2], + ] + ), + samples=Labels( + ["s", "a"], np.array([[0, 110], [0, 9], [2, 5], [0, 1], [0, 0]]) + ), + components=[], + properties=Labels(["p"], np.array([[0], [1]])), + ) + keys = Labels(names=["key_1"], values=np.array([[0]])) + # block order is descending + return TensorMap(keys, [block_1]) + + def test_sort(tensor, tensor_sorted_ascending): metatensor.allclose_block_raise( metatensor.sort_block(tensor.block(0)), tensor_sorted_ascending.block(1) @@ -294,3 +432,44 @@ def test_raise_error(tensor, tensor_sorted_ascending): ) with pytest.raises(ValueError, match=error_message): metatensor.operations.sort(tensor, axes=[5]) + + error_message = ( + "'name' is allowed only if 'axes' is one of" + "'samples', 'components','properties' but" + "'axes' is a List" + ) + with pytest.raises(ValueError, match=error_message): + metatensor.operations.sort(tensor, axes=["samples", "components"], name="s") + + error_message = ( + "'name' is allowed only if 'axes' is one of" + "'samples', 'components','properties' but" + "'axes'=='all'" + ) + with pytest.raises(ValueError, match=error_message): + metatensor.operations.sort(tensor, axes="all", name="s") + + +def test_sort_two_sample(tensor_two_samples, tensor_two_samples_ascending_a): + print( + "ii", + metatensor.sort_block(tensor_two_samples.block(0), axes="samples", name="a") + .gradient("g") + .samples, + ) + print("jj", tensor_two_samples_ascending_a.block(0).gradient("g").samples) + + metatensor.allclose_block_raise( + metatensor.sort_block(tensor_two_samples.block(0), axes="samples", name="a"), + tensor_two_samples_ascending_a.block(0), + ) + + +def test_sort_two_sample_descending( + tensor_two_samples, tensor_two_samples_descending_a +): + t = metatensor.remove_gradients(tensor_two_samples) + metatensor.allclose_block_raise( + metatensor.sort_block(t.block(0), axes="samples", name="a", descending=True), + tensor_two_samples_descending_a.block(0), + )