Skip to content

Commit

Permalink
Support axis parameter in numpy.argsort
Browse files Browse the repository at this point in the history
Related to #2013
  • Loading branch information
serge-sans-paille committed Sep 28, 2022
1 parent d22a954 commit 1a517ac
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 15 deletions.
7 changes: 6 additions & 1 deletion pythran/pythonic/include/numpy/argsort.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,13 @@ PYTHONIC_NS_BEGIN

namespace numpy
{
template <class E>
types::ndarray<long, types::array<long, 1>> argsort(E const &expr,
types::none_type);

template <class T, class pS>
types::ndarray<long, pS> argsort(types::ndarray<T, pS> const &a);
types::ndarray<long, pS> argsort(types::ndarray<T, pS> const &a,
long axis = -1);

NUMPY_EXPR_TO_NDARRAY0_DECL(argsort);

Expand Down
68 changes: 54 additions & 14 deletions pythran/pythonic/numpy/argsort.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,64 @@ PYTHONIC_NS_BEGIN

namespace numpy
{
template <class E>
types::ndarray<long, types::array<long, 1>> argsort(E const &expr,
types::none_type)
{
auto out = functor::array{}(expr).flat();
return argsort(out);
}

template <class T, class pS>
types::ndarray<long, pS> argsort(types::ndarray<T, pS> const &a)
types::ndarray<long, pS> argsort(types::ndarray<T, pS> const &a, long axis)
{
constexpr auto N = std::tuple_size<pS>::value;
size_t last_axis = a.template shape<N - 1>();
size_t n = a.flat_size();
if (axis < 0)
axis += N;

long const flat_size = a.flat_size();
types::ndarray<long, pS> indices(a._shape, builtins::None);
for (long j = 0, *iter_indices = indices.buffer,
*end_indices = indices.buffer + n;
iter_indices != end_indices;
iter_indices += last_axis, j += last_axis) {
// fill with the original indices
std::iota(iter_indices, iter_indices + last_axis, 0L);
// sort the index using the value from a
pdqsort(iter_indices, iter_indices + last_axis,
[&a, j](long i1, long i2) {
return *(a.fbegin() + j + i1) < *(a.fbegin() + j + i2);
});
if (axis == N - 1) {
size_t step = a.template shape<N - 1>();

auto a_base = a.fbegin();
for (long *iter_indices = indices.buffer,
*end_indices = indices.buffer + flat_size;
iter_indices != end_indices; iter_indices += step, a_base += step) {
// fill with the original indices
std::iota(iter_indices, iter_indices + step, 0L);
// sort the index using the value from a
pdqsort(iter_indices, iter_indices + step,
[a_base](long i1, long i2) { return a_base[i1] < a_base[i2]; });
}
} else {
auto out_shape = sutils::getshape(a);
const long step =
std::accumulate(out_shape.begin() + axis, out_shape.end(), 1L,
std::multiplies<long>());
long const buffer_size = out_shape[axis];
const long stepper = step / out_shape[axis];
const long n = flat_size / out_shape[axis];
long ith = 0, nth = 0;
std::unique_ptr<long[]> buffer{new long[buffer_size]};
long *buffer_start = buffer.get(),
*buffer_end = buffer.get() + buffer_size;
std::iota(buffer_start, buffer_end, 0L);
for (long i = 0; i < n; i++) {
auto a_base = a.fbegin() + ith;
pdqsort(buffer.get(), buffer.get() + buffer_size,
[a_base, stepper](long i1, long i2) {
return a_base[i1 * stepper] < a_base[i2 * stepper];
});

for (long j = 0; j < buffer_size; ++j)
indices.buffer[ith + j * stepper] = buffer[j];

ith = step;
if (ith >= flat_size) {
ith = ++nth;
}
}
}
return indices;
}
Expand Down
9 changes: 9 additions & 0 deletions pythran/tests/test_numpy_func2.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,15 @@ def test_argsort0(self):
def test_argsort1(self):
self.run_test("def np_argsort1(x): return x.argsort()", numpy.array([[3, 1, 2], [1 , 2, 3]]), np_argsort1=[NDArray[int,:,:]])

def test_argsort2(self):
self.run_test("def np_argsort2(x): return x.argsort(axis=0)", numpy.array([[3, 1, 2], [1 , 2, 3]]), np_argsort2=[NDArray[int,:,:]])

def test_argsort3(self):
self.run_test("def np_argsort3(x): return x.argsort(axis=1)", numpy.array([[3, 1, 2], [1 , 2, 3]]), np_argsort3=[NDArray[int,:,:]])

def test_argsort4(self):
self.run_test("def np_argsort4(x): return x.argsort(axis=None)", numpy.array([[3, 1, 2], [1 , 2, 3]]), np_argsort4=[NDArray[int,:,:]])

def test_argmax0(self):
self.run_test("def np_argmax0(a): return a.argmax()", numpy.arange(6).reshape(2,3), np_argmax0=[NDArray[int,:,:]])

Expand Down

0 comments on commit 1a517ac

Please sign in to comment.