Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Ordered distribution factory #7603

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pymc/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@
WishartBartlett,
ZeroSumNormal,
)
from pymc.distributions.ordered import Ordered
from pymc.distributions.simulator import Simulator
from pymc.distributions.timeseries import (
AR,
Expand Down Expand Up @@ -178,6 +179,7 @@
"NegativeBinomial",
"Normal",
"NormalMixture",
"Ordered",
"OrderedLogistic",
"OrderedMultinomial",
"OrderedProbit",
Expand Down
6 changes: 2 additions & 4 deletions pymc/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -1239,8 +1239,7 @@ class OrderedLogistic:

# Ordered logistic regression
with pm.Model() as model:
cutpoints = pm.Normal("cutpoints", mu=[-1,1], sigma=10, shape=2,
transform=pm.distributions.transforms.ordered)
cutpoints = pm.Ordered("cutpoints", dist=pm.Normal.dist(mu=0, sigma=10), shape=2)
y_ = pm.OrderedLogistic("y", cutpoints=cutpoints, eta=x, observed=y)
idata = pm.sample()

Expand Down Expand Up @@ -1343,8 +1342,7 @@ class OrderedProbit:

# Ordered probit regression
with pm.Model() as model:
cutpoints = pm.Normal("cutpoints", mu=[-1,1], sigma=10, shape=2,
transform=pm.distributions.transforms.ordered)
cutpoints = pm.Ordered("cutpoints", dist=pm.Normal.dist(mu=0, sigma=10), shape=2)
y_ = pm.OrderedProbit("y", cutpoints=cutpoints, eta=x, observed=y)
idata = pm.sample()

Expand Down
133 changes: 133 additions & 0 deletions pymc/distributions/ordered.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# Copyright 2024 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytensor.tensor as pt

from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.utils import normalize_size_param
from pytensor.tensor.variable import TensorVariable

from pymc.distributions.distribution import (
Distribution,
SymbolicRandomVariable,
_support_point,
)
from pymc.distributions.shape_utils import change_dist_size, get_support_shape_1d, rv_size_is_none
from pymc.distributions.transforms import _default_transform, ordered


class OrderedRV(SymbolicRandomVariable):
inline_logprob = True
extended_signature = "(x)->(x)"
_print_name = ("Ordered", "\\operatorname{Ordered}")

@classmethod
def rv_op(cls, dist, *, size=None):
# We don't allow passing `rng` because we don't fully control the rng of the components!

size = normalize_size_param(size)

if not rv_size_is_none(size):
core_shape = tuple(dist.shape)[-1]
shape = (*tuple(size), core_shape)
dist = change_dist_size(dist, shape)

sorted_rv = pt.sort(dist, axis=-1)

return OrderedRV(
inputs=[dist],
outputs=[sorted_rv],
)(dist)


class Ordered(Distribution):
r"""Univariate IID Ordered distribution.

The pdf of the oredered distribution is

.. math::
f(x_1, ..., x_n) = n!\prod_{i=1}^n f(x_{(i)}),
where x_1 <= x2 <= ... <= x_n

Parameters
----------
dist: unnamed_distribution
Univariate IID distribution which will be sorted.

.. warning:: dist will be cloned, rendering it independent of the one passade as input

Examples
--------
.. code-block:: python
import pymc as pm

with pm.Model():
x = pm.Normal.dist(mu=0, sigma=1) # Must be IID
ordered_x = pm.Ordered("ordered_x", dist=x, shape=(3,))

pm.draw(ordered_x, random_seed=52) # array([0.05172346, 0.43970706, 0.91500416])
"""

rv_type = OrderedRV
rv_op = OrderedRV.rv_op

def __new__(cls, name, dist, *, support_shape=None, **kwargs):
support_shape = get_support_shape_1d(
support_shape=support_shape,
shape=None, # shape will be checked in `cls.dist`
dims=kwargs.get("dims", None),
observed=kwargs.get("observed", None),
)
return super().__new__(cls, name, dist, support_shape=support_shape, **kwargs)

@classmethod
def dist(cls, dist, *, support_shape=None, **kwargs):
if not isinstance(dist, TensorVariable) or not isinstance(
dist.owner.op, RandomVariable | SymbolicRandomVariable
):
raise ValueError(
f"Ordered dist must be a distribution created via the `.dist()` API, got {type(dist)}"
)
if dist.owner.op.ndim_supp > 0:
raise NotImplementedError("Ordering of multivariate distributions not supported")
if not all(
all(param.type.broadcastable) for param in dist.owner.op.dist_params(dist.owner)
):
raise ValueError("Ordered dist must be an IID variable")

support_shape = get_support_shape_1d(
support_shape=support_shape,
shape=kwargs.get("shape", None),
)
if support_shape is not None:
dist = change_dist_size(dist, support_shape)

dist = pt.atleast_1d(dist)

return super().dist([dist], **kwargs)


@_default_transform.register(OrderedRV)
def default_transform_ordered(op, rv):
if rv.type.dtype.startswith("float"):
return ordered
else:
return None


@_support_point.register(OrderedRV)
def support_point_ordered(op, rv, dist):
# FIXME: This does not work with the default ordered transform
# which maps [0, 0, 0] to [0, -inf, -inf].
# return support_point(dist)
Copy link
Member Author

@ricardoV94 ricardoV94 Dec 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@aseyboldt @lucianopaz any idea if we could easily modify the OrderedTransform to work nicely with equal values in the constrained space? The current one maps a [0, 0, 0] vector to [0, -inf, -inf]. My guess is not.

But then I don't have a nice way to define a valid support_point. I could try to add an increasing jitter but that will sooner or later fail for bounded distributions.

return rv # Draw from the prior
126 changes: 101 additions & 25 deletions pymc/logprob/order.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,21 +41,65 @@
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import node_rewriter
from pytensor.tensor.math import Max
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.sort import SortOp
from pytensor.tensor.variable import TensorVariable

from pymc.logprob.abstract import (
MeasurableElemwise,
MeasurableOp,
_logcdf_helper,
_logprob,
_logprob_helper,
)
from pymc.logprob.rewriting import measurable_ir_rewrites_db
from pymc.logprob.utils import filter_measurable_variables
from pymc.logprob.utils import (
CheckParameterValue,
check_potential_measurability,
filter_measurable_variables,
)
from pymc.math import logdiffexp
from pymc.pytensorf import constant_fold


def _underlying_iid_rv(variable) -> TensorVariable | None:
# Check whether an IID base RV is connected to the variable through identical elemwise operations
from pymc.distributions.distribution import SymbolicRandomVariable
from pymc.logprob.transforms import MeasurableTransform

def iid_elemwise_root(var: TensorVariable) -> TensorVariable | None:
node = var.owner
if isinstance(node.op, RandomVariable | SymbolicRandomVariable):
return var
elif isinstance(node.op, MeasurableTransform):
if len(node.inputs == 1):
return iid_elemwise_root(node.inputs[0])
else:
# If the non-measurable inputs are broadcasted, it is still an IID operation.
measurable_inp = node.op.measurable_input_idx
other_inputs = [inp for i, inp in node.inputs if i != measurable_inp]
if all(all(other_inp.type.broadcastable) for other_inp in other_inputs):
return iid_elemwise_root(node.inputs[measurable_inp])
return None

# Check that the root is a univariate distribution linked by only elemwise operations
latent_base_var = iid_elemwise_root(variable)

if latent_base_var is None:
return None

latent_op = latent_base_var.owner.op

if not (hasattr(latent_op, "dist_params") and getattr(latent_op, "ndim_supp") == 0):
return None

if not all(
all(params.type.broadcastable) for params in latent_op.dist_params(latent_base_var.owner)
):
return None

return cast(TensorVariable, latent_base_var)


class MeasurableMax(MeasurableOp, Max):
"""A placeholder used to specify a log-likelihood for a max sub-graph."""

Expand All @@ -77,31 +121,12 @@ def find_measurable_max(fgraph: FunctionGraph, node: Apply) -> list[TensorVariab
if not filter_measurable_variables(node.inputs):
return None

# We allow Max of RandomVariables or Elemwise of univariate RandomVariables
if isinstance(base_var.owner.op, MeasurableElemwise):
latent_base_vars = [
var
for var in base_var.owner.inputs
if (var.owner and isinstance(var.owner.op, MeasurableOp))
]
if len(latent_base_vars) != 1:
return None
[latent_base_var] = latent_base_vars
else:
latent_base_var = base_var

latent_op = latent_base_var.owner.op
if not (hasattr(latent_op, "dist_params") and getattr(latent_op, "ndim_supp") == 0):
return None
# We allow Max of RandomVariables or IID Elemwise of univariate RandomVariables
latent_base_var = _underlying_iid_rv(base_var)

# univariate i.i.d. test which also rules out other distributions
if not all(
all(params.type.broadcastable) for params in latent_op.dist_params(latent_base_var.owner)
):
if not latent_base_var:
return None

base_var = cast(TensorVariable, base_var)

if node.op.axis is None:
axis = tuple(range(base_var.ndim))
else:
Expand All @@ -119,7 +144,7 @@ def find_measurable_max(fgraph: FunctionGraph, node: Apply) -> list[TensorVariab


measurable_ir_rewrites_db.register(
"find_measurable_max",
find_measurable_max.__name__,
find_measurable_max,
"basic",
"max",
Expand Down Expand Up @@ -158,3 +183,54 @@ def max_logprob_discrete(op, values, base_rv, **kwargs):

n = pt.prod(base_rv_shape)
return logdiffexp(n * logcdf, n * logcdf_prev)


class MeasurableSort(MeasurableOp, SortOp):
"""A placeholder used to specify a log-likelihood for a sort sub-graph."""


@_logprob.register(MeasurableSort)
def sort_logprob(op, values, base_rv, axis, **kwargs):
r"""Compute the log-likelihood graph for the `Sort` operation."""
(value,) = values

logprob = _logprob_helper(base_rv, value).sum(axis=-1)

base_rv_shape = constant_fold(tuple(base_rv.shape), raise_not_constant=False)
n = pt.prod(base_rv_shape, axis=-1)
sorted_logp = pt.gammaln(n + 1) + logprob

# The sorted value is not really a parameter, but we include the check in
# `CheckParameterValue` to avoid costly sorting if `check_bounds=False` in a PyMC model
return CheckParameterValue("value must be sorted", can_be_replaced_by_ninf=True)(
sorted_logp, pt.eq(value, value.sort(axis=axis, kind=op.kind)).all()
)


@node_rewriter(tracks=[SortOp])
def find_measurable_sort(fgraph, node):
if isinstance(node.op, MeasurableSort):
return None

if not filter_measurable_variables(node.inputs):
return None

[base_var, axis] = node.inputs

# We allow Max of RandomVariables or IID Elemwise of univariate RandomVariables
if _underlying_iid_rv(base_var) is None:
return None

# Check axis is not potentially measurable
if check_potential_measurability([axis]):
return None

return [MeasurableSort(**node.op._props_dict())(base_var, axis)]


measurable_ir_rewrites_db.register(
find_measurable_sort.__name__,
find_measurable_sort,
"basic",
"sort",
)
Loading