Skip to content

Commit

Permalink
Fix the bug with metatensor.torch.sort (#647)
Browse files Browse the repository at this point in the history
  • Loading branch information
DavideTisi committed Jun 11, 2024
1 parent 18975d2 commit 5ebb12f
Show file tree
Hide file tree
Showing 2 changed files with 498 additions and 9 deletions.
245 changes: 236 additions & 9 deletions python/metatensor-operations/metatensor/operations/_dispatch.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import re
import warnings
from typing import List, Optional, Union
from typing import List, Optional, Tuple, Union

import numpy as np

from ._backend import torch_jit_is_scripting
from ._backend import torch_jit_is_scripting, torch_jit_script


def parse_version(version):
Expand Down Expand Up @@ -124,13 +124,139 @@ def argsort_labels_values(labels_values, reverse: bool = False):
:return: indices corresponding to the sorted values in ``labels_values``
"""
if isinstance(labels_values, TorchTensor):
# torchscript does not support sorted for List[List[int]]
# so we temporary do this trick. this will be fixed with issue #366
max_int = torch.max(labels_values)
idx = torch.sum(
max_int ** torch.arange(labels_values.shape[1]) * labels_values, dim=1
)
return torch.argsort(idx, dim=-1, descending=reverse)
if labels_values.shape[1] == 1:
# Index is appended at the end to get
# the indices corresponding to the
# sorted values. Because we append the indices at
# the end and since metadata is unique, we do not affect
# the sorted order. Using `torch.argsort` is not an option
# because only on Tensors and sorting a Tensor is
# different than sorting a List(Tuple)
data = [(int(row[0]), i) for i, row in enumerate(labels_values)]
return torch.tensor(
[i[-1] for i in sort_list_2(data, reverse=reverse)],
dtype=torch.int64,
device=labels_values.device,
)
if labels_values.shape[1] == 2:
data = [
(int(row[0]), int(row[1]), i) for i, row in enumerate(labels_values)
]
return torch.tensor(
[i[-1] for i in sort_list_3(data, reverse=reverse)],
dtype=torch.int64,
device=labels_values.device,
)
if labels_values.shape[1] == 3:
data = [
(int(row[0]), int(row[1]), int(row[2]), i)
for i, row in enumerate(labels_values)
]
return torch.tensor(
[i[-1] for i in sort_list_4(data, reverse=reverse)],
dtype=torch.int64,
device=labels_values.device,
)
if labels_values.shape[1] == 4:
data = [
(int(row[0]), int(row[1]), int(row[2]), int(row[3]), i)
for i, row in enumerate(labels_values)
]
return torch.tensor(
[i[-1] for i in sort_list_5(data, reverse=reverse)],
dtype=torch.int64,
device=labels_values.device,
)
if labels_values.shape[1] == 5:
data = [
(int(row[0]), int(row[1]), int(row[2]), int(row[3]), int(row[4]), i)
for i, row in enumerate(labels_values)
]
return torch.tensor(
[i[-1] for i in sort_list_6(data, reverse=reverse)],
dtype=torch.int64,
device=labels_values.device,
)
if labels_values.shape[1] == 6:
data = [
(
int(row[0]),
int(row[1]),
int(row[2]),
int(row[3]),
int(row[4]),
int(row[5]),
i,
)
for i, row in enumerate(labels_values)
]
return torch.tensor(
[i[-1] for i in sort_list_7(data, reverse=reverse)],
dtype=torch.int64,
device=labels_values.device,
)
if labels_values.shape[1] == 7:
data = [
(
int(row[0]),
int(row[1]),
int(row[2]),
int(row[3]),
int(row[4]),
int(row[5]),
int(row[6]),
i,
)
for i, row in enumerate(labels_values)
]
return torch.tensor(
[i[-1] for i in sort_list_8(data, reverse=reverse)],
dtype=torch.int64,
device=labels_values.device,
)
if labels_values.shape[1] == 8:
data = [
(
int(row[0]),
int(row[1]),
int(row[2]),
int(row[3]),
int(row[4]),
int(row[5]),
int(row[6]),
int(row[7]),
i,
)
for i, row in enumerate(labels_values)
]
return torch.tensor(
[i[-1] for i in sort_list_9(data, reverse=reverse)],
dtype=torch.int64,
device=labels_values.device,
)
if labels_values.shape[1] == 9:
data = [
(
int(row[0]),
int(row[1]),
int(row[2]),
int(row[3]),
int(row[4]),
int(row[5]),
int(row[6]),
int(row[7]),
int(row[8]),
i,
)
for i, row in enumerate(labels_values)
]
return torch.tensor(
[i[-1] for i in sort_list_10(data, reverse=reverse)],
dtype=torch.int64,
device=labels_values.device,
)
else:
raise Exception("labels_values.shape[1]> 9 is not supported")
elif isinstance(labels_values, np.ndarray):
# Index is appended at the end to get the indices corresponding to the
# sorted values. Because we append the indices at the end and since metadata
Expand Down Expand Up @@ -756,3 +882,104 @@ def zeros_like(array, shape: Optional[List[int]] = None, requires_grad: bool = F
return np.zeros_like(array, shape=shape, subok=False)
else:
raise TypeError(UNKNOWN_ARRAY_TYPE)


@torch_jit_script
def sort_list_2(
data: List[Tuple[int, int]], reverse: bool = False
) -> List[Tuple[int, int]]:
# Using the `reverse` flag of sorted does not work with TorchScript
if reverse:
return list(sorted(data))[::-1]
else:
return list(sorted(data))


@torch_jit_script
def sort_list_3(
data: List[Tuple[int, int, int]], reverse: bool = False
) -> List[Tuple[int, int, int]]:
# Using the `reverse` flag of sorted does not work with TorchScript
if reverse:
return list(sorted(data))[::-1]
else:
return list(sorted(data))


@torch_jit_script
def sort_list_4(
data: List[Tuple[int, int, int, int]], reverse: bool = False
) -> List[Tuple[int, int, int, int]]:
# Using the `reverse` flag of sorted does not work with TorchScript
if reverse:
return list(sorted(data))[::-1]
else:
return list(sorted(data))


@torch_jit_script
def sort_list_5(
data: List[Tuple[int, int, int, int, int]], reverse: bool = False
) -> List[Tuple[int, int, int, int, int]]:
# Using the `reverse` flag of sorted does not work with TorchScript
if reverse:
return list(sorted(data))[::-1]
else:
return list(sorted(data))


@torch_jit_script
def sort_list_6(
data: List[Tuple[int, int, int, int, int, int]], reverse: bool = False
) -> List[Tuple[int, int, int, int, int, int]]:
# Using the `reverse` flag of sorted does not work with TorchScript
if reverse:
return list(sorted(data))[::-1]
else:
return list(sorted(data))


@torch_jit_script
def sort_list_7(
data: List[Tuple[int, int, int, int, int, int, int]], reverse: bool = False
) -> List[Tuple[int, int, int, int, int, int, int]]:
# Using the `reverse` flag of sorted does not work with TorchScript
if reverse:
return list(sorted(data))[::-1]
else:
return list(sorted(data))


@torch_jit_script
def sort_list_8(
data: List[Tuple[int, int, int, int, int, int, int, int]], reverse: bool = False
) -> List[Tuple[int, int, int, int, int, int, int, int]]:
# Using the `reverse` flag of sorted does not work with TorchScript
if reverse:
return list(sorted(data))[::-1]
else:
return list(sorted(data))


@torch_jit_script
def sort_list_9(
data: List[Tuple[int, int, int, int, int, int, int, int, int]],
reverse: bool = False,
) -> List[Tuple[int, int, int, int, int, int, int, int, int]]:
# Using the `reverse` flag of sorted does not work with TorchScript
if reverse:
return list(sorted(data))[::-1]
else:
return list(sorted(data))


@torch_jit_script
def sort_list_10(
data: List[Tuple[int, int, int, int, int, int, int, int, int, int]],
reverse: bool = False,
) -> List[Tuple[int, int, int, int, int, int, int, int, int, int]]:
# Using the `reverse` flag of sorted does not work with TorchScript
if reverse:
return list(sorted(data))[::-1]
else:
return list(sorted(data))
Loading

0 comments on commit 5ebb12f

Please sign in to comment.