Skip to content

Commit

Permalink
Add tests for percentile and quantile
Browse files Browse the repository at this point in the history
  • Loading branch information
Dhruvanshu-Joshi committed Jul 19, 2024
1 parent 15a3718 commit cc8d4d9
Showing 1 changed file with 68 additions and 0 deletions.
68 changes: 68 additions & 0 deletions tests/tensor/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,11 @@
neg,
neq,
outer,
percentile,
polygamma,
power,
ptp,
quantile,
rad2deg,
reciprocal,
round_half_away_from_zero,
Expand Down Expand Up @@ -3732,3 +3734,69 @@ def test_nan_to_num(nan, posinf, neginf):
out,
np.nan_to_num(y, nan=nan, posinf=posinf, neginf=neginf),
)


@pytest.mark.parametrize(
"ndim, axis, q",
[
(2, None, 50),
(2, 1, 33),
(2, (0, 1), 50),
(3, (1, 2), 50),
(4, (1, 3, 0), 25),
(2, None, [25, 50, 75]),
(3, (1, 2), [10, 90]),
(3, 1, 75),
(3, 0, 50),
],
)
def test_percentile(ndim, axis, q):
shape = tuple(np.arange(1, ndim + 1))
data = np.random.rand(*shape)
x = tensor(shape=np.array(data).shape)
f = function([x], percentile(x, q, axis=axis))
result = f(data.astype(x.dtype))
expected = np.percentile(data.astype(x.dtype), q, axis=axis)
assert np.allclose(result, expected)


@pytest.mark.parametrize(
"ndim, axis, q",
[
(2, None, 0.5),
(2, None, [0.25, 0.75]),
(2, 0, 0.5),
(2, (0, 1), 0.5),
(3, None, 0.5),
(3, None, [0.25, 0.75]),
(3, 0, 0.5),
(3, (1, 2), 0.5),
],
)
def test_quantile(ndim, axis, q):
shape = tuple(np.random.randint(2, 6) for _ in range(ndim))
data = np.random.rand(*shape)

x = tensor(dtype="float64", shape=(None,) * ndim)
f = function([x], quantile(x, q, axis=axis))

result = f(data.astype(x.dtype))
expected = np.quantile(data.astype(x.dtype), q, axis=axis)

assert np.allclose(result, expected)


@pytest.mark.parametrize(
"ndim, axis, q, is_percentile",
[
(2, None, [50, 120], True),
(2, 1, -0.5, False),
],
)
def test_invalid_percentile_quamtile(ndim, axis, q, is_percentile):
x = tensor(dtype="float64", shape=(None,) * ndim)
with pytest.raises(ValueError):
if is_percentile:
percentile(x, q, axis)
else:
quantile(x, q, axis)

0 comments on commit cc8d4d9

Please sign in to comment.