Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/sort by #644

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions python/metatensor-operations/metatensor/operations/_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,26 @@ def allclose(
raise TypeError(UNKNOWN_ARRAY_TYPE)


def argsort(values, axis: int = -1, reverse: bool = False):
"""
Similar to :py:func:`np.argsort`.

:param labels_values: numpy.array or torch.Tensor
:param reverse: if true, order is descending

:return: indices corresponding to the sorted values in ``labels_values``
"""
if isinstance(values, TorchTensor):
return torch.argsort(values, dim=axis, descending=reverse)
elif isinstance(values, np.ndarray):
tmp = np.argsort(values, axis=axis)
if reverse:
tmp = tmp[::-1]
return tmp
else:
raise TypeError(UNKNOWN_ARRAY_TYPE)


def argsort_labels_values(labels_values, reverse: bool = False):
"""
Similar to :py:func:`np.argsort`, but sort the rows as one aggregated
Expand Down
120 changes: 107 additions & 13 deletions python/metatensor-operations/metatensor/operations/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def _sort_single_gradient_block(
gradient_block: TensorBlock,
axes: List[str],
descending: bool,
name: str = "-1",
) -> TensorBlock:
"""
Sorts a single gradient tensor block given the tensor block which the gradients are
Expand Down Expand Up @@ -43,9 +44,20 @@ def _sort_single_gradient_block(
# the parent block
block_sample_values = block.samples.values
# sample index -> sample labels
sorted_idx = _dispatch.argsort_labels_values(
block_sample_values, reverse=descending
)
if name == "-1":
sorted_idx = _dispatch.argsort_labels_values(
block_sample_values, reverse=descending
)
else:
axis = 0
for iv, v in enumerate(block.samples.names):
if v == name:
axis = iv
break
sorted_idx = _dispatch.argsort(
block_sample_values[:, axis], reverse=descending
)

# obtain inverse mapping sample labels -> sample index
sorted_idx_inverse = _dispatch.empty_like(sorted_idx, shape=(len(sorted_idx),))
sorted_idx_inverse[sorted_idx] = _dispatch.int_array_like(
Expand Down Expand Up @@ -95,13 +107,16 @@ def _sort_single_block(
block: TensorBlock,
axes: List[str],
descending: bool,
name: str = "-1",
) -> TensorBlock:
"""
Sorts a single TensorBlock without the user input checking and sorting of gradients
"""

sample_names = block.samples.names
sample_values = block.samples.values
if name != "-1" and name not in sample_names:
raise ValueError("`name` must be or '-1' or one of the sample names")

component_names: List[List[str]] = []
components_values = []
Expand All @@ -114,20 +129,50 @@ def _sort_single_block(

values = block.values
if "samples" in axes:
sorted_idx = _dispatch.argsort_labels_values(sample_values, reverse=descending)
if name == "-1":
sorted_idx = _dispatch.argsort_labels_values(
sample_values, reverse=descending
)
else:
axis = 0
for iv, v in enumerate(sample_names):
if v == name:
axis = iv
break
sorted_idx = _dispatch.argsort(sample_values[:, axis], reverse=descending)
sample_values = sample_values[sorted_idx]
values = values[sorted_idx]
if "components" in axes:
for i, _ in enumerate(block.components):
sorted_idx = _dispatch.argsort_labels_values(
components_values[i], reverse=descending
)
if name == "-1":
sorted_idx = _dispatch.argsort_labels_values(
components_values[i], reverse=descending
)
else:
axis = 0
for ic, c in enumerate(component_names[i]):
if c == name:
axis = ic
break
sorted_idx = _dispatch.argsort(
components_values[i][:, axis], reverse=descending
)
components_values[i] = components_values[i][sorted_idx]
values = _dispatch.take(values, sorted_idx, axis=i + 1)
if "properties" in axes:
sorted_idx = _dispatch.argsort_labels_values(
properties_values, reverse=descending
)
if name == "-1":
sorted_idx = _dispatch.argsort_labels_values(
properties_values, reverse=descending
)
else:
axis = 0
for ip, p in enumerate(property_names):
if p == name:
axis = ip
break
sorted_idx = _dispatch.argsort(
properties_values[:, axis], reverse=descending
)
properties_values = properties_values[sorted_idx]
values = _dispatch.take(values, sorted_idx, axis=-1)

Expand All @@ -151,6 +196,7 @@ def sort_block(
block: TensorBlock,
axes: Union[str, List[str]] = "all",
descending: bool = False,
name: str = "-1",
) -> TensorBlock:
"""
Rearrange the values of a block according to the order given by the sorted metadata
Expand All @@ -165,6 +211,9 @@ def sort_block(

:param descending: if false, the order is ascending

:param name: name of `axes` to be used for the sorting. Default == "-1"
means sort along the last axis

:return: sorted tensor block

>>> import numpy as np
Expand Down Expand Up @@ -226,14 +275,43 @@ def sort_block(
<BLANKLINE>
[[ 6, 7, 8],
[ 9, 10, 11]]])

>>> # You can also choose along which axis of "samples“ you sort
>>> block2 = TensorBlock(
... values=np.arange(12).reshape(4, 3),
... samples=Labels(
... ["system", "atom"], np.array([[0, 2], [1, 0], [2, 5], [2, 1]])
... ),
... components=[],
... properties=Labels(["n", "l"], np.array([[2, 0], [3, 0], [1, 0]])),
... )
>>> block_sorted_2_sample = metatensor.sort_block(
... block2, axes=["samples"], name="atom"
... )
>>> # samples (first dimension of the array) are sorted
>>> block_sorted_2_sample.values
array([[ 3, 4, 5],
[ 9, 10, 11],
[ 0, 1, 2],
[ 6, 7, 8]])
"""
if isinstance(axes, str):
if axes == "all":
axes_list = ["samples", "components", "properties"]
if name != "-1":
raise ValueError(
"'name' is allowed only if 'axes' is one of"
"'samples', 'components','properties' but"
"'axes'=='all'"
)
else:
axes_list = [axes]
elif isinstance(axes, list):
if len(axes) > 1 and name != "-1":
raise ValueError(
"'name' is allowed only if 'axes' is one of"
"'samples', 'components','properties' but"
"'axes' is a List"
)
axes_list = axes
else:
if torch_jit_is_scripting():
Expand All @@ -250,7 +328,7 @@ def sort_block(
f"not '{axis}'"
)

result_block = _sort_single_block(block, axes_list, descending)
result_block = _sort_single_block(block, axes_list, descending, name)

for parameter, gradient in block.gradients():
if len(gradient.gradients_list()) != 0:
Expand All @@ -259,7 +337,7 @@ def sort_block(
result_block.add_gradient(
parameter=parameter,
gradient=_sort_single_gradient_block(
block, gradient, axes_list, descending
block, gradient, axes_list, descending, name
),
)

Expand All @@ -271,6 +349,7 @@ def sort(
tensor: TensorMap,
axes: Union[str, List[str]] = "all",
descending: bool = False,
name: str = "-1",
) -> TensorMap:
"""
Sort the ``tensor`` according to the key values and the blocks for each specified
Expand All @@ -285,6 +364,8 @@ def sort(
Possible values are ``'keys'``, ``'samples'``, ``'components'``,
``'properties'`` and ``'all'`` to sort everything.
:param descending: if false, the order is ascending
:param name: name of `axes` to be used for the sorting. Default == "-1"
means sort along the last axis
:return: sorted tensor map

>>> import numpy as np
Expand Down Expand Up @@ -316,6 +397,12 @@ def sort(
if axes == "all":
axes_list = ["samples", "components", "properties"]
sort_keys = True
if name != "-1":
raise ValueError(
"'name' is allowed only if 'axes' is one of"
"'samples', 'components','properties' but"
"'axes'=='all'"
)
elif axes == "keys":
axes_list = torch_jit_annotate(List[str], [])
sort_keys = True
Expand All @@ -324,6 +411,12 @@ def sort(
sort_keys = False

elif isinstance(axes, list):
if len(axes) > 1 and name != "-1":
raise ValueError(
"'name' is allowed only if 'axes' is one of"
"'samples', 'components','properties' but"
"'axes' is a List"
)
axes_list = axes

if "keys" in axes_list:
Expand Down Expand Up @@ -365,6 +458,7 @@ def sort(
block=tensor.block(tensor.keys[int(i)]),
axes=axes_list,
descending=descending,
name=name,
)
)

Expand Down
Loading