Skip to content

Commit

Permalink
added comments
Browse files Browse the repository at this point in the history
  • Loading branch information
DavideTisi committed Jun 11, 2024
1 parent 9644fcb commit c99f627
Showing 1 changed file with 16 additions and 0 deletions.
16 changes: 16 additions & 0 deletions python/metatensor-operations/metatensor/operations/_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,13 @@ def argsort_labels_values(labels_values, reverse: bool = False):
"""
if isinstance(labels_values, TorchTensor):
if labels_values.shape[1] == 1:
# Index is appended at the end to get
# the indices corresponding to the
# sorted values. Because we append the indices at
# the end and since metadata is unique, we do not affect
# the sorted order. Using `torch.argsort` is not an option
# because only on Tensors and sorting a Tensor is
# different than sorting a List(Tuple)
data = [(int(row[0]), i) for i, row in enumerate(labels_values)]
return torch.tensor(
[i[-1] for i in sort_list_2(data, reverse=reverse)],
Expand Down Expand Up @@ -881,6 +888,7 @@ def zeros_like(array, shape: Optional[List[int]] = None, requires_grad: bool = F
def sort_list_2(
data: List[Tuple[int, int]], reverse: bool = False
) -> List[Tuple[int, int]]:
# Using the `reverse` flag of sorted does not work with TorchScript
if reverse:
return list(sorted(data))[::-1]
else:
Expand All @@ -891,6 +899,7 @@ def sort_list_2(
def sort_list_3(
data: List[Tuple[int, int, int]], reverse: bool = False
) -> List[Tuple[int, int, int]]:
# Using the `reverse` flag of sorted does not work with TorchScript
if reverse:
return list(sorted(data))[::-1]
else:
Expand All @@ -901,6 +910,7 @@ def sort_list_3(
def sort_list_4(
data: List[Tuple[int, int, int, int]], reverse: bool = False
) -> List[Tuple[int, int, int, int]]:
# Using the `reverse` flag of sorted does not work with TorchScript
if reverse:
return list(sorted(data))[::-1]
else:
Expand All @@ -911,6 +921,7 @@ def sort_list_4(
def sort_list_5(
data: List[Tuple[int, int, int, int, int]], reverse: bool = False
) -> List[Tuple[int, int, int, int, int]]:
# Using the `reverse` flag of sorted does not work with TorchScript
if reverse:
return list(sorted(data))[::-1]
else:
Expand All @@ -921,6 +932,7 @@ def sort_list_5(
def sort_list_6(
data: List[Tuple[int, int, int, int, int, int]], reverse: bool = False
) -> List[Tuple[int, int, int, int, int, int]]:
# Using the `reverse` flag of sorted does not work with TorchScript
if reverse:
return list(sorted(data))[::-1]
else:
Expand All @@ -931,6 +943,7 @@ def sort_list_6(
def sort_list_7(
data: List[Tuple[int, int, int, int, int, int, int]], reverse: bool = False
) -> List[Tuple[int, int, int, int, int, int, int]]:
# Using the `reverse` flag of sorted does not work with TorchScript
if reverse:
return list(sorted(data))[::-1]
else:
Expand All @@ -941,6 +954,7 @@ def sort_list_7(
def sort_list_8(
data: List[Tuple[int, int, int, int, int, int, int, int]], reverse: bool = False
) -> List[Tuple[int, int, int, int, int, int, int, int]]:
# Using the `reverse` flag of sorted does not work with TorchScript
if reverse:
return list(sorted(data))[::-1]
else:
Expand All @@ -952,6 +966,7 @@ def sort_list_9(
data: List[Tuple[int, int, int, int, int, int, int, int, int]],
reverse: bool = False,
) -> List[Tuple[int, int, int, int, int, int, int, int, int]]:
# Using the `reverse` flag of sorted does not work with TorchScript
if reverse:
return list(sorted(data))[::-1]
else:
Expand All @@ -963,6 +978,7 @@ def sort_list_10(
data: List[Tuple[int, int, int, int, int, int, int, int, int, int]],
reverse: bool = False,
) -> List[Tuple[int, int, int, int, int, int, int, int, int, int]]:
# Using the `reverse` flag of sorted does not work with TorchScript
if reverse:
return list(sorted(data))[::-1]
else:
Expand Down

0 comments on commit c99f627

Please sign in to comment.