Skip to content

Commit

Permalink
Vectorised percentile/quantile
Browse files Browse the repository at this point in the history
  • Loading branch information
Dhruvanshu-Joshi committed Oct 15, 2024
1 parent 17e6aa3 commit dcdeef7
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 66 deletions.
118 changes: 52 additions & 66 deletions pytensor/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -2926,99 +2926,85 @@ def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
return x


def quantile(input, q, axis=None):
def quantile(x: TensorLike, q: float | list[float], axis=None) -> TensorVariable:
"""
Computes the quantile along the given axis(es) of a tensor `input` using linear interpolation.
Computes the q-th quantile along the given axis(es) of a tensor `input`.
Parameters
----------
input: TensorVariable
x: TensorVariable
The input tensor.
q: float or list of floats
Quantile or sequence of quantiles to compute, which must be between 0 and 1 inclusive.
axis: None or int or list of int, optional
Axis or axes along which the quantiles are computed. The default is to compute the quantile(s) along a flattened version of the array.
q: float or list of float
The quantile(s) to compute, which must be between 0 and 1 inclusive.
0 corresponds to the minimum, 0.5 to the median, and 1 to the maximum.
axis: None or int or (list of int) (see `Sum`)
Compute the quantile along this axis of the tensor.
None means all axes (like numpy).
"""
x = as_tensor_variable(input)
x_ndim = x.type.ndim

x = as_tensor_variable(x)
q = as_tensor_variable(q)
x_ndim = x.type.ndim
if axis is None:
axis = list(range(x_ndim))
elif isinstance(axis, (int | np.integer)):
axis = [axis]
elif isinstance(axis, np.ndarray) and axis.ndim == 0:
axis = [int(axis)]
else:
axis = [int(a) for a in axis]

# Compute the shape of the remaining axes
new_axes_order = [i for i in range(x.ndim) if i not in axis] + list(axis)
x = x.dimshuffle(new_axes_order)
input_shape = shape(x)
remaining_axis_size = input_shape[: x.ndim - len(axis)]
x = x.reshape((*remaining_axis_size, -1))

# Sort the input tensor along the specified axis
sorted_input = x.sort(axis=-1)
slices1 = [slice(None)] * sorted_input.ndim
slices2 = [slice(None)] * sorted_input.ndim
input_shape = x.shape[-1]
axis = list(normalize_axis_tuple(axis, x_ndim))

if isinstance(q, (int | float)):
q = [q]
non_axis = [i for i in range(x_ndim) if i not in axis]
non_axis_shape = [x.shape[i] for i in non_axis]

for quantile in q:
if quantile < 0 or quantile > 1:
raise ValueError("Quantiles must be in the range [0, 1]")
# Put axis at the end and unravel them
x_raveled = x.transpose(*non_axis, *axis)
if len(axis) > 1:
x_raveled = x_raveled.reshape((*non_axis_shape, -1))
raveled_size = x_raveled.shape[-1]

result = []
for quantile in q:
k = (quantile) * (input_shape - 1)
k_floor = floor(k).astype("int64")
k_ceil = ceil(k).astype("int64")
# Ensure q is between 0 and 1
q = clip(q, 0.0, 1.0)

slices1[-1] = slice(k_floor, k_floor + 1)
slices2[-1] = slice(k_ceil, k_ceil + 1)
val1 = sorted_input[tuple(slices1)]
val2 = sorted_input[tuple(slices2)]
# Compute quantile indices
k = (q * (raveled_size - 1)).astype("int64")
k_float = q * (raveled_size - 1)

d = k - k_floor
quantile_val = val1 + d * (val2 - val1)
# Sort the input tensor along the specified axis
x_sorted = x_raveled.sort(axis=-1)

result.append(quantile_val.squeeze(axis=-1))
# Get the values at index k and k + 1 for linear interpolation
k_values = x_sorted[..., k]
kp1_values = x_sorted[..., minimum(k + 1, raveled_size - 1)]

if len(result) == 1:
result = result[0]
else:
result = stack(result)
# Interpolation between the two values if needed
frac = k_float - k.astype(k_float.dtype)
quantile_value = (1 - frac) * k_values + frac * kp1_values

result.name = "quantile"
return result
return quantile_value


def percentile(input, q, axis=None):
def percentile(x: TensorLike, p: float | list[float], axis=None) -> TensorVariable:
"""
Computes the percentile along the given axis(es) of a tensor `input` using linear interpolation.
Computes the p-th percentile along the given axis(es) of a tensor `input`.
Parameters
----------
input: TensorVariable
x: TensorVariable
The input tensor.
q: float or list of floats
Percentile or sequence of percentiles to compute, which must be between 0 and 100 inclusive.
axis: None or int or list of int, optional
Axis or axes along which the percentiles are computed. The default is to compute the percentile(s) along a flattened version of the array.
"""
if isinstance(q, (int | float)):
q = [q]

for percentile in q:
if percentile < 0 or percentile > 100:
raise ValueError("Percentiles must be in the range [0, 100]")
p: float or list of float
The percentile(s) to compute, which must be between 0 and 100 inclusive.
0 corresponds to the minimum, 50 to the median, and 100 to the maximum.
axis: None or int or (list of int) (see `Sum`)
Compute the percentile along this axis of the tensor.
None means all axes (like numpy).
quantiles = [x / 100 for x in q]
Returns
-------
TensorVariable
The computed percentile values.
"""
# Convert percentiles (0-100) to quantiles (0-1)
q = as_tensor_variable(p) / 100.0

return quantile(input, quantiles, axis)
# Call the quantile function
return quantile(x, q, axis=axis)


# NumPy logical aliases
Expand Down
1 change: 1 addition & 0 deletions tests/tensor/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -3805,6 +3805,7 @@ def test_percentile(ndim, axis, q):
(3, None, [0.25, 0.75]),
(3, 0, 0.5),
(3, (1, 2), 0.5),
(3, (1, 2), [0.1, 0.9]),
],
)
def test_quantile(ndim, axis, q):
Expand Down

0 comments on commit dcdeef7

Please sign in to comment.