From cabee44716f8d6e7482d24ecde863fe8f424641b Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 30 Jun 2025 14:26:30 +0200 Subject: [PATCH 01/11] Show error codes in run_mypy.py --- scripts/run_mypy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/run_mypy.py b/scripts/run_mypy.py index 409e255d7..032fbc938 100755 --- a/scripts/run_mypy.py +++ b/scripts/run_mypy.py @@ -168,7 +168,7 @@ def check_no_unexpected_results(mypy_lines: Iterator[str]): for section, sdf in df.reset_index().groupby(args.groupby): print(f"\n\n[{section}]") for row in sdf.itertuples(): - print(f"{row.file}:{row.line}: {row.type}: {row.message}") + print(f"{row.file}:{row.line}: {row.type} [{row.errorcode}]: {row.message}") print() else: print( From 9d9008e968e81ae4c98825a111aeaf47064f482d Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Sat, 28 Jun 2025 16:24:38 +0200 Subject: [PATCH 02/11] Break circular dependency between model_graph.py and mode/core.py with specific lazy imports --- pymc/model/core.py | 5 ++++- pymc/model_graph.py | 9 ++++----- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/pymc/model/core.py b/pymc/model/core.py index dc09288bc..66e633e15 100644 --- a/pymc/model/core.py +++ b/pymc/model/core.py @@ -52,7 +52,6 @@ from pymc.logprob.basic import transformed_conditional_logp from pymc.logprob.transforms import Transform from pymc.logprob.utils import ParameterValueError, replace_rvs_by_values -from pymc.model_graph import model_to_graphviz, model_to_mermaid from pymc.pytensorf import ( PointFunc, SeedSequenceSeed, @@ -440,6 +439,8 @@ def __exit__(self, exc_type: None, exc_val: None, exc_tb: None) -> None: def _display_(self): import marimo as mo + from pymc.model_graph import model_to_mermaid + return mo.mermaid(model_to_mermaid(self)) @staticmethod @@ -2002,6 +2003,8 @@ def to_graphviz( # creates the file `schools.pdf` schools.to_graphviz().render("schools") """ + from pymc.model_graph import model_to_graphviz + return model_to_graphviz( model=self, var_names=var_names, diff --git a/pymc/model_graph.py b/pymc/model_graph.py index e62e5e244..5185bfbf1 100644 --- a/pymc/model_graph.py +++ b/pymc/model_graph.py @@ -29,8 +29,7 @@ from pytensor.tensor.shape import Shape from pytensor.tensor.variable import TensorVariable -import pymc as pm - +from pymc.model.core import modelcontext from pymc.util import VarName, get_default_varnames, get_var_name __all__ = ( @@ -662,7 +661,7 @@ def model_to_networkx( stacklevel=2, ) - model = pm.modelcontext(model) + model = modelcontext(model) graph = ModelGraph(model) return make_networkx( name=model.name, @@ -777,7 +776,7 @@ def model_to_graphviz( stacklevel=2, ) - model = pm.modelcontext(model) + model = modelcontext(model) graph = ModelGraph(model) return make_graph( model.name, @@ -910,7 +909,7 @@ def model_to_mermaid(model=None, *, var_names=None, include_dim_lengths: bool = """ - model = pm.modelcontext(model) + model = modelcontext(model) graph = ModelGraph(model) plates = sorted(graph.get_plates(var_names=var_names), key=lambda plate: hash(plate.dim_info)) edges = sorted(graph.edges(var_names=var_names)) From 9253343abb91a3996d501b0833dc8b130615604d Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Sun, 29 Jun 2025 23:54:18 +0200 Subject: [PATCH 03/11] Convert mu to TensorVariable in Normal Also remove commented out code --- pymc/distributions/continuous.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/pymc/distributions/continuous.py b/pymc/distributions/continuous.py index ca4fece1b..fb4fd5eb3 100644 --- a/pymc/distributions/continuous.py +++ b/pymc/distributions/continuous.py @@ -487,11 +487,7 @@ class Normal(Continuous): def dist(cls, mu=0, sigma=None, tau=None, **kwargs): tau, sigma = get_tau_sigma(tau=tau, sigma=sigma) sigma = pt.as_tensor_variable(sigma) - - # tau = pt.as_tensor_variable(tau) - # mean = median = mode = mu = pt.as_tensor_variable(floatX(mu)) - # variance = 1.0 / self.tau - + mu = pt.as_tensor_variable(mu) return super().dist([mu, sigma], **kwargs) def support_point(rv, size, mu, sigma): From 26acbe7d68829f063a266e1d391999858c6b68fe Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 17 Jun 2025 20:07:25 +0200 Subject: [PATCH 04/11] Implement xarray like semantics in `dims` module --- .github/workflows/tests.yml | 6 + .../environment-alternative-backends.yml | 2 +- conda-envs/environment-dev.yml | 2 +- conda-envs/environment-docs.yml | 2 +- conda-envs/environment-test.yml | 2 +- conda-envs/windows-environment-dev.yml | 2 +- conda-envs/windows-environment-test.yml | 2 +- pymc/data.py | 24 +- pymc/dims/__init__.py | 67 ++++++ pymc/dims/distributions/__init__.py | 15 ++ pymc/dims/distributions/core.py | 191 +++++++++++++++ pymc/dims/distributions/scalar.py | 174 ++++++++++++++ pymc/dims/distributions/vector.py | 102 ++++++++ pymc/dims/math.py | 15 ++ pymc/dims/model.py | 106 +++++++++ pymc/distributions/continuous.py | 16 +- pymc/distributions/multivariate.py | 4 +- pymc/distributions/shape_utils.py | 24 +- pymc/initial_point.py | 48 ++-- pymc/logprob/basic.py | 15 +- pymc/logprob/rewriting.py | 26 ++- pymc/logprob/utils.py | 38 ++- pymc/math.py | 4 + pymc/model/transform/conditioning.py | 3 +- pymc/pytensorf.py | 112 +++++---- pymc/testing.py | 35 ++- pyproject.toml | 9 + requirements-dev.txt | 2 +- requirements.txt | 2 +- tests/dims/__init__.py | 13 ++ tests/dims/distributions/__init__.py | 13 ++ tests/dims/distributions/test_core.py | 191 +++++++++++++++ tests/dims/distributions/test_scalar.py | 217 ++++++++++++++++++ tests/dims/distributions/test_vector.py | 62 +++++ tests/dims/test_model.py | 174 ++++++++++++++ tests/dims/utils.py | 64 ++++++ tests/distributions/test_continuous.py | 3 +- tests/distributions/test_shape_utils.py | 9 + 38 files changed, 1653 insertions(+), 143 deletions(-) create mode 100644 pymc/dims/__init__.py create mode 100644 pymc/dims/distributions/__init__.py create mode 100644 pymc/dims/distributions/core.py create mode 100644 pymc/dims/distributions/scalar.py create mode 100644 pymc/dims/distributions/vector.py create mode 100644 pymc/dims/math.py create mode 100644 pymc/dims/model.py create mode 100644 tests/dims/__init__.py create mode 100644 tests/dims/distributions/__init__.py create mode 100644 tests/dims/distributions/test_core.py create mode 100644 tests/dims/distributions/test_scalar.py create mode 100644 tests/dims/distributions/test_vector.py create mode 100644 tests/dims/test_model.py create mode 100644 tests/dims/utils.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index d9b4b000c..4ef5add46 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -135,6 +135,12 @@ jobs: tests/logprob/test_transforms.py tests/logprob/test_utils.py + - | + tests/dims/distributions/test_core.py + tests/dims/distributions/test_scalar.py + tests/dims/distributions/test_vector.py + tests/dims/test_model.py + fail-fast: false runs-on: ${{ matrix.os }} env: diff --git a/conda-envs/environment-alternative-backends.yml b/conda-envs/environment-alternative-backends.yml index 5030e7bac..98772e118 100644 --- a/conda-envs/environment-alternative-backends.yml +++ b/conda-envs/environment-alternative-backends.yml @@ -22,7 +22,7 @@ dependencies: - numpyro>=0.8.0 - pandas>=0.24.0 - pip -- pytensor>=2.31.2,<2.32 +- pytensor>=2.31.5,<2.32 - python-graphviz - networkx - rich>=13.7.1 diff --git a/conda-envs/environment-dev.yml b/conda-envs/environment-dev.yml index e161b66e1..a7be421ad 100644 --- a/conda-envs/environment-dev.yml +++ b/conda-envs/environment-dev.yml @@ -12,7 +12,7 @@ dependencies: - numpy>=1.25.0 - pandas>=0.24.0 - pip -- pytensor>=2.31.2,<2.32 +- pytensor>=2.31.5,<2.32 - python-graphviz - networkx - scipy>=1.4.1 diff --git a/conda-envs/environment-docs.yml b/conda-envs/environment-docs.yml index 67e6673af..865605f20 100644 --- a/conda-envs/environment-docs.yml +++ b/conda-envs/environment-docs.yml @@ -11,7 +11,7 @@ dependencies: - numpy>=1.25.0 - pandas>=0.24.0 - pip -- pytensor>=2.31.2,<2.32 +- pytensor>=2.31.5,<2.32 - python-graphviz - rich>=13.7.1 - scipy>=1.4.1 diff --git a/conda-envs/environment-test.yml b/conda-envs/environment-test.yml index 2230d08e7..d8d31ef93 100644 --- a/conda-envs/environment-test.yml +++ b/conda-envs/environment-test.yml @@ -14,7 +14,7 @@ dependencies: - pandas>=0.24.0 - pip - polyagamma -- pytensor>=2.31.2,<2.32 +- pytensor>=2.31.5,<2.32 - python-graphviz - networkx - rich>=13.7.1 diff --git a/conda-envs/windows-environment-dev.yml b/conda-envs/windows-environment-dev.yml index daf039814..5a9641729 100644 --- a/conda-envs/windows-environment-dev.yml +++ b/conda-envs/windows-environment-dev.yml @@ -12,7 +12,7 @@ dependencies: - numpy>=1.25.0 - pandas>=0.24.0 - pip -- pytensor>=2.31.2,<2.32 +- pytensor>=2.31.5,<2.32 - python-graphviz - networkx - rich>=13.7.1 diff --git a/conda-envs/windows-environment-test.yml b/conda-envs/windows-environment-test.yml index 7c9c28b70..3716c6f76 100644 --- a/conda-envs/windows-environment-test.yml +++ b/conda-envs/windows-environment-test.yml @@ -15,7 +15,7 @@ dependencies: - pandas>=0.24.0 - pip - polyagamma -- pytensor>=2.31.2,<2.32 +- pytensor>=2.31.5,<2.32 - python-graphviz - networkx - rich>=13.7.1 diff --git a/pymc/data.py b/pymc/data.py index 507f547e5..cfade3791 100644 --- a/pymc/data.py +++ b/pymc/data.py @@ -13,11 +13,12 @@ # limitations under the License. import io +import typing import urllib.request from collections.abc import Sequence from copy import copy -from typing import cast +from typing import Union, cast import numpy as np import pandas as pd @@ -32,12 +33,13 @@ from pytensor.tensor.random.basic import IntegersRV from pytensor.tensor.variable import TensorConstant, TensorVariable -import pymc as pm - -from pymc.logprob.utils import rvs_in_graph -from pymc.pytensorf import convert_data +from pymc.exceptions import ShapeError +from pymc.pytensorf import convert_data, rvs_in_graph from pymc.vartypes import isgenerator +if typing.TYPE_CHECKING: + from pymc.model.core import Model + __all__ = [ "Data", "Minibatch", @@ -197,7 +199,7 @@ def determine_coords( if isinstance(value, np.ndarray) and dims is not None: if len(dims) != value.ndim: - raise pm.exceptions.ShapeError( + raise ShapeError( "Invalid data shape. The rank of the dataset must match the length of `dims`.", actual=value.shape, expected=value.ndim, @@ -222,6 +224,7 @@ def Data( dims: Sequence[str] | None = None, coords: dict[str, Sequence | np.ndarray] | None = None, infer_dims_and_coords=False, + model: Union["Model", None] = None, **kwargs, ) -> SharedVariable | TensorConstant: """Create a data container that registers a data variable with the model. @@ -286,6 +289,8 @@ def Data( ... model.set_data("data", data_vals) ... idatas.append(pm.sample()) """ + from pymc.model.core import modelcontext + if coords is None: coords = {} @@ -293,8 +298,9 @@ def Data( value = np.array(value) # Add data container to the named variables of the model. - model = pm.Model.get_context(error_if_none=False) - if model is None: + try: + model = modelcontext(model) + except TypeError: raise TypeError( "No model on context stack, which is needed to instantiate a data container. " "Add variable inside a 'with model:' block." @@ -321,7 +327,7 @@ def Data( if isinstance(dims, str): dims = (dims,) if not (dims is None or len(dims) == x.ndim): - raise pm.exceptions.ShapeError( + raise ShapeError( "Length of `dims` must match the dimensions of the dataset.", actual=len(dims), expected=x.ndim, diff --git a/pymc/dims/__init__.py b/pymc/dims/__init__.py new file mode 100644 index 000000000..aa9a3c983 --- /dev/null +++ b/pymc/dims/__init__.py @@ -0,0 +1,67 @@ +# Copyright 2025 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def __init__(): + """Make PyMC aware of the xtensor functionality. + + This should be done eagerly once development matures. + """ + import datetime + import warnings + + from pytensor.compile import optdb + + from pymc.initial_point import initial_point_rewrites_db + from pymc.logprob.abstract import MeasurableOp + from pymc.logprob.rewriting import logprob_rewrites_db + + # Filter PyTensor xtensor warning, we emmit our own warning + with warnings.catch_warnings(): + warnings.simplefilter("ignore", UserWarning) + import pytensor.xtensor + + from pytensor.xtensor.vectorization import XRV + + # Make PyMC aware of xtensor functionality + MeasurableOp.register(XRV) + lower_xtensor_query = optdb.query("+lower_xtensor") + logprob_rewrites_db.register("lower_xtensor", lower_xtensor_query, "basic", position=0.1) + initial_point_rewrites_db.register("lower_xtensor", lower_xtensor_query, "basic", position=0.1) + + # TODO: Better model of probability of bugs + day_of_conception = datetime.date(2025, 6, 17) + day_of_last_bug = datetime.date(2025, 6, 30) + today = datetime.date.today() + days_with_bugs = (day_of_last_bug - day_of_conception).days + days_without_bugs = (today - day_of_last_bug).days + p = 1 - (days_without_bugs / (days_without_bugs + days_with_bugs + 10)) + if p > 0.05: + warnings.warn( + f"The `pymc.dims` module is experimental and may contain critical bugs (p={p:.3f}).\n" + "Please report any issues you encounter at https://github.com/pymc-devs/pymc/issues.\n" + "Disclaimer: This an experimental API and may change at any time.", + UserWarning, + stacklevel=2, + ) + + +__init__() +del __init__ + +from pytensor.xtensor import as_xtensor, concat + +from pymc.dims import math +from pymc.dims.distributions import * +from pymc.dims.model import Data, Deterministic, Potential, with_dims diff --git a/pymc/dims/distributions/__init__.py b/pymc/dims/distributions/__init__.py new file mode 100644 index 000000000..da85bc146 --- /dev/null +++ b/pymc/dims/distributions/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pymc.dims.distributions.scalar import * +from pymc.dims.distributions.vector import * diff --git a/pymc/dims/distributions/core.py b/pymc/dims/distributions/core.py new file mode 100644 index 000000000..bd48db0ec --- /dev/null +++ b/pymc/dims/distributions/core.py @@ -0,0 +1,191 @@ +# Copyright 2025 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections.abc import Callable, Sequence +from itertools import chain + +from pytensor.graph.basic import Variable +from pytensor.tensor.elemwise import DimShuffle +from pytensor.xtensor import as_xtensor +from pytensor.xtensor.type import XTensorVariable + +from pymc import modelcontext +from pymc.dims.model import with_dims +from pymc.distributions import transforms +from pymc.distributions.distribution import _support_point, support_point +from pymc.distributions.shape_utils import DimsWithEllipsis, convert_dims_with_ellipsis +from pymc.logprob.transforms import Transform +from pymc.util import UNSET + + +@_support_point.register(DimShuffle) +def dimshuffle_support_point(ds_op, _, rv): + # We implement support point for DimShuffle because + # DimDistribution can register a transposed version of a variable. + + return ds_op(support_point(rv)) + + +class DimDistribution: + """Base class for PyMC distribution that wrap pytensor.xtensor.random operations, and follow xarray-like semantics.""" + + xrv_op: Callable + default_transform: Transform | None = None + + @staticmethod + def _as_xtensor(x): + try: + return as_xtensor(x) + except TypeError: + try: + return with_dims(x) + except ValueError: + raise ValueError( + f"Variable {x} must have dims associated with it.\n" + "To avoid subtle bugs, PyMC does not make any assumptions about the dims of parameters.\n" + "Use `as_xtensor` with the `dims` keyword argument to specify the dims explicitly." + ) + + def __new__( + cls, + name: str, + *dist_params, + dims: DimsWithEllipsis | None = None, + initval=None, + observed=None, + total_size=None, + transform=UNSET, + default_transform=UNSET, + model=None, + **kwargs, + ): + try: + model = modelcontext(model) + except TypeError: + raise TypeError( + "No model on context stack, which is needed to instantiate distributions. " + "Add variable inside a 'with model:' block, or use the '.dist' syntax for a standalone distribution." + ) + + if not isinstance(name, str): + raise TypeError(f"Name needs to be a string but got: {name}") + + dims = convert_dims_with_ellipsis(dims) + if dims is None: + dim_lengths = {} + else: + try: + dim_lengths = {dim: model.dim_lengths[dim] for dim in dims if dim is not Ellipsis} + except KeyError: + raise ValueError( + f"Not all dims {dims} are part of the model coords. " + f"Add them at initialization time or use `model.add_coord` before defining the distribution." + ) + + if observed is not None: + observed = cls._as_xtensor(observed) + + # Propagate observed dims to dim_lengths + for observed_dim in observed.type.dims: + if observed_dim not in dim_lengths: + dim_lengths[observed_dim] = model.dim_lengths[observed_dim] + + rv = cls.dist(*dist_params, dim_lengths=dim_lengths, **kwargs) + + # User provided dims must specify all dims or use ellipsis + if dims is not None: + if (... not in dims) and (set(dims) != set(rv.type.dims)): + raise ValueError( + f"Provided dims {dims} do not match the distribution's output dims {rv.type.dims}. " + "Use ellipsis to specify all other dimensions." + ) + # Use provided dims to transpose the output to the desired order + rv = rv.transpose(*dims) + + rv_dims = rv.type.dims + if observed is None: + if default_transform is UNSET: + default_transform = cls.default_transform + else: + # Align observed dims with those of the RV + # TODO: If this fails give a more informative error message + observed = observed.transpose(*rv_dims).values + + rv = model.register_rv( + rv.values, + name=name, + observed=observed, + total_size=total_size, + dims=rv_dims, + transform=transform, + default_transform=default_transform, + initval=initval, + ) + + return as_xtensor(rv, dims=rv_dims) + + @classmethod + def dist( + cls, + dist_params, + *, + dim_lengths: dict[str, Variable | int] | None = None, + core_dims: str | Sequence[str] | None = None, + **kwargs, + ) -> XTensorVariable: + for invalid_kwarg in ("size", "shape", "dims"): + if invalid_kwarg in kwargs: + raise TypeError(f"DimDistribution does not accept {invalid_kwarg} argument.") + + # XRV requires only extra_dims, not dims + dist_params = [cls._as_xtensor(param) for param in dist_params] + + if dim_lengths is None: + extra_dims = None + else: + # Exclude dims that are implied by the parameters or core_dims + implied_dims = set(chain.from_iterable(param.type.dims for param in dist_params)) + if core_dims is not None: + if isinstance(core_dims, str): + implied_dims.add(core_dims) + else: + implied_dims.update(core_dims) + + extra_dims = { + dim: length for dim, length in dim_lengths.items() if dim not in implied_dims + } + return cls.xrv_op(*dist_params, extra_dims=extra_dims, core_dims=core_dims, **kwargs) + + +class VectorDimDistribution(DimDistribution): + @classmethod + def dist(self, *args, core_dims: str | Sequence[str] | None = None, **kwargs): + # Add a helpful error message if core_dims is not provided + if core_dims is None: + raise ValueError( + f"{self.__name__} requires core_dims to be specified, as it involves non-scalar inputs or outputs." + "Check the documentation of the distribution for details." + ) + return super().dist(*args, core_dims=core_dims, **kwargs) + + +class PositiveDimDistribution(DimDistribution): + """Base class for positive continuous distributions.""" + + default_transform = transforms.log + + +class UnitDimDistribution(DimDistribution): + """Base class for unit-valued distributions.""" + + default_transform = transforms.logodds diff --git a/pymc/dims/distributions/scalar.py b/pymc/dims/distributions/scalar.py new file mode 100644 index 000000000..540f69cf8 --- /dev/null +++ b/pymc/dims/distributions/scalar.py @@ -0,0 +1,174 @@ +# Copyright 2025 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytensor.xtensor as ptx +import pytensor.xtensor.random as pxr + +from pymc.dims.distributions.core import ( + DimDistribution, + PositiveDimDistribution, + UnitDimDistribution, +) +from pymc.distributions.continuous import Beta as RegularBeta +from pymc.distributions.continuous import Gamma as RegularGamma +from pymc.distributions.continuous import flat, halfflat + + +def _get_sigma_from_either_sigma_or_tau(*, sigma, tau): + if sigma is not None and tau is not None: + raise ValueError("Can't pass both tau and sigma") + + if sigma is None and tau is None: + return 1.0 + + if sigma is not None: + return sigma + + return ptx.math.reciprocal(ptx.math.sqrt(tau)) + + +class Flat(DimDistribution): + xrv_op = pxr._as_xrv(flat) + + @classmethod + def dist(cls, **kwargs): + return super().dist([], **kwargs) + + +class HalfFlat(PositiveDimDistribution): + xrv_op = pxr._as_xrv(halfflat, [], ()) + + @classmethod + def dist(cls, **kwargs): + return super().dist([], **kwargs) + + +class Normal(DimDistribution): + xrv_op = pxr.normal + + @classmethod + def dist(cls, mu=0, sigma=None, *, tau=None, **kwargs): + sigma = _get_sigma_from_either_sigma_or_tau(sigma=sigma, tau=tau) + return super().dist([mu, sigma], **kwargs) + + +class HalfNormal(PositiveDimDistribution): + xrv_op = pxr.halfnormal + + @classmethod + def dist(cls, sigma=None, *, tau=None, **kwargs): + sigma = _get_sigma_from_either_sigma_or_tau(sigma=sigma, tau=tau) + return super().dist([0.0, sigma], **kwargs) + + +class LogNormal(PositiveDimDistribution): + xrv_op = pxr.lognormal + + @classmethod + def dist(cls, mu=0, sigma=None, *, tau=None, **kwargs): + sigma = _get_sigma_from_either_sigma_or_tau(sigma=sigma, tau=tau) + return super().dist([mu, sigma], **kwargs) + + +class StudentT(DimDistribution): + xrv_op = pxr.t + + @classmethod + def dist(cls, nu, mu=0, sigma=None, *, lam=None, **kwargs): + sigma = _get_sigma_from_either_sigma_or_tau(sigma=sigma, tau=lam) + return super().dist([nu, mu, sigma], **kwargs) + + +class Cauchy(DimDistribution): + xrv_op = pxr.cauchy + + @classmethod + def dist(cls, alpha, beta, **kwargs): + return super().dist([alpha, beta], **kwargs) + + +class HalfCauchy(PositiveDimDistribution): + xrv_op = pxr.halfcauchy + + @classmethod + def dist(cls, beta, **kwargs): + return super().dist([0.0, beta], **kwargs) + + +class Beta(UnitDimDistribution): + xrv_op = pxr.beta + + @classmethod + def dist(cls, alpha=None, beta=None, *, mu=None, sigma=None, nu=None, **kwargs): + alpha, beta = RegularBeta.get_alpha_beta(alpha=alpha, beta=beta, mu=mu, sigma=sigma, nu=nu) + return super().dist([alpha, beta], **kwargs) + + +class Laplace(DimDistribution): + xrv_op = pxr.laplace + + @classmethod + def dist(cls, mu=0, b=1, **kwargs): + return super().dist([mu, b], **kwargs) + + +class Exponential(PositiveDimDistribution): + xrv_op = pxr.exponential + + @classmethod + def dist(cls, lam=None, *, scale=None, **kwargs): + if lam is None and scale is None: + scale = 1.0 + elif lam is not None and scale is not None: + raise ValueError("Cannot pass both 'lam' and 'scale'. Use one of them.") + elif lam is not None: + scale = 1 / lam + return super().dist([scale], **kwargs) + + +class Gamma(PositiveDimDistribution): + xrv_op = pxr.gamma + + @classmethod + def dist(cls, alpha=None, beta=None, *, mu=None, sigma=None, **kwargs): + if (alpha is not None) and (beta is not None): + pass + elif (mu is not None) and (sigma is not None): + # Use sign of sigma to not let negative sigma fly by + alpha = (mu**2 / sigma**2) * ptx.math.sign(sigma) + beta = mu / sigma**2 + else: + raise ValueError( + "Incompatible parameterization. Either use alpha and beta, or mu and sigma." + ) + alpha, beta = RegularGamma.get_alpha_beta(alpha=alpha, beta=beta, mu=mu, sigma=sigma) + return super().dist([alpha, ptx.math.reciprocal(beta)], **kwargs) + + +class InverseGamma(PositiveDimDistribution): + xrv_op = pxr.invgamma + + @classmethod + def dist(cls, alpha=None, beta=None, *, mu=None, sigma=None, **kwargs): + if alpha is not None: + if beta is None: + beta = 1.0 + elif (mu is not None) and (sigma is not None): + # Use sign of sigma to not let negative sigma fly by + alpha = ((2 * sigma**2 + mu**2) / sigma**2) * ptx.math.sign(sigma) + beta = mu * (mu**2 + sigma**2) / sigma**2 + else: + raise ValueError( + "Incompatible parameterization. Either use alpha and (optionally) beta, or mu and sigma" + ) + return super().dist([alpha, beta], **kwargs) diff --git a/pymc/dims/distributions/vector.py b/pymc/dims/distributions/vector.py new file mode 100644 index 000000000..b11bd56f2 --- /dev/null +++ b/pymc/dims/distributions/vector.py @@ -0,0 +1,102 @@ +# Copyright 2025 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytensor.xtensor as ptx +import pytensor.xtensor.random as ptxr + +from pytensor.tensor.random.utils import normalize_size_param +from pytensor.xtensor import random as pxr + +from pymc.dims.distributions.core import VectorDimDistribution +from pymc.distributions.multivariate import ZeroSumNormalRV + + +class Categorical(VectorDimDistribution): + xrv_op = ptxr.categorical + + @classmethod + def dist(cls, p=None, *, logit_p=None, core_dims=None, **kwargs): + if p is not None and logit_p is not None: + raise ValueError("Incompatible parametrization. Can't specify both p and logit_p.") + elif p is None and logit_p is None: + raise ValueError("Incompatible parametrization. Must specify either p or logit_p.") + + if logit_p is not None: + p = ptx.math.softmax(logit_p, dim=core_dims) + return super().dist([p], core_dims=core_dims, **kwargs) + + +class MvNormal(VectorDimDistribution): + """Multivariate Normal distribution. + + Parameters + ---------- + mu : xtensor_like + Mean vector of the distribution. + cov : xtensor_like, optional + Covariance matrix of the distribution. Only one of `cov` or `chol` must be provided. + chol : xtensor_like, optional + Cholesky decomposition of the covariance matrix. only one of `cov` or `chol` must be provided. + lower : bool, default True + If True, the Cholesky decomposition is assumed to be lower triangular. + If False, it is assumed to be upper triangular. + core_dims: Sequence of string, optional + Sequence of two strings representing the core dimensions of the distribution. + The two dimensions must be present in `cov` or `chol`, and exactly one must also be present in `mu`. + **kwargs + Additional keyword arguments used to define the distribution. + + Returns + ------- + XTensorVariable + An xtensor variable representing the multivariate normal distribution. + The output contains the core dimension that is shared between `mu` and `cov` or `chol`. + + """ + + xrv_op = pxr.multivariate_normal + + @classmethod + def dist(cls, mu, cov=None, *, chol=None, lower=True, core_dims=None, **kwargs): + if "tau" in kwargs: + raise NotImplementedError("MvNormal does not support 'tau' parameter.") + + if not (isinstance(core_dims, tuple | list) and len(core_dims) == 2): + raise ValueError("MvNormal requires 2 core_dims") + + if cov is None and chol is None: + raise ValueError("Either 'cov' or 'chol' must be provided.") + + if chol is not None: + d0, d1 = core_dims + if not lower: + # By logical symmetry this must be the only correct way to implement lower + # We refuse to test it because it is not useful + d1, d0 = d0, d1 + + chol = cls._as_xtensor(chol) + # chol @ chol.T in xarray semantics requires a rename + safe_name = "_" + if "_" in chol.type.dims: + safe_name *= max(map(len, chol.type.dims)) + 1 + cov = chol.dot(chol.rename({d0: safe_name}), dim=d1).rename({safe_name: d1}) + + return super().dist([mu, cov], core_dims=core_dims, **kwargs) + + +class DimZeroSumNormalRV(ZeroSumNormalRV): + def make_node(self, rng, size, sigma, support_shape): + if not self.input_types[1].in_same_class(normalize_size_param(size).type): + # We need to rebuild the graph with new size type + return self.rv_op(sigma, support_shape, size=size, rng=rng).owner + return super().make_node(rng, size, sigma, support_shape) diff --git a/pymc/dims/math.py b/pymc/dims/math.py new file mode 100644 index 000000000..5e1c8ba17 --- /dev/null +++ b/pymc/dims/math.py @@ -0,0 +1,15 @@ +# Copyright 2025 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pytensor.xtensor import linalg +from pytensor.xtensor.math import * diff --git a/pymc/dims/model.py b/pymc/dims/model.py new file mode 100644 index 000000000..e76bf06ec --- /dev/null +++ b/pymc/dims/model.py @@ -0,0 +1,106 @@ +# Copyright 2025 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections.abc import Callable + +from pytensor.tensor import TensorVariable +from pytensor.xtensor import as_xtensor +from pytensor.xtensor.basic import TensorFromXTensor +from pytensor.xtensor.type import XTensorVariable + +from pymc.data import Data as RegularData +from pymc.distributions.shape_utils import ( + Dims, + DimsWithEllipsis, + convert_dims, + convert_dims_with_ellipsis, +) +from pymc.model.core import Deterministic as RegularDeterministic +from pymc.model.core import Model, modelcontext +from pymc.model.core import Potential as RegularPotential + + +def with_dims(x: TensorVariable | XTensorVariable, model: Model | None = None) -> XTensorVariable: + """Recover the dims of a variable that was registered in the Model.""" + if isinstance(x, XTensorVariable): + return x + + if (x.owner is not None) and isinstance(x.owner.op, TensorFromXTensor): + dims = x.owner.inputs[0].type.dims + return as_xtensor(x, dims=dims, name=x.name) + + # Try accessing the model context to get dims + try: + model = modelcontext(model) + if ( + model.named_vars.get(x.name, None) is x + and (dims := model.named_vars_to_dims.get(x.name, None)) is not None + ): + return as_xtensor(x, dims=dims, name=x.name) + except TypeError: + pass + + raise ValueError(f"variable {x} doesn't have dims associated with it") + + +def Data( + name: str, value, dims: Dims = None, model: Model | None = None, **kwargs +) -> XTensorVariable: + """Wrapper around regular Data that returns an XtensorVariable.""" + model = modelcontext(model) + dims = convert_dims(dims) + + with model: + value = RegularData(name, value, dims=dims, **kwargs) + + dims = model.named_vars_to_dims[value.name] + if dims is None and value.ndim > 0: + raise ValueError("pymc.dims.Data requires dims to be specified for non-scalar data.") + + return as_xtensor(value, dims=dims, name=name) # type: ignore[arg-type] + + +def _register_and_return_xtensor_variable( + name: str, + value: TensorVariable | XTensorVariable, + dims: DimsWithEllipsis | None, + model: Model | None, + registration_func: Callable, +) -> XTensorVariable: + if isinstance(value, XTensorVariable): + dims = convert_dims_with_ellipsis(dims) + if dims is not None: + # If dims are provided, apply a transpose to align with the user expectation + value = value.transpose(*dims) + # Regardless of whether dims are provided, we now have them + dims = value.type.dims + # Register the equivalent TensorVariable with the model so it doesn't see XTensorVariables directly. + value = value.values # type: ignore[union-attr] + + value = registration_func(name, value, dims=dims, model=model) + + return as_xtensor(value, dims=dims, name=name) # type: ignore[arg-type] + + +def Deterministic( + name: str, value, dims: DimsWithEllipsis | None = None, model: Model | None = None +) -> XTensorVariable: + """Wrapper around pytensor.xtensor.deterministic that returns an XtensorVariable.""" + return _register_and_return_xtensor_variable(name, value, dims, model, RegularDeterministic) + + +def Potential( + name: str, value, dims: DimsWithEllipsis | None = None, model: Model | None = None +) -> XTensorVariable: + """Wrapper around pytensor.xtensor.potential that returns an XtensorVariable.""" + return _register_and_return_xtensor_variable(name, value, dims, model, RegularPotential) diff --git a/pymc/distributions/continuous.py b/pymc/distributions/continuous.py index fb4fd5eb3..1ccc8c06a 100644 --- a/pymc/distributions/continuous.py +++ b/pymc/distributions/continuous.py @@ -2370,11 +2370,8 @@ def get_alpha_beta(cls, alpha=None, beta=None, mu=None, sigma=None): if (alpha is not None) and (beta is not None): pass elif (mu is not None) and (sigma is not None): - if isinstance(sigma, Variable): - sigma = check_parameters(sigma, sigma > 0, msg="sigma > 0") - else: - assert np.all(np.asarray(sigma) > 0) - alpha = mu**2 / sigma**2 + # Use sign of sigma to not let negative sigma fly by + alpha = (mu**2 / sigma**2) * pt.sign(sigma) beta = mu / sigma**2 else: raise ValueError( @@ -2496,13 +2493,10 @@ def _get_alpha_beta(cls, alpha, beta, mu, sigma): if beta is not None: pass else: - beta = 1 + beta = 1.0 elif (mu is not None) and (sigma is not None): - if isinstance(sigma, Variable): - sigma = check_parameters(sigma, sigma > 0, msg="sigma > 0") - else: - assert np.all(np.asarray(sigma) > 0) - alpha = (2 * sigma**2 + mu**2) / sigma**2 + # Use sign of sigma to not let negative sigma fly by + alpha = ((2 * sigma**2 + mu**2) / sigma**2) * pt.sign(sigma) beta = mu * (mu**2 + sigma**2) / sigma**2 else: raise ValueError( diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index ac2a13431..f5a5506bb 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -149,11 +149,11 @@ def quaddist_matrix(cov=None, chol=None, tau=None, lower=True, *args, **kwargs): raise ValueError("chol must be at least two dimensional.") if not lower: - chol = pt.swapaxes(chol, -1, -2) + chol = chol.mT # tag as lower triangular to enable pytensor rewrites of chol(l.l') -> l chol.tag.lower_triangular = True - cov = pt.matmul(chol, pt.swapaxes(chol, -1, -2)) + cov = chol @ chol.mT return cov diff --git a/pymc/distributions/shape_utils.py b/pymc/distributions/shape_utils.py index 8ef378b76..85177369a 100644 --- a/pymc/distributions/shape_utils.py +++ b/pymc/distributions/shape_utils.py @@ -18,6 +18,7 @@ from collections.abc import Sequence from functools import singledispatch +from types import EllipsisType from typing import Any, TypeAlias, cast import numpy as np @@ -87,11 +88,13 @@ def _check_shape_type(shape): # User-provided can be lazily specified as scalars Shape: TypeAlias = int | TensorVariable | Sequence[int | Variable] Dims: TypeAlias = str | Sequence[str | None] +DimsWithEllipsis: TypeAlias = str | EllipsisType | Sequence[str | None | EllipsisType] Size: TypeAlias = int | TensorVariable | Sequence[int | Variable] # After conversion to vectors StrongShape: TypeAlias = TensorVariable | tuple[int | Variable, ...] -StrongDims: TypeAlias = Sequence[str | None] +StrongDims: TypeAlias = Sequence[str] +StrongDimsWithEllipsis: TypeAlias = Sequence[str | EllipsisType] StrongSize: TypeAlias = TensorVariable | tuple[int | Variable, ...] @@ -107,7 +110,24 @@ def convert_dims(dims: Dims | None) -> StrongDims | None: else: raise ValueError(f"The `dims` parameter must be a tuple, str or list. Actual: {type(dims)}") - return dims + return dims # type: ignore[return-value] + + +def convert_dims_with_ellipsis(dims: DimsWithEllipsis | None) -> StrongDimsWithEllipsis | None: + """Process a user-provided dims variable into None or a valid dims tuple with ellipsis.""" + if dims is None: + return None + + if isinstance(dims, str | EllipsisType): + dims = (dims,) + elif isinstance(dims, list | tuple): + dims = tuple(dims) + else: + raise ValueError( + f"The `dims` parameter must be a tuple, list, str or Ellipsis. Actual: {type(dims)}" + ) + + return dims # type: ignore[return-value] def convert_shape(shape: Shape) -> StrongShape | None: diff --git a/pymc/initial_point.py b/pymc/initial_point.py index df9419c74..4704979e3 100644 --- a/pymc/initial_point.py +++ b/pymc/initial_point.py @@ -20,8 +20,9 @@ import pytensor import pytensor.tensor as pt -from pytensor.graph.basic import Variable +from pytensor.graph.basic import Constant, Variable from pytensor.graph.fg import FunctionGraph +from pytensor.graph.rewriting.db import RewriteDatabaseQuery, SequenceDB from pytensor.tensor.variable import TensorVariable from pymc.logprob.transforms import Transform @@ -37,6 +38,8 @@ StartDict = dict[Variable | str, np.ndarray | Variable | str] PointType = dict[str, np.ndarray] +initial_point_rewrites_db = SequenceDB() +initial_point_basic_query = RewriteDatabaseQuery(include=["basic"]) def convert_str_to_rv_dict( @@ -230,11 +233,20 @@ def make_initial_point_expression( if jitter_rvs is None: jitter_rvs = set() + # Clone free_rvs so we don't modify the original graph + initial_point_fgraph = FunctionGraph(outputs=free_rvs, clone=True) + + # Apply any rewrites necessary to compute the initial points. + initial_point_rewriter = initial_point_rewrites_db.query(initial_point_basic_query) + if initial_point_rewriter: + initial_point_rewriter.rewrite(initial_point_fgraph) + + free_rvs_clone = initial_point_fgraph.outputs + initial_values = [] initial_values_transformed = [] - - for variable in free_rvs: - strategy = initval_strategies.get(variable, None) + for old_variable, variable in zip(free_rvs, free_rvs_clone): + strategy = initval_strategies.get(old_variable) if strategy is None: strategy = default_strategy @@ -261,9 +273,13 @@ def make_initial_point_expression( f'Invalid string strategy: {strategy}. It must be one of ["support_point", "prior"]' ) else: - value = pt.as_tensor(strategy, dtype=variable.dtype).astype(variable.dtype) + if isinstance(strategy, Variable) and not isinstance(strategy, Constant): + raise ValueError( + f"Initial value must be constants variables. {strategy} comes from an apply node {strategy.node}" + ) + value = pt.as_tensor(strategy, variable.dtype).astype(variable.dtype) - transform = rvs_to_transforms.get(variable, None) + transform = rvs_to_transforms.get(old_variable, None) if transform is not None: value = transform.forward(value, *variable.owner.inputs) @@ -281,28 +297,16 @@ def make_initial_point_expression( initial_values.append(value) - all_outputs: list[TensorVariable] = [] - all_outputs.extend(free_rvs) - all_outputs.extend(initial_values) - all_outputs.extend(initial_values_transformed) - - copy_graph = FunctionGraph(outputs=all_outputs, clone=True) - - n_variables = len(free_rvs) - free_rvs_clone = copy_graph.outputs[:n_variables] - initial_values_clone = copy_graph.outputs[n_variables:-n_variables] - initial_values_transformed_clone = copy_graph.outputs[-n_variables:] - # We now replace all rvs by the respective initial_point expressions # in the constrained (untransformed) space. We do this in reverse topological # order, so that later nodes do not reintroduce expressions with earlier # rvs that would need to once again be replaced by their initial_points - graph = FunctionGraph(outputs=free_rvs_clone, clone=False) - toposort_replace(graph, tuple(zip(free_rvs_clone, initial_values_clone)), reverse=True) + toposort_replace(initial_point_fgraph, tuple(zip(free_rvs_clone, initial_values)), reverse=True) if not return_transformed: - return graph.outputs + return initial_point_fgraph.outputs + # Because the unconstrained (transformed) expressions are a subgraph of the # constrained initial point they were also automatically updated inplace # when calling graph.replace_all above, so we don't need to do anything else - return initial_values_transformed_clone + return initial_values_transformed diff --git a/pymc/logprob/basic.py b/pymc/logprob/basic.py index 85d74aab3..9fb76df16 100644 --- a/pymc/logprob/basic.py +++ b/pymc/logprob/basic.py @@ -46,6 +46,7 @@ Constant, Variable, ancestors, + walk, ) from pytensor.graph.rewriting.basic import GraphRewriter, NodeRewriter from pytensor.tensor.variable import TensorVariable @@ -60,8 +61,8 @@ from pymc.logprob.rewriting import cleanup_ir, construct_ir_fgraph from pymc.logprob.transform_value import TransformValuesRewrite from pymc.logprob.transforms import Transform -from pymc.logprob.utils import get_related_valued_nodes, rvs_in_graph -from pymc.pytensorf import replace_vars_in_graphs +from pymc.logprob.utils import get_related_valued_nodes +from pymc.pytensorf import expand_inner_graph, replace_vars_in_graphs TensorLike: TypeAlias = Variable | float | np.ndarray @@ -71,9 +72,13 @@ def _find_unallowed_rvs_in_graph(graph): from pymc.distributions.simulator import SimulatorRV return { - rv - for rv in rvs_in_graph(graph) - if not isinstance(rv.owner.op, SimulatorRV | MinibatchIndexRV) + var + for var in walk(graph, expand_inner_graph, False) + if ( + var.owner + and isinstance(var.owner.op, MeasurableOp) + and not isinstance(var.owner.op, SimulatorRV | MinibatchIndexRV) + ) } diff --git a/pymc/logprob/rewriting.py b/pymc/logprob/rewriting.py index b5a6b23a0..ea0202f00 100644 --- a/pymc/logprob/rewriting.py +++ b/pymc/logprob/rewriting.py @@ -132,6 +132,7 @@ def remove_DiracDelta(fgraph, node): return [dd_val] +logprob_rewrites_basic_query = RewriteDatabaseQuery(include=["basic"]) logprob_rewrites_db = SequenceDB() logprob_rewrites_db.name = "logprob_rewrites_db" @@ -146,16 +147,21 @@ def remove_DiracDelta(fgraph, node): failure_callback=None, ), "basic", + position=0, ) # Introduce sigmoid. We do it before canonicalization so that useless mul are removed next logprob_rewrites_db.register( - "local_exp_over_1_plus_exp", out2in(local_exp_over_1_plus_exp), "basic" + "local_exp_over_1_plus_exp", + out2in(local_exp_over_1_plus_exp), + "basic", + position=0.9, ) logprob_rewrites_db.register( "pre-canonicalize", optdb.query("+canonicalize", "-local_eager_useless_unbatched_blockwise"), "basic", + position=1, ) # These rewrites convert un-measurable variables into their measurable forms, @@ -164,18 +170,26 @@ def remove_DiracDelta(fgraph, node): measurable_ir_rewrites_db = EquilibriumDB() measurable_ir_rewrites_db.name = "measurable_ir_rewrites_db" -logprob_rewrites_db.register("measurable_ir_rewrites", measurable_ir_rewrites_db, "basic") +logprob_rewrites_db.register( + "measurable_ir_rewrites", + measurable_ir_rewrites_db, + "basic", + position=2, +) # These rewrites push random/measurable variables "down", making them closer to # (or eventually) the graph outputs. Often this is done by lifting other `Op`s # "up" through the random/measurable variables and into their inputs. measurable_ir_rewrites_db.register("subtensor_lift", local_subtensor_rv_lift, "basic") -# These rewrites are used to introduce specalized operations with better logprob graphs +# These rewrites are used to introduce specialized operations with better logprob graphs specialization_ir_rewrites_db = EquilibriumDB() specialization_ir_rewrites_db.name = "specialization_ir_rewrites_db" logprob_rewrites_db.register( - "specialization_ir_rewrites_db", specialization_ir_rewrites_db, "basic" + "specialization_ir_rewrites_db", + specialization_ir_rewrites_db, + "basic", + position=3, ) @@ -183,6 +197,7 @@ def remove_DiracDelta(fgraph, node): "post-canonicalize", optdb.query("+canonicalize", "-local_eager_useless_unbatched_blockwise"), "basic", + position=4, ) # Rewrites that remove IR Ops @@ -192,6 +207,7 @@ def remove_DiracDelta(fgraph, node): "cleanup_ir_rewrites", TopoDB(cleanup_ir_rewrites_db, order="out_to_in", ignore_newtrees=True, failure_callback=None), "cleanup", + position=5, ) cleanup_ir_rewrites_db.register("remove_DiracDelta", remove_DiracDelta, "cleanup") @@ -250,7 +266,7 @@ def construct_ir_fgraph( toposort_replace(fgraph, replacements, reverse=True) if ir_rewriter is None: - ir_rewriter = logprob_rewrites_db.query(RewriteDatabaseQuery(include=["basic"])) + ir_rewriter = logprob_rewrites_db.query(logprob_rewrites_basic_query) ir_rewriter.rewrite(fgraph) # Reintroduce original value variables diff --git a/pymc/logprob/utils.py b/pymc/logprob/utils.py index a02834103..0f521393d 100644 --- a/pymc/logprob/utils.py +++ b/pymc/logprob/utils.py @@ -43,21 +43,18 @@ from pytensor import tensor as pt from pytensor.graph import Apply, Op, node_rewriter -from pytensor.graph.basic import Constant, Variable, clone_get_equiv, graph_inputs, walk +from pytensor.graph.basic import Constant, clone_get_equiv, graph_inputs, walk from pytensor.graph.fg import FunctionGraph -from pytensor.graph.op import HasInnerGraph from pytensor.link.c.type import CType from pytensor.raise_op import CheckAndRaise from pytensor.scalar.basic import Mul from pytensor.tensor.basic import get_underlying_scalar_constant_value from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.exceptions import NotScalarConstantError -from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.variable import TensorVariable from pymc.logprob.abstract import MeasurableOp, ValuedRV, _logprob from pymc.pytensorf import replace_vars_in_graphs -from pymc.util import makeiter if typing.TYPE_CHECKING: from pymc.logprob.transforms import Transform @@ -130,26 +127,6 @@ def populate_replacements(var): return replace_vars_in_graphs(graphs, replacements) -def rvs_in_graph(vars: Variable | Sequence[Variable]) -> set[Variable]: - """Assert that there are no `MeasurableOp` nodes in a graph.""" - - def expand(r): - owner = r.owner - if owner: - inputs = list(reversed(owner.inputs)) - - if isinstance(owner.op, HasInnerGraph): - inputs += owner.op.inner_outputs - - return inputs - - return { - node - for node in walk(makeiter(vars), expand, False) - if node.owner and isinstance(node.owner.op, RandomVariable | MeasurableOp) - } - - def convert_indices(indices, entry): if indices and isinstance(entry, CType): rval = indices.pop(0) @@ -334,3 +311,16 @@ def get_related_valued_nodes(fgraph: FunctionGraph, node: Apply) -> list[Apply]: for client, _ in clients[out] if isinstance(client.op, ValuedRV) ] + + +def __getattr__(name): + if name == "rvs_in_graphs": + warnings.warn( + f"{name} has been moved to `pymc.pytensorf`. Importing from `pymc.logprob.utils` will fail in a future release.", + FutureWarning, + ) + from pymc.pytensorf import rvs_in_graph + + return rvs_in_graph() + + raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/pymc/math.py b/pymc/math.py index 13655f534..65ddacfb9 100644 --- a/pymc/math.py +++ b/pymc/math.py @@ -33,6 +33,7 @@ arcsinh, arctan, arctanh, + as_tensor, broadcast_to, ceil, clip, @@ -42,6 +43,7 @@ cosh, cumprod, cumsum, + diff, dot, eq, erf, @@ -103,6 +105,7 @@ "arcsinh", "arctan", "arctanh", + "as_tensor", "batched_diag", "block_diagonal", "broadcast_to", @@ -115,6 +118,7 @@ "cosh", "cumprod", "cumsum", + "diff", "dot", "eq", "erf", diff --git a/pymc/model/transform/conditioning.py b/pymc/model/transform/conditioning.py index edcf5862b..6c40ab563 100644 --- a/pymc/model/transform/conditioning.py +++ b/pymc/model/transform/conditioning.py @@ -22,7 +22,6 @@ from pytensor.tensor import TensorVariable from pymc.logprob.transforms import Transform -from pymc.logprob.utils import rvs_in_graph from pymc.model.core import Model from pymc.model.fgraph import ( ModelDeterministic, @@ -41,7 +40,7 @@ parse_vars, prune_vars_detached_from_observed, ) -from pymc.pytensorf import replace_vars_in_graphs, toposort_replace +from pymc.pytensorf import replace_vars_in_graphs, rvs_in_graph, toposort_replace from pymc.util import get_transformed_name, get_untransformed_name diff --git a/pymc/pytensorf.py b/pymc/pytensorf.py index f1d69c928..3ce308147 100644 --- a/pymc/pytensorf.py +++ b/pymc/pytensorf.py @@ -33,13 +33,15 @@ clone_get_equiv, equal_computations, graph_inputs, + walk, ) from pytensor.graph.fg import FunctionGraph, Output +from pytensor.graph.op import HasInnerGraph from pytensor.scalar.basic import Cast from pytensor.scan.op import Scan from pytensor.tensor.basic import _as_tensor_variable from pytensor.tensor.elemwise import Elemwise -from pytensor.tensor.random.op import RandomVariable +from pytensor.tensor.random.op import RandomVariable, RNGConsumerOp from pytensor.tensor.random.type import RandomType from pytensor.tensor.random.var import RandomGeneratorSharedVariable from pytensor.tensor.rewriting.basic import topo_unconditional_constant_folding @@ -133,6 +135,9 @@ def dataframe_to_tensor_variable(df: pd.DataFrame, *args, **kwargs) -> TensorVar return pt.as_tensor_variable(df.to_numpy(), *args, **kwargs) +_cheap_eval_mode = Mode(linker="py", optimizer="minimum_compile") + + def extract_obs_data(x: TensorVariable) -> np.ndarray: """Extract data from observed symbolic variables. @@ -161,15 +166,31 @@ def extract_obs_data(x: TensorVariable) -> np.ndarray: mask[mask_idx] = 1 return np.ma.MaskedArray(array_data, mask) - from pymc.logprob.utils import rvs_in_graph - if not inputvars(x) and not rvs_in_graph(x): - cheap_eval_mode = Mode(linker="py", optimizer=None) - return x.eval(mode=cheap_eval_mode) + return x.eval(mode=_cheap_eval_mode) raise TypeError(f"Data cannot be extracted from {x}") +def expand_inner_graph(r): + if (node := r.owner) is not None: + inputs = list(reversed(node.inputs)) + + if isinstance(node.op, HasInnerGraph): + inputs += node.op.inner_outputs + + return inputs + + +def rvs_in_graph(vars: Variable | Sequence[Variable], rv_ops=None) -> set[Variable]: + """Assert that there are no random nodes in a graph.""" + return { + var + for var in walk(makeiter(vars), expand_inner_graph, False) + if (var.owner and isinstance(var.owner.op, RNGConsumerOp)) + } + + def replace_vars_in_graphs( graphs: Iterable[Variable], replacements: dict[Variable, Variable], @@ -720,8 +741,6 @@ def scan_step(xtm1): xs_draws = pm.draw(xs, draws=10) """ - # Avoid circular import - from pymc.distributions.distribution import SymbolicRandomVariable def find_default_update(clients, rng: Variable) -> None | Variable: rng_clients = clients.get(rng, None) @@ -764,48 +783,47 @@ def find_default_update(clients, rng: Variable) -> None | Variable: [client, _] = rng_clients[0] # RNG is an output of the function, this is not a problem - if isinstance(client.op, Output): - return None + client_op = client.op - # RNG is used by another operator, which should output an update for the RNG - if isinstance(client.op, RandomVariable): - # RandomVariable first output is always the update of the input RNG - next_rng = client.outputs[0] - - elif isinstance(client.op, SymbolicRandomVariable): - # SymbolicRandomVariable have an explicit method that returns an - # update mapping for their RNG(s) - next_rng = client.op.update(client).get(rng) - if next_rng is None: - raise ValueError( - f"No update found for at least one RNG used in SymbolicRandomVariable Op {client.op}" - ) - elif isinstance(client.op, Scan): - # Check if any shared output corresponds to the RNG - rng_idx = client.inputs.index(rng) - io_map = client.op.get_oinp_iinp_iout_oout_mappings()["outer_out_from_outer_inp"] - out_idx = io_map.get(rng_idx, -1) - if out_idx != -1: - next_rng = client.outputs[out_idx] - else: # No break - raise ValueError( - f"No update found for at least one RNG used in Scan Op {client.op}.\n" - "You can use `pytensorf.collect_default_updates` inside the Scan function to return updates automatically." - ) - elif isinstance(client.op, OpFromGraph): - try: - next_rng = collect_default_updates_inner_fgraph(client).get(rng) + match client_op: + case Output(): + return None + # Otherwise, RNG is used by another operator, which should output an update for the RNG + case RandomVariable(): + # RandomVariable first output is always the update of the input RNG + next_rng = client.outputs[0] + case RNGConsumerOp(): + # RNGConsumerOp have an explicit method that returns an update mapping for their RNG(s) + # RandomVariable is a subclass of RNGConsumerOp, but we specialize above for speedup + next_rng = client_op.update(client).get(rng) if next_rng is None: - # OFG either does not make use of this RNG or inconsistent use that will have emitted a warning - return None - except ValueError as exc: - raise ValueError( - f"No update found for at least one RNG used in OpFromGraph Op {client.op}.\n" - "You can use `pytensorf.collect_default_updates` and include those updates as outputs." - ) from exc - else: - # We don't know how this RNG should be updated. The user should provide an update manually - return None + raise ValueError(f"No update found for at least one RNG used in {client_op}") + case Scan(): + # Check if any shared output corresponds to the RNG + rng_idx = client.inputs.index(rng) + io_map = client_op.get_oinp_iinp_iout_oout_mappings()["outer_out_from_outer_inp"] + out_idx = io_map.get(rng_idx, -1) + if out_idx != -1: + next_rng = client.outputs[out_idx] + else: # No break + raise ValueError( + f"No update found for at least one RNG used in Scan Op {client_op}.\n" + "You can use `pytensorf.collect_default_updates` inside the Scan function to return updates automatically." + ) + case OpFromGraph(): + try: + next_rng = collect_default_updates_inner_fgraph(client).get(rng) + if next_rng is None: + # OFG either does not make use of this RNG or inconsistent use that will have emitted a warning + return None + except ValueError as exc: + raise ValueError( + f"No update found for at least one RNG used in OpFromGraph Op {client_op}.\n" + "You can use `pytensorf.collect_default_updates` and include those updates as outputs." + ) from exc + case _: + # We don't know how this RNG should be updated. The user should provide an update manually + return None # Recurse until we find final update for RNG nested_next_rng = find_default_update(clients, next_rng) diff --git a/pymc/testing.py b/pymc/testing.py index b016c25ad..886177ef0 100644 --- a/pymc/testing.py +++ b/pymc/testing.py @@ -24,11 +24,13 @@ from arviz import InferenceData from numpy import random as nr from numpy import testing as npt +from pytensor.compile import SharedVariable from pytensor.compile.mode import Mode -from pytensor.graph.basic import Variable +from pytensor.graph.basic import Constant, Variable, equal_computations, graph_inputs from pytensor.graph.rewriting.basic import in2out from pytensor.tensor import TensorVariable from pytensor.tensor.random.op import RandomVariable +from pytensor.tensor.random.type import RandomType from scipy import special as sp from scipy import stats as st @@ -41,9 +43,8 @@ from pymc.logprob.utils import ( ParameterValueError, local_check_parameter_to_ninf_switch, - rvs_in_graph, ) -from pymc.pytensorf import compile, floatX, inputvars +from pymc.pytensorf import compile, floatX, inputvars, rvs_in_graph # This mode can be used for tests where model compilations takes the bulk of the runtime # AND where we don't care about posterior numerical or sampling stability (e.g., when @@ -971,8 +972,7 @@ def seeded_numpy_distribution_builder(dist_name: str) -> Callable: def assert_no_rvs(vars: Sequence[Variable]) -> None: """Assert that there are no `MeasurableOp` nodes in a graph.""" - rvs = rvs_in_graph(vars) - if rvs: + if rvs := rvs_in_graph(vars): raise AssertionError(f"RV found in graph: {rvs}") @@ -1086,3 +1086,28 @@ def test_model_inference(mock_pymc_sample): pm.sample = original_sample pm.Flat = original_flat pm.HalfFlat = original_half_flat + + +def equal_computations_up_to_root( + xs: Sequence[Variable], ys: Sequence[Variable], ignore_rng_values=True +) -> bool: + # Check if graphs are equivalent even if root variables have distinct identities + + x_graph_inputs = [var for var in graph_inputs(xs) if not isinstance(var, Constant)] + y_graph_inputs = [var for var in graph_inputs(ys) if not isinstance(var, Constant)] + if len(x_graph_inputs) != len(y_graph_inputs): + return False + for x, y in zip(x_graph_inputs, y_graph_inputs): + if x.type != y.type: + return False + if x.name != y.name: + return False + if isinstance(x, SharedVariable): + if not isinstance(y, SharedVariable): + return False + if isinstance(x.type, RandomType) and ignore_rng_values: + continue + if not x.type.values_eq(x.get_value(), y.get_value()): + return False + + return equal_computations(xs, ys, in_xs=x_graph_inputs, in_ys=y_graph_inputs) # type: ignore[arg-type] diff --git a/pyproject.toml b/pyproject.toml index a8ffb06ee..e999a19bb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,7 +48,9 @@ ignore = [ "D101", # Missing docstring in public class "D102", # Missing docstring in public method "D103", # Missing docstring in public function + "D104", # Missing docstring in public package "D105", # Missing docstring in magic method + "D401", # Ignore Umbridge level of control ] [tool.ruff.lint.pydocstyle] @@ -66,6 +68,13 @@ lines-between-types = 1 "pymc/__init__.py" = [ "E402", # Module level import not at top of file ] +"pymc/dims/__init__.py" = [ + "E402", # Module level import not at top of file +] +"pymc/dims/math.py" = [ + "F401", # Module imported but unused + "F403", # 'from module import *' used; unable to detect undefined names +] "pymc/stats/__init__.py" = [ "E402", # Module level import not at top of file ] diff --git a/requirements-dev.txt b/requirements-dev.txt index 840f3d806..0f8069794 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -16,7 +16,7 @@ numpydoc pandas>=0.24.0 polyagamma pre-commit>=2.8.0 -pytensor>=2.31.2,<2.32 +pytensor>=2.31.5,<2.32 pytest-cov>=2.5 pytest>=3.0 rich>=13.7.1 diff --git a/requirements.txt b/requirements.txt index c1ca979dd..f1deaaaf5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ cachetools>=4.2.1 cloudpickle numpy>=1.25.0 pandas>=0.24.0 -pytensor>=2.31.2,<2.32 +pytensor>=2.31.5,<2.32 rich>=13.7.1 scipy>=1.4.1 threadpoolctl>=3.1.0,<4.0.0 diff --git a/tests/dims/__init__.py b/tests/dims/__init__.py new file mode 100644 index 000000000..00e50af6b --- /dev/null +++ b/tests/dims/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/dims/distributions/__init__.py b/tests/dims/distributions/__init__.py new file mode 100644 index 000000000..00e50af6b --- /dev/null +++ b/tests/dims/distributions/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/dims/distributions/test_core.py b/tests/dims/distributions/test_core.py new file mode 100644 index 000000000..3f33a0ce1 --- /dev/null +++ b/tests/dims/distributions/test_core.py @@ -0,0 +1,191 @@ +# Copyright 2025 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import re + +import numpy as np +import pytest + +import pymc as pm + +from pymc import dims as pmx + +pytestmark = pytest.mark.filterwarnings("error") + + +def test_distribution_dims(): + coords = { + "a": range(2), + "b": range(3), + "c": range(5), + "d": range(7), + } + with pm.Model(coords=coords) as model: + x = pmx.Data("x", np.random.randn(2, 3, 5), dims=("a", "b", "c")) + y1 = pmx.Normal("y1", mu=x) + assert y1.type.dims == ("a", "b", "c") + assert y1.eval().shape == (2, 3, 5) + + y2 = pmx.Normal("y2", mu=x, dims=("a", "b", "c")) # redundant + assert y2.type.dims == ("a", "b", "c") + assert y2.eval().shape == (2, 3, 5) + + y3 = pmx.Normal("y3", mu=x, dims=("b", "a", "c")) # Implies a transpose + assert y3.type.dims == ("b", "a", "c") + assert y3.eval().shape == (3, 2, 5) + + y4 = pmx.Normal("y4", mu=x, dims=("a", ...)) + assert y4.type.dims == ("a", "b", "c") + assert y4.eval().shape == (2, 3, 5) + + y5 = pmx.Normal("y5", mu=x, dims=("b", ...)) # Implies a transpose + assert y5.type.dims == ("b", "a", "c") + assert y5.eval().shape == (3, 2, 5) + + y6 = pmx.Normal("y6", mu=x, dims=("b", ..., "a")) # Implies a transpose + assert y6.type.dims == ("b", "c", "a") + assert y6.eval().shape == (3, 5, 2) + + y7 = pmx.Normal("y7", mu=x, dims=(..., "b")) # Implies a transpose + assert y7.type.dims == ("a", "c", "b") + assert y7.eval().shape == (2, 5, 3) + + y8 = pmx.Normal("y8", mu=x, dims=("d", "a", "b", "c")) # Adds an extra dimension + assert y8.type.dims == ("d", "a", "b", "c") + assert y8.eval().shape == (7, 2, 3, 5) + + y9 = pmx.Normal("y9", mu=x, dims=("d", ...)) # Adds an extra dimension + assert y9.type.dims == ("d", "a", "b", "c") + assert y9.eval().shape == (7, 2, 3, 5) + + y10 = pmx.Normal( + "y10", mu=x, dims=("b", "a", "c", "d") + ) # Adds an extra dimension and implies a transpose + assert y10.type.dims == ("b", "a", "c", "d") + assert y10.eval().shape == (3, 2, 5, 7) + + y11 = pmx.Normal( + "y11", mu=x, dims=("c", ..., "d") + ) # Adds an extra dimension and implies a transpose + assert y11.type.dims == ("c", "a", "b", "d") + assert y11.eval().shape == (5, 2, 3, 7) + + # Invalid cases + err_msg = "Provided dims ('a', 'b') do not match the distribution's output dims ('a', 'b', 'c'). Use ellipsis to specify all other dimensions." + with pytest.raises(ValueError, match=re.escape(err_msg)): + # Missing a dimension + pmx.Normal("y_bad", mu=x, dims=("a", "b")) + + err_msg = "Provided dims ('d',) do not match the distribution's output dims ('d', 'a', 'b', 'c'). Use ellipsis to specify all other dimensions." + with pytest.raises(ValueError, match=re.escape(err_msg)): + # Only specifies the extra dimension + pmx.Normal("y_bad", mu=x, dims=("d",)) + + err_msg = "Not all dims ('a', 'b', 'c', 'e') are part of the model coords. Add them at initialization time or use `model.add_coord` before defining the distribution" + with pytest.raises(ValueError, match=re.escape(err_msg)): + # Extra dimension not in coords + pmx.Normal("y_bad", mu=x, dims=("a", "b", "c", "e")) + + +def test_multivariate_distribution_dims(): + coords = { + "batch": range(2), + "core1": range(3), + "core2": range(3), + } + with pm.Model(coords=coords) as m: + mu = pmx.Normal("mu", dims=("batch", "core1")) + chol, _, _ = pm.LKJCholeskyCov( + "chol", + eta=1, + n=3, + sd_dist=pm.Exponential.dist(1), + ) + chol_xr = pmx.math.as_xtensor(chol, dims=("core1", "core2")) + + x1 = pmx.MvNormal( + "x1", + mu, + chol=chol_xr, + core_dims=("core1", "core2"), + ) + assert x1.type.dims == ("batch", "core1") + assert x1.eval().shape == (2, 3) + + x2 = pmx.MvNormal( + "x2", + mu, + chol=chol_xr, + core_dims=("core1", "core2"), + dims=("batch", "core1"), + ) + assert x2.type.dims == ("batch", "core1") + assert x2.eval().shape == (2, 3) + + x3 = pmx.MvNormal( + "x3", + mu, + chol=chol_xr, + core_dims=("core2", "core1"), + dims=("batch", ...), + ) + assert x3.type.dims == ("batch", "core1") + assert x3.eval().shape == (2, 3) + + x4 = pmx.MvNormal( + "x4", + mu, + chol=chol_xr, + core_dims=("core1", "core2"), + # Implies transposition + dims=("core1", ...), + ) + assert x4.type.dims == ("core1", "batch") + assert x4.eval().shape == (3, 2) + + # Errors + err_msg = "MvNormal requires 2 core_dims" + with pytest.raises(ValueError, match=re.escape(err_msg)): + # Missing core_dims + pmx.MvNormal( + "x_bad", + mu, + chol=chol_xr, + ) + + with pytest.raises(ValueError, match="Dimension batch not found in either input"): + pmx.MvNormal( + "x_bad", + mu, + chol=chol_xr, + # Invalid because batch is not on chol_xr + core_dims=("core1", "batch"), + ) + + with pytest.raises(ValueError, match="Parameter mu_renamed has invalid core dimensions"): + mu_renamed = mu.rename({"batch": "core2"}) + mu_renamed.name = "mu_renamed" + pmx.MvNormal( + "x_bad", + mu_renamed, + chol=chol_xr, + # Invalid because mu has both core dimensions (after renaming) + core_dims=("core1", "core2"), + ) + + # Invalid because core2 is not a core output dimension + err_msg = "Dimensions {'core2'} do not exist. Expected one or more of: ('batch', 'core1')" + with pytest.raises(ValueError, match=re.escape(err_msg)): + pmx.MvNormal( + "x_bad", mu, chol=chol_xr, core_dims=("core1", "core2"), dims=("core2", ...) + ) diff --git a/tests/dims/distributions/test_scalar.py b/tests/dims/distributions/test_scalar.py new file mode 100644 index 000000000..a487591c0 --- /dev/null +++ b/tests/dims/distributions/test_scalar.py @@ -0,0 +1,217 @@ +# Copyright 2025 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pymc import Model +from pymc import distributions as regular_distributions +from pymc.dims import ( + Beta, + Cauchy, + Exponential, + Flat, + Gamma, + HalfCauchy, + HalfFlat, + HalfNormal, + InverseGamma, + Laplace, + LogNormal, + Normal, + StudentT, +) +from tests.dims.utils import assert_equivalent_logp_graph, assert_equivalent_random_graph + + +def test_flat(): + coords = {"a": range(3)} + with Model(coords=coords) as model: + Flat("x", dims="a") + + with Model(coords=coords) as reference_model: + regular_distributions.Flat("x", dims="a") + + assert_equivalent_random_graph(model, reference_model) + assert_equivalent_logp_graph(model, reference_model) + + +def test_halfflat(): + coords = {"a": range(3)} + with Model(coords=coords) as model: + HalfFlat("x", dims="a") + + with Model(coords=coords) as reference_model: + regular_distributions.HalfFlat("x", dims="a") + + assert_equivalent_random_graph(model, reference_model) + assert_equivalent_logp_graph(model, reference_model) + + +def test_normal(): + coords = {"a": range(3)} + with Model(coords=coords) as model: + Normal("x", dims="a") + Normal("y", mu=2, sigma=3, dims="a") + Normal("z", mu=-2, tau=3, dims="a") + + with Model(coords=coords) as reference_model: + regular_distributions.Normal("x", dims="a") + regular_distributions.Normal("y", mu=2, sigma=3, dims="a") + regular_distributions.Normal("z", mu=-2, tau=3, dims="a") + + assert_equivalent_random_graph(model, reference_model) + assert_equivalent_logp_graph(model, reference_model) + + +def test_halfnormal(): + coords = {"a": range(3)} + with Model(coords=coords) as model: + HalfNormal("x", dims="a") + HalfNormal("y", sigma=3, dims="a") + HalfNormal("z", tau=3, dims="a") + + with Model(coords=coords) as reference_model: + regular_distributions.HalfNormal("x", dims="a") + regular_distributions.HalfNormal("y", sigma=3, dims="a") + regular_distributions.HalfNormal("z", tau=3, dims="a") + + assert_equivalent_random_graph(model, reference_model) + assert_equivalent_logp_graph(model, reference_model) + + +def test_lognormal(): + coords = {"a": range(3)} + with Model(coords=coords) as model: + LogNormal("x", dims="a") + LogNormal("y", mu=2, sigma=3, dims="a") + LogNormal("z", mu=-2, tau=3, dims="a") + + with Model(coords=coords) as reference_model: + regular_distributions.LogNormal("x", dims="a") + regular_distributions.LogNormal("y", mu=2, sigma=3, dims="a") + regular_distributions.LogNormal("z", mu=-2, tau=3, dims="a") + + assert_equivalent_random_graph(model, reference_model) + assert_equivalent_logp_graph(model, reference_model) + + +def test_studentt(): + coords = {"a": range(3)} + with Model(coords=coords) as model: + StudentT("x", nu=1, dims="a") + StudentT("y", nu=1, mu=2, sigma=3, dims="a") + StudentT("z", nu=1, mu=-2, lam=3, dims="a") + + with Model(coords=coords) as reference_model: + regular_distributions.StudentT("x", nu=1, dims="a") + regular_distributions.StudentT("y", nu=1, mu=2, sigma=3, dims="a") + regular_distributions.StudentT("z", nu=1, mu=-2, lam=3, dims="a") + + assert_equivalent_random_graph(model, reference_model) + assert_equivalent_logp_graph(model, reference_model) + + +def test_cauchy(): + coords = {"a": range(3)} + with Model(coords=coords) as model: + Cauchy("x", alpha=1, beta=2, dims="a") + + with Model(coords=coords) as reference_model: + regular_distributions.Cauchy("x", alpha=1, beta=2, dims="a") + + assert_equivalent_random_graph(model, reference_model) + assert_equivalent_logp_graph(model, reference_model) + + +def test_halfcauchy(): + coords = {"a": range(3)} + with Model(coords=coords) as model: + HalfCauchy("x", beta=2, dims="a") + + with Model(coords=coords) as reference_model: + regular_distributions.HalfCauchy("x", beta=2, dims="a") + + assert_equivalent_random_graph(model, reference_model) + assert_equivalent_logp_graph(model, reference_model) + + +def test_beta(): + coords = {"a": range(3)} + with Model(coords=coords) as model: + Beta("w", alpha=1, beta=1, dims="a") + Beta("x", mu=0.5, sigma=0.1, dims="a") + Beta("y", mu=0.5, nu=10, dims="a") + + with Model(coords=coords) as reference_model: + regular_distributions.Beta("w", alpha=1, beta=1, dims="a") + regular_distributions.Beta("x", mu=0.5, sigma=0.1, dims="a") + regular_distributions.Beta("y", mu=0.5, nu=10, dims="a") + + assert_equivalent_random_graph(model, reference_model) + assert_equivalent_logp_graph(model, reference_model) + + +def test_laplace(): + coords = {"a": range(3)} + with Model(coords=coords) as model: + Laplace("x", dims="a") + Laplace("y", mu=1, b=2, dims="a") + + with Model(coords=coords) as reference_model: + regular_distributions.Laplace("x", mu=0, b=1, dims="a") + regular_distributions.Laplace("y", mu=1, b=2, dims="a") + + assert_equivalent_random_graph(model, reference_model) + assert_equivalent_logp_graph(model, reference_model) + + +def test_exponential(): + coords = {"a": range(3)} + with Model(coords=coords) as model: + Exponential("x", dims="a") + Exponential("y", lam=2, dims="a") + Exponential("z", scale=3, dims="a") + + with Model(coords=coords) as reference_model: + regular_distributions.Exponential("x", dims="a") + regular_distributions.Exponential("y", lam=2, dims="a") + regular_distributions.Exponential("z", scale=3, dims="a") + + assert_equivalent_random_graph(model, reference_model) + assert_equivalent_logp_graph(model, reference_model) + + +def test_gamma(): + coords = {"a": range(3)} + with Model(coords=coords) as model: + # Gamma("w", alpha=1, beta=1, dims="a") + Gamma("x", mu=2, sigma=3, dims="a") + + with Model(coords=coords) as reference_model: + # regular_distributions.Gamma("w", alpha=1, beta=1, dims="a") + regular_distributions.Gamma("x", mu=2, sigma=3, dims="a") + + assert_equivalent_random_graph(model, reference_model) + assert_equivalent_logp_graph(model, reference_model) + + +def test_inverse_gamma(): + coords = {"a": range(3)} + with Model(coords=coords) as model: + InverseGamma("w", alpha=1, beta=1, dims="a") + InverseGamma("x", mu=2, sigma=3, dims="a") + + with Model(coords=coords) as reference_model: + regular_distributions.InverseGamma("w", alpha=1, beta=1, dims="a") + regular_distributions.InverseGamma("x", mu=2, sigma=3, dims="a") + + assert_equivalent_random_graph(model, reference_model) + assert_equivalent_logp_graph(model, reference_model) diff --git a/tests/dims/distributions/test_vector.py b/tests/dims/distributions/test_vector.py new file mode 100644 index 000000000..0f08505db --- /dev/null +++ b/tests/dims/distributions/test_vector.py @@ -0,0 +1,62 @@ +# Copyright 2025 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import pytensor.tensor as pt + +from pytensor.xtensor import as_xtensor + +import pymc.distributions as regular_distributions + +from pymc import Model +from pymc.dims import Categorical, MvNormal +from tests.dims.utils import assert_equivalent_logp_graph, assert_equivalent_random_graph + + +def test_categorical(): + coords = {"a": range(3), "b": range(4)} + p = pt.as_tensor([0.1, 0.2, 0.3, 0.4]) + p_xr = as_xtensor(p, dims=("b",)) + + with Model(coords=coords) as model: + Categorical("x", p=p_xr, core_dims="b", dims=("a",)) + Categorical("y", logit_p=p_xr, core_dims="b", dims=("a",)) + + with Model(coords=coords) as reference_model: + regular_distributions.Categorical("x", p=p, dims=("a",)) + regular_distributions.Categorical("y", logit_p=p, dims=("a",)) + + assert_equivalent_random_graph(model, reference_model) + assert_equivalent_logp_graph(model, reference_model) + + +def test_mvnormal(): + coords = {"a": range(3), "b": range(2)} + mu = pt.as_tensor([1, 2]) + cov = pt.as_tensor([[1, 0.5], [0.5, 2]]) + chol = pt.as_tensor([[1, 0], [0.5, np.sqrt(1.75)]]) + + mu_xr = as_xtensor(mu, dims=("b",)) + cov_xr = as_xtensor(cov, dims=("b", "b'")) + chol_xr = as_xtensor(chol, dims=("b", "b'")) + + with Model(coords=coords) as model: + MvNormal("x", mu=mu_xr, cov=cov_xr, core_dims=("b", "b'"), dims=("a", "b")) + MvNormal("y", mu=mu_xr, chol=chol_xr, core_dims=("b", "b'"), dims=("a", "b")) + + with Model(coords=coords) as reference_model: + regular_distributions.MvNormal("x", mu=mu, cov=cov, dims=("a", "b")) + regular_distributions.MvNormal("y", mu=mu, chol=chol, dims=("a", "b")) + + assert_equivalent_random_graph(model, reference_model) + assert_equivalent_logp_graph(model, reference_model) diff --git a/tests/dims/test_model.py b/tests/dims/test_model.py new file mode 100644 index 000000000..bf27bf482 --- /dev/null +++ b/tests/dims/test_model.py @@ -0,0 +1,174 @@ +# Copyright 2025 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import pytest + +from pytensor.xtensor.type import XTensorType +from xarray import DataArray + +import pymc as pm + +from pymc import dims as pmd +from pymc import observe + +pytestmark = pytest.mark.filterwarnings("error") + + +def test_data(): + x_np = np.random.randn(10, 2, 3) + + with pm.Model() as m: + x1 = pmd.Data("x", x_np, dims=("a", "b", "c")) + assert isinstance(x1.type, XTensorType) + assert x1.type.dims == ("a", "b", "c") + + +def test_simple_model(): + coords = {"a": range(3), "b": range(5)} + + with pm.Model(coords=coords) as model: + x = pmd.Normal("x", mu=1, dims=("a", "b")) + sigma = pmd.HalfNormal("sigma", dims=("a",)) + y = pmd.Normal("y", mu=x.T * 2, sigma=sigma, dims=("b", "a")) + + with pm.Model(coords=coords) as xmodel: + x = pmd.Normal("x", mu=1, dims=("a", "b")) + sigma = pmd.HalfNormal("sigma", dims=("a",)) + # Imply a transposition + y = pmd.Normal("y", mu=x * 2, sigma=sigma, dims=("b", "a")) + + assert x.type.dims == ("a", "b") + assert sigma.type.dims == ("a",) + assert y.type.dims == ("b", "a") + + ip = model.initial_point() + xip = xmodel.initial_point() + assert ip.keys() == xip.keys() + for value, xvalue in zip(ip.values(), xip.values()): + np.testing.assert_allclose(value, xvalue) + + logp = model.compile_logp()(ip) + xlogp = xmodel.compile_logp()(xip) + np.testing.assert_allclose(logp, xlogp) + + dlogp = model.compile_dlogp()(ip) + xdlogp = xmodel.compile_dlogp()(xip) + np.testing.assert_allclose(dlogp, xdlogp) + + draw = pm.draw(xmodel["y"], random_seed=1) + draw_same = pm.draw(xmodel["y"], random_seed=1) + draw_diff = pm.draw(xmodel["y"], random_seed=2) + assert draw.shape == (5, 3) + np.testing.assert_allclose(draw, draw_same) + assert not np.allclose(draw, draw_diff) + + observed_values = DataArray(np.ones((3, 5)), dims=("a", "b")).transpose() + with observe(xmodel, {"y": observed_values}): + pm.sample_prior_predictive() + idata = pm.sample( + tune=200, chains=2, draws=50, compute_convergence_checks=False, progressbar=False + ) + pm.sample_posterior_predictive(idata, progressbar=False) + + +def test_complex_model(): + N = 100 + rng = np.random.default_rng(4) + x_np = np.linspace(0, 10, N) + y_np = np.piecewise( + x_np, + [x_np <= 3, (x_np > 3) & (x_np <= 7), x_np > 7], + [lambda x: 0.5 * x, lambda x: 1.5 + 0.2 * (x - 3), lambda x: 2.3 - 0.1 * (x - 7)], + ) + y_np += rng.normal(0, 0.2, size=N) + group_idx_np = rng.choice(3, size=N) + N_knots = 13 + knots_np = np.linspace(0, 10, num=N_knots) + + coords = { + "group": range(3), + "knots": range(N_knots), + "obs": range(N), + } + + with pm.Model(coords=coords) as model: + x = pm.Data("x", x_np, dims="obs") + knots = pm.Data("knots", knots_np, dims="knot") + + sigma = pm.HalfCauchy("sigma", beta=1) + sigma_beta0 = pm.HalfNormal("sigma_beta0", sigma=10) + beta0 = pm.HalfNormal("beta_0", sigma=sigma_beta0, dims="group") + z = pm.Normal("z", dims=("group", "knot")) + + delta_factors = pm.math.softmax(z, axis=-1) # (groups, knot) + slope_factors = 1 - delta_factors[:, :-1].cumsum(axis=-1) # (groups, knot-1) + spline_slopes = pm.math.concatenate( + [beta0[:, None], beta0[:, None] * slope_factors], axis=-1 + ) # (groups, knot-1) + beta = pm.math.concatenate( + [beta0[:, None], pm.math.diff(spline_slopes)], axis=-1 + ) # (groups, knot) + + beta = pm.Deterministic("beta", beta, dims=("group", "knot")) + + X = pm.math.maximum(0, x[:, None] - knots[None, :]) # (n, knot) + mu = (X * beta[group_idx_np]).sum(-1) # ((n, knots) * (n, knots)).sum(-1) = (n,) + y_obs = pm.Normal("y_obs", mu=mu, sigma=sigma, observed=y_np, dims="obs") + + with pm.Model(coords=coords) as xmodel: + x = pmd.Data("x", x_np, dims="obs") + y = pmd.Data("y", y_np, dims="obs") + knots = pmd.Data("knots", knots_np, dims=("knot",)) + group_idx = pmd.math.as_xtensor(group_idx_np, dims=("obs",)) + + sigma = pmd.HalfCauchy("sigma", beta=1) + sigma_beta0 = pmd.HalfNormal("sigma_beta0", sigma=10) + beta0 = pmd.HalfNormal("beta_0", sigma=sigma_beta0, dims=("group",)) + z = pmd.Normal("z", dims=("group", "knot")) + + delta_factors = pmd.math.softmax(z, dim="knot") + slope_factors = 1 - delta_factors.isel(knot=slice(None, -1)).cumsum("knot") + spline_slopes = pmd.concat([beta0, beta0 * slope_factors], dim="knot") + beta = pmd.concat([beta0, spline_slopes.diff("knot")], dim="knot") + + beta = pm.Deterministic("beta", beta) + + X = pmd.math.maximum(0, x - knots) + mu = (X * beta.isel(group=group_idx)).sum("knot") + y_obs = pmd.Normal("y_obs", mu=mu, sigma=sigma, observed=y) + + # Test initial point + model_ip = model.initial_point() + xmodel_ip = xmodel.initial_point() + assert model_ip.keys() == xmodel_ip.keys() + for value, xvalue in zip(model_ip.values(), xmodel_ip.values()): + np.testing.assert_allclose(value, xvalue) + + # Test logp + model_logp = model.compile_logp()(model_ip) + xmodel_logp = xmodel.compile_logp()(xmodel_ip) + np.testing.assert_allclose(model_logp, xmodel_logp) + + # Test random draws + model_draw = pm.draw(model["y_obs"], random_seed=1) + xmodel_draw = pm.draw(xmodel["y_obs"], random_seed=1) + np.testing.assert_allclose(model_draw, xmodel_draw) + np.testing.assert_allclose(model_draw, xmodel_draw) + + with xmodel: + pm.sample_prior_predictive() + idata = pm.sample( + tune=200, chains=2, draws=50, compute_convergence_checks=False, progressbar=False + ) + pm.sample_posterior_predictive(idata, progressbar=False) diff --git a/tests/dims/utils.py b/tests/dims/utils.py new file mode 100644 index 000000000..e84dba4a5 --- /dev/null +++ b/tests/dims/utils.py @@ -0,0 +1,64 @@ +# Copyright 2025 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pytensor.graph import rewrite_graph +from pytensor.printing import debugprint + +from pymc import Model +from pymc.testing import equal_computations_up_to_root + + +def assert_equivalent_random_graph(model: Model, reference_model: Model) -> bool: + """Check if the random graph of a model with xtensor variables is equivalent.""" + lowered_model = rewrite_graph( + model.basic_RVs + model.deterministics + model.potentials, + include=( + "lower_xtensor", + "inline_ofg_expansion_xtensor", + "canonicalize", + "local_remove_all_assert", + ), + ) + reference_lowered_model = rewrite_graph( + reference_model.basic_RVs + reference_model.deterministics + reference_model.potentials, + include=( + "inline_ofg_expansion", + "canonicalize", + "local_remove_all_assert", + ), + ) + assert equal_computations_up_to_root( + lowered_model, + reference_lowered_model, + ignore_rng_values=True, + ), debugprint(lowered_model + reference_lowered_model, print_type=True) + + +def assert_equivalent_logp_graph(model: Model, reference_model: Model) -> bool: + """Check if the logp graph of a model with xtensor variables is equivalent.""" + lowered_model_logp = rewrite_graph( + [model.logp()], + include=("lower_xtensor", "canonicalize", "local_remove_all_assert"), + ) + reference_lowered_model_logp = rewrite_graph( + [reference_model.logp()], + include=("canonicalize", "local_remove_all_assert"), + ) + assert equal_computations_up_to_root( + lowered_model_logp, + reference_lowered_model_logp, + ignore_rng_values=False, + ), debugprint( + lowered_model_logp + reference_lowered_model_logp, + print_type=True, + ) diff --git a/tests/distributions/test_continuous.py b/tests/distributions/test_continuous.py index 9c108d203..720938266 100644 --- a/tests/distributions/test_continuous.py +++ b/tests/distributions/test_continuous.py @@ -693,7 +693,8 @@ def test_inverse_gamma_logcdf(self): ) def test_inverse_gamma_alt_params(self): def test_fun(value, mu, sigma): - alpha, beta = pm.InverseGamma._get_alpha_beta(None, None, mu, sigma) + alpha = (2 * sigma**2 + mu**2) / sigma**2 + beta = mu * (mu**2 + sigma**2) / sigma**2 return st.invgamma.logpdf(value, alpha, scale=beta) check_logp( diff --git a/tests/distributions/test_shape_utils.py b/tests/distributions/test_shape_utils.py index d0f3f1b43..8d787df61 100644 --- a/tests/distributions/test_shape_utils.py +++ b/tests/distributions/test_shape_utils.py @@ -31,6 +31,7 @@ from pymc.distributions.shape_utils import ( change_dist_size, convert_dims, + convert_dims_with_ellipsis, convert_shape, convert_size, get_support_shape, @@ -297,6 +298,14 @@ def test_convert_dims(self): assert convert_dims(dims="town") == ("town",) with pytest.raises(ValueError, match="must be a tuple, str or list"): convert_dims(3) + with pytest.raises(ValueError, match="must be a tuple, str or list"): + convert_dims(...) + + def test_convert_dims_with_ellipsis(self): + assert convert_dims_with_ellipsis(dims="town") == ("town",) + assert convert_dims_with_ellipsis(...) == (...,) + with pytest.raises(ValueError, match="must be a tuple, list, str or Ellipsis"): + convert_dims_with_ellipsis(3) def test_convert_shape(self): assert convert_shape(5) == (5,) From 35b7dda6f8aa2538166fb0cf2b4b9dc2749ea2ac Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 30 Jun 2025 12:47:41 +0200 Subject: [PATCH 05/11] Allow registering XTensorVariables directly in model --- pymc/dims/__init__.py | 14 ++- pymc/dims/distributions/core.py | 127 ++++++++++++++++++++++---- pymc/dims/distributions/transforms.py | 53 +++++++++++ pymc/dims/model.py | 30 +----- pymc/initial_point.py | 20 +++- pymc/logprob/basic.py | 12 +-- pymc/logprob/rewriting.py | 7 +- pymc/model/core.py | 14 ++- pymc/pytensorf.py | 25 +++-- pymc/step_methods/metropolis.py | 8 +- scripts/run_mypy.py | 1 - tests/dims/test_model.py | 2 +- tests/dims/utils.py | 11 ++- 13 files changed, 245 insertions(+), 79 deletions(-) create mode 100644 pymc/dims/distributions/transforms.py diff --git a/pymc/dims/__init__.py b/pymc/dims/__init__.py index aa9a3c983..a0b77b165 100644 --- a/pymc/dims/__init__.py +++ b/pymc/dims/__init__.py @@ -36,9 +36,15 @@ def __init__(): # Make PyMC aware of xtensor functionality MeasurableOp.register(XRV) - lower_xtensor_query = optdb.query("+lower_xtensor") - logprob_rewrites_db.register("lower_xtensor", lower_xtensor_query, "basic", position=0.1) - initial_point_rewrites_db.register("lower_xtensor", lower_xtensor_query, "basic", position=0.1) + logprob_rewrites_db.register( + "pre_lower_xtensor", optdb.query("+lower_xtensor"), "basic", position=0.1 + ) + logprob_rewrites_db.register( + "post_lower_xtensor", optdb.query("+lower_xtensor"), "cleanup", position=5.1 + ) + initial_point_rewrites_db.register( + "lower_xtensor", optdb.query("+lower_xtensor"), "basic", position=0.1 + ) # TODO: Better model of probability of bugs day_of_conception = datetime.date(2025, 6, 17) @@ -64,4 +70,4 @@ def __init__(): from pymc.dims import math from pymc.dims.distributions import * -from pymc.dims.model import Data, Deterministic, Potential, with_dims +from pymc.dims.model import Data, Deterministic, Potential diff --git a/pymc/dims/distributions/core.py b/pymc/dims/distributions/core.py index bd48db0ec..fb06457dc 100644 --- a/pymc/dims/distributions/core.py +++ b/pymc/dims/distributions/core.py @@ -13,18 +13,26 @@ # limitations under the License. from collections.abc import Callable, Sequence from itertools import chain +from typing import cast +import numpy as np + +from pytensor.graph import node_rewriter from pytensor.graph.basic import Variable from pytensor.tensor.elemwise import DimShuffle +from pytensor.tensor.random.op import RandomVariable from pytensor.xtensor import as_xtensor +from pytensor.xtensor.basic import XTensorFromTensor, xtensor_from_tensor from pytensor.xtensor.type import XTensorVariable -from pymc import modelcontext -from pymc.dims.model import with_dims -from pymc.distributions import transforms +from pymc import SymbolicRandomVariable, modelcontext +from pymc.dims.distributions.transforms import DimTransform, log_odds_transform, log_transform from pymc.distributions.distribution import _support_point, support_point from pymc.distributions.shape_utils import DimsWithEllipsis, convert_dims_with_ellipsis -from pymc.logprob.transforms import Transform +from pymc.logprob.abstract import MeasurableOp, _logprob +from pymc.logprob.rewriting import measurable_ir_rewrites_db +from pymc.logprob.tensor import MeasurableDimShuffle +from pymc.logprob.utils import filter_measurable_variables from pymc.util import UNSET @@ -36,25 +44,98 @@ def dimshuffle_support_point(ds_op, _, rv): return ds_op(support_point(rv)) +@_support_point.register(XTensorFromTensor) +def xtensor_from_tensor_support_point(xtensor_op, _, rv): + # We remove the xtensor_from_tensor operation, so initial_point doesn't have to do a further lowering + return xtensor_op(support_point(rv)) + + +class MeasurableXTensorFromTensor(MeasurableOp, XTensorFromTensor): + __props__ = ("dims", "core_dims") # type: ignore[assignment] + + def __init__(self, dims, core_dims): + super().__init__(dims=dims) + self.core_dims = tuple(core_dims) if core_dims is not None else None + + +@node_rewriter([XTensorFromTensor]) +def find_measurable_xtensor_from_tensor(fgraph, node) -> list[XTensorVariable] | None: + if isinstance(node.op, MeasurableXTensorFromTensor): + return None + + xs = filter_measurable_variables(node.inputs) + + if not xs: + # Check if we have a transposition instead + # The rewrite that introduces measurable tranpsoses refuses to apply to multivariate RVs + # So we have a chance of inferring the core dims! + [ds] = node.inputs + ds_node = ds.owner + if not ( + ds_node is not None + and isinstance(ds_node.op, DimShuffle) + and ds_node.op.is_transpose + and filter_measurable_variables(ds_node.inputs) + ): + return None + [x] = ds_node.inputs + if not ( + x.owner is not None and isinstance(x.owner.op, RandomVariable | SymbolicRandomVariable) + ): + return None + + measurable_x = MeasurableDimShuffle(**ds_node.op._props_dict())(x) # type: ignore[attr-defined] + + ndim_supp = x.owner.op.ndim_supp + if ndim_supp: + inverse_transpose = np.argsort(ds_node.op.shuffle) + dims = node.op.dims + dims_before_transpose = tuple(dims[i] for i in inverse_transpose) + core_dims = dims_before_transpose[-ndim_supp:] + else: + core_dims = () + + new_out = MeasurableXTensorFromTensor(dims=node.op.dims, core_dims=core_dims)(measurable_x) + else: + # If this happens we know there's no measurable transpose in between and we can + # safely infer the core_dims positionally when the inner logp is returned + new_out = MeasurableXTensorFromTensor(dims=node.op.dims, core_dims=None)(*node.inputs) + return [cast(XTensorVariable, new_out)] + + +@_logprob.register(MeasurableXTensorFromTensor) +def measurable_xtensor_from_tensor(op, values, rv, **kwargs): + rv_logp = _logprob(rv.owner.op, tuple(v.values for v in values), *rv.owner.inputs, **kwargs) + if op.core_dims is None: + # The core_dims of the inner rv are on the right + dims = op.dims[: rv_logp.ndim] + else: + # We inferred where the core_dims are! + dims = [d for d in op.dims if d not in op.core_dims] + return xtensor_from_tensor(rv_logp, dims=dims) + + +measurable_ir_rewrites_db.register( + "measurable_xtensor_from_tensor", find_measurable_xtensor_from_tensor, "basic", "xtensor" +) + + class DimDistribution: """Base class for PyMC distribution that wrap pytensor.xtensor.random operations, and follow xarray-like semantics.""" xrv_op: Callable - default_transform: Transform | None = None + default_transform: DimTransform | None = None @staticmethod def _as_xtensor(x): try: return as_xtensor(x) except TypeError: - try: - return with_dims(x) - except ValueError: - raise ValueError( - f"Variable {x} must have dims associated with it.\n" - "To avoid subtle bugs, PyMC does not make any assumptions about the dims of parameters.\n" - "Use `as_xtensor` with the `dims` keyword argument to specify the dims explicitly." - ) + raise ValueError( + f"Variable {x} must have dims associated with it.\n" + "To avoid subtle bugs, PyMC does not make any assumptions about the dims of parameters.\n" + "Use `as_xtensor` with the `dims` keyword argument to specify the dims explicitly." + ) def __new__( cls, @@ -119,10 +200,22 @@ def __new__( else: # Align observed dims with those of the RV # TODO: If this fails give a more informative error message - observed = observed.transpose(*rv_dims).values + observed = observed.transpose(*rv_dims) + + # Check user didn't pass regular transforms + if transform not in (UNSET, None): + if not isinstance(transform, DimTransform): + raise TypeError( + f"Transform must be a DimTransform, form pymc.dims.transforms, but got {type(transform)}." + ) + if default_transform not in (UNSET, None): + if not isinstance(default_transform, DimTransform): + raise TypeError( + f"default_transform must be a DimTransform, from pymc.dims.transforms, but got {type(default_transform)}." + ) rv = model.register_rv( - rv.values, + rv, name=name, observed=observed, total_size=total_size, @@ -182,10 +275,10 @@ def dist(self, *args, core_dims: str | Sequence[str] | None = None, **kwargs): class PositiveDimDistribution(DimDistribution): """Base class for positive continuous distributions.""" - default_transform = transforms.log + default_transform = log_transform class UnitDimDistribution(DimDistribution): """Base class for unit-valued distributions.""" - default_transform = transforms.logodds + default_transform = log_odds_transform diff --git a/pymc/dims/distributions/transforms.py b/pymc/dims/distributions/transforms.py new file mode 100644 index 000000000..6805d1b5c --- /dev/null +++ b/pymc/dims/distributions/transforms.py @@ -0,0 +1,53 @@ +# Copyright 2025 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytensor.xtensor as ptx + +from pymc.logprob.transforms import Transform + + +class DimTransform(Transform): + """Base class for transforms that are applied to dim distriubtions.""" + + +class LogTransform(DimTransform): + name = "log" + + def forward(self, value, *inputs): + return ptx.math.log(value) + + def backward(self, value, *inputs): + return ptx.math.exp(value) + + def log_jac_det(self, value, *inputs): + return value + + +log_transform = LogTransform() + + +class LogOddsTransform(DimTransform): + name = "logodds" + + def backward(self, value, *inputs): + return ptx.math.expit(value) + + def forward(self, value, *inputs): + return ptx.math.log(value / (1 - value)) + + def log_jac_det(self, value, *inputs): + sigmoid_value = ptx.math.sigmoid(value) + return ptx.math.log(sigmoid_value) + ptx.math.log1p(-sigmoid_value) + + +log_odds_transform = LogOddsTransform() diff --git a/pymc/dims/model.py b/pymc/dims/model.py index e76bf06ec..263497575 100644 --- a/pymc/dims/model.py +++ b/pymc/dims/model.py @@ -15,7 +15,6 @@ from pytensor.tensor import TensorVariable from pytensor.xtensor import as_xtensor -from pytensor.xtensor.basic import TensorFromXTensor from pytensor.xtensor.type import XTensorVariable from pymc.data import Data as RegularData @@ -30,38 +29,15 @@ from pymc.model.core import Potential as RegularPotential -def with_dims(x: TensorVariable | XTensorVariable, model: Model | None = None) -> XTensorVariable: - """Recover the dims of a variable that was registered in the Model.""" - if isinstance(x, XTensorVariable): - return x - - if (x.owner is not None) and isinstance(x.owner.op, TensorFromXTensor): - dims = x.owner.inputs[0].type.dims - return as_xtensor(x, dims=dims, name=x.name) - - # Try accessing the model context to get dims - try: - model = modelcontext(model) - if ( - model.named_vars.get(x.name, None) is x - and (dims := model.named_vars_to_dims.get(x.name, None)) is not None - ): - return as_xtensor(x, dims=dims, name=x.name) - except TypeError: - pass - - raise ValueError(f"variable {x} doesn't have dims associated with it") - - def Data( name: str, value, dims: Dims = None, model: Model | None = None, **kwargs ) -> XTensorVariable: """Wrapper around regular Data that returns an XtensorVariable.""" model = modelcontext(model) - dims = convert_dims(dims) + dims = convert_dims(dims) # type: ignore[assignment] with model: - value = RegularData(name, value, dims=dims, **kwargs) + value = RegularData(name, value, dims=dims, **kwargs) # type: ignore[arg-type] dims = model.named_vars_to_dims[value.name] if dims is None and value.ndim > 0: @@ -84,8 +60,6 @@ def _register_and_return_xtensor_variable( value = value.transpose(*dims) # Regardless of whether dims are provided, we now have them dims = value.type.dims - # Register the equivalent TensorVariable with the model so it doesn't see XTensorVariables directly. - value = value.values # type: ignore[union-attr] value = registration_func(name, value, dims=dims, model=model) diff --git a/pymc/initial_point.py b/pymc/initial_point.py index 4704979e3..87842a998 100644 --- a/pymc/initial_point.py +++ b/pymc/initial_point.py @@ -20,7 +20,8 @@ import pytensor import pytensor.tensor as pt -from pytensor.graph.basic import Constant, Variable +from pytensor.compile.ops import TypeCastingOp +from pytensor.graph.basic import Apply, Constant, Variable from pytensor.graph.fg import FunctionGraph from pytensor.graph.rewriting.db import RewriteDatabaseQuery, SequenceDB from pytensor.tensor.variable import TensorVariable @@ -195,6 +196,14 @@ def inner(seed, *args, **kwargs): return make_seeded_function(func) +class InitialPoint(TypeCastingOp): + def make_node(self, var): + return Apply(self, [var], [var.type()]) + + +initial_point_op = InitialPoint() + + def make_initial_point_expression( *, free_rvs: Sequence[TensorVariable], @@ -235,6 +244,9 @@ def make_initial_point_expression( # Clone free_rvs so we don't modify the original graph initial_point_fgraph = FunctionGraph(outputs=free_rvs, clone=True) + # Wrap each rv in an initial_point Operation to avoid losing dependency between the RVs + replacements = tuple((rv, initial_point_op(rv)) for rv in initial_point_fgraph.outputs) + toposort_replace(initial_point_fgraph, replacements, reverse=True) # Apply any rewrites necessary to compute the initial points. initial_point_rewriter = initial_point_rewrites_db.query(initial_point_basic_query) @@ -245,7 +257,9 @@ def make_initial_point_expression( initial_values = [] initial_values_transformed = [] - for old_variable, variable in zip(free_rvs, free_rvs_clone): + for old_variable, ip_variable in zip(free_rvs, free_rvs_clone): + # Extract the variable from the initial_point operation + [variable] = ip_variable.owner.inputs strategy = initval_strategies.get(old_variable) if strategy is None: @@ -257,7 +271,7 @@ def make_initial_point_expression( value = support_point(variable) except NotImplementedError: warnings.warn( - f"Moment not defined for variable {variable} of type " + f"support_point not defined for variable {variable} of type " f"{variable.owner.op.__class__.__name__}, defaulting to " f"a draw from the prior. This can lead to difficulties " f"during tuning. You can manually define an initval or " diff --git a/pymc/logprob/basic.py b/pymc/logprob/basic.py index 9fb76df16..76483ad69 100644 --- a/pymc/logprob/basic.py +++ b/pymc/logprob/basic.py @@ -197,7 +197,7 @@ def normal_logp(value, mu, sigma): [ir_valued_var] = fgraph.outputs [ir_rv, ir_value] = ir_valued_var.owner.inputs expr = _logprob_helper(ir_rv, ir_value, **kwargs) - cleanup_ir([expr]) + [expr] = cleanup_ir([expr]) if warn_rvs: _warn_rvs_in_inferred_graph(expr) return expr @@ -297,7 +297,7 @@ def normal_logcdf(value, mu, sigma): [ir_valued_rv] = fgraph.outputs [ir_rv, ir_value] = ir_valued_rv.owner.inputs expr = _logcdf_helper(ir_rv, ir_value, **kwargs) - cleanup_ir([expr]) + [expr] = cleanup_ir([expr]) if warn_rvs: _warn_rvs_in_inferred_graph(expr) return expr @@ -379,7 +379,7 @@ def icdf(rv: TensorVariable, value: TensorLike, warn_rvs=True, **kwargs) -> Tens [ir_valued_rv] = fgraph.outputs [ir_rv, ir_value] = ir_valued_rv.owner.inputs expr = _icdf_helper(ir_rv, ir_value, **kwargs) - cleanup_ir([expr]) + [expr] = cleanup_ir([expr]) if warn_rvs: _warn_rvs_in_inferred_graph(expr) return expr @@ -540,15 +540,15 @@ def conditional_logp( f"The logprob terms of the following value variables could not be derived: {missing_value_terms}" ) - logprobs = list(values_to_logprobs.values()) - cleanup_ir(logprobs) + values, logprobs = zip(*values_to_logprobs.items()) + logprobs = cleanup_ir(logprobs) if warn_rvs: rvs_in_logp_expressions = _find_unallowed_rvs_in_graph(logprobs) if rvs_in_logp_expressions: warnings.warn(RVS_IN_JOINT_LOGP_GRAPH_MSG % rvs_in_logp_expressions, UserWarning) - return values_to_logprobs + return dict(zip(values, logprobs)) def transformed_conditional_logp( diff --git a/pymc/logprob/rewriting.py b/pymc/logprob/rewriting.py index ea0202f00..af0d8d01e 100644 --- a/pymc/logprob/rewriting.py +++ b/pymc/logprob/rewriting.py @@ -133,6 +133,8 @@ def remove_DiracDelta(fgraph, node): logprob_rewrites_basic_query = RewriteDatabaseQuery(include=["basic"]) +logprob_rewrites_cleanup_query = RewriteDatabaseQuery(include=["cleanup"]) + logprob_rewrites_db = SequenceDB() logprob_rewrites_db.name = "logprob_rewrites_db" @@ -276,10 +278,11 @@ def construct_ir_fgraph( return fgraph -def cleanup_ir(vars: Sequence[Variable]) -> None: +def cleanup_ir(vars: Sequence[Variable]) -> Sequence[Variable]: fgraph = FunctionGraph(outputs=vars, clone=False) - ir_rewriter = logprob_rewrites_db.query(RewriteDatabaseQuery(include=["cleanup"])) + ir_rewriter = logprob_rewrites_db.query(logprob_rewrites_cleanup_query) ir_rewriter.rewrite(fgraph) + return fgraph.outputs def assume_valued_outputs(outputs: Sequence[TensorVariable]) -> Sequence[TensorVariable]: diff --git a/pymc/model/core.py b/pymc/model/core.py index 66e633e15..5ec5c0ec3 100644 --- a/pymc/model/core.py +++ b/pymc/model/core.py @@ -35,6 +35,8 @@ from pytensor.compile import DeepCopyOp, Function, ProfileStats, get_mode from pytensor.compile.sharedvalue import SharedVariable from pytensor.graph.basic import Constant, Variable, ancestors, graph_inputs +from pytensor.tensor import as_tensor +from pytensor.tensor.math import variadic_add from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.random.type import RandomType from pytensor.tensor.variable import TensorConstant, TensorVariable @@ -231,7 +233,9 @@ def __init__( grads = pytensor.grad(cost, grad_vars, disconnected_inputs="ignore") for grad_wrt, var in zip(grads, grad_vars): grad_wrt.name = f"{var.name}_grad" - grads = pt.join(0, *[pt.atleast_1d(grad.ravel()) for grad in grads]) + grads = pt.join( + 0, *[as_tensor(grad, allow_xtensor_conversion=True).ravel() for grad in grads] + ) outputs = [cost, grads] else: outputs = [cost] @@ -708,7 +712,9 @@ def logp( if not sum: return logp_factors - logp_scalar = pt.sum([pt.sum(factor) for factor in logp_factors]) + logp_scalar = variadic_add( + *(as_tensor(factor, allow_xtensor_conversion=True).sum() for factor in logp_factors) + ) logp_scalar_name = "__logp" if jacobian else "__logp_nojac" if self.name: logp_scalar_name = f"{logp_scalar_name}_{self.name}" @@ -1328,7 +1334,7 @@ def make_obs_var( else: if sps.issparse(data): data = sparse.basic.as_sparse(data, name=name) - else: + elif not isinstance(data, Variable): data = pt.as_tensor_variable(data, name=name) if total_size: @@ -1781,7 +1787,7 @@ def point_logps(self, point=None, round_vals=2, **kwargs): point = self.initial_point() factors = self.basic_RVs + self.potentials - factor_logps_fn = [pt.sum(factor) for factor in self.logp(factors, sum=False)] + factor_logps_fn = [factor.sum() for factor in self.logp(factors, sum=False)] return { factor.name: np.round(np.asarray(factor_logp), round_vals) for factor, factor_logp in zip( diff --git a/pymc/pytensorf.py b/pymc/pytensorf.py index 3ce308147..1226e6ce9 100644 --- a/pymc/pytensorf.py +++ b/pymc/pytensorf.py @@ -46,7 +46,7 @@ from pytensor.tensor.random.var import RandomGeneratorSharedVariable from pytensor.tensor.rewriting.basic import topo_unconditional_constant_folding from pytensor.tensor.rewriting.shape import ShapeFeature -from pytensor.tensor.sharedvar import SharedVariable, TensorSharedVariable +from pytensor.tensor.sharedvar import SharedVariable from pytensor.tensor.subtensor import AdvancedIncSubtensor, AdvancedIncSubtensor1 from pytensor.tensor.variable import TensorVariable @@ -300,7 +300,9 @@ def smarttypeX(x): def gradient1(f, v): """Flat gradient of f wrt v.""" - return pt.flatten(grad(f, v, disconnected_inputs="warn")) + return pt.as_tensor( + grad(f, v, disconnected_inputs="warn"), allow_xtensor_conversion=True + ).ravel() empty_gradient = pt.zeros(0, dtype="float32") @@ -419,11 +421,11 @@ def make_shared_replacements(point, vars, model): def join_nonshared_inputs( point: dict[str, np.ndarray], - outputs: list[TensorVariable], - inputs: list[TensorVariable], - shared_inputs: dict[TensorVariable, TensorSharedVariable] | None = None, + outputs: Sequence[Variable], + inputs: Sequence[Variable], + shared_inputs: dict[Variable, Variable] | None = None, make_inputs_shared: bool = False, -) -> tuple[list[TensorVariable], TensorVariable]: +) -> tuple[Sequence[Variable], TensorVariable]: """ Create new outputs and input TensorVariables where the non-shared inputs are joined in a single raveled vector input. @@ -548,7 +550,9 @@ def join_nonshared_inputs( if not inputs: raise ValueError("Empty list of input variables.") - raveled_inputs = pt.concatenate([var.ravel() for var in inputs]) + raveled_inputs = pt.concatenate( + [pt.as_tensor(var, allow_xtensor_conversion=True).ravel() for var in inputs] + ) if not make_inputs_shared: tensor_type = raveled_inputs.type @@ -560,12 +564,15 @@ def join_nonshared_inputs( if pytensor.config.compute_test_value != "off": joined_inputs.tag.test_value = raveled_inputs.tag.test_value - replace: dict[TensorVariable, TensorVariable] = {} + replace: dict[Variable, Variable] = {} last_idx = 0 for var in inputs: shape = point[var.name].shape arr_len = np.prod(shape, dtype=int) - replace[var] = joined_inputs[last_idx : last_idx + arr_len].reshape(shape).astype(var.dtype) + replacement_var = ( + joined_inputs[last_idx : last_idx + arr_len].reshape(shape).astype(var.dtype) + ) + replace[var] = var.type.filter_variable(replacement_var) last_idx += arr_len if shared_inputs is not None: diff --git a/pymc/step_methods/metropolis.py b/pymc/step_methods/metropolis.py index 70c650653..6bf3d6b9e 100644 --- a/pymc/step_methods/metropolis.py +++ b/pymc/step_methods/metropolis.py @@ -13,7 +13,7 @@ # limitations under the License. from collections.abc import Callable from dataclasses import field -from typing import Any +from typing import Any, cast import numpy as np import numpy.random as nr @@ -22,6 +22,7 @@ import scipy.special from pytensor import tensor as pt +from pytensor.graph.basic import Variable from pytensor.graph.fg import MissingInputError from pytensor.tensor.random.basic import BernoulliRV, CategoricalRV from rich.progress import TextColumn @@ -1263,7 +1264,10 @@ def delta_logp( compile_kwargs: dict | None, ) -> pytensor.compile.Function: [logp0], inarray0 = join_nonshared_inputs( - point=point, outputs=[logp], inputs=vars, shared_inputs=shared + point=point, + outputs=[logp], + inputs=vars, + shared_inputs=cast(dict[Variable, Variable], shared), ) tensor_type = inarray0.type diff --git a/scripts/run_mypy.py b/scripts/run_mypy.py index 032fbc938..3f104684a 100755 --- a/scripts/run_mypy.py +++ b/scripts/run_mypy.py @@ -44,7 +44,6 @@ pymc/model/transform/conditioning.py pymc/pytensorf.py pymc/sampling/jax.py -pymc/sampling/mcmc.py """ diff --git a/tests/dims/test_model.py b/tests/dims/test_model.py index bf27bf482..c67f629b9 100644 --- a/tests/dims/test_model.py +++ b/tests/dims/test_model.py @@ -73,7 +73,7 @@ def test_simple_model(): np.testing.assert_allclose(draw, draw_same) assert not np.allclose(draw, draw_diff) - observed_values = DataArray(np.ones((3, 5)), dims=("a", "b")).transpose() + observed_values = DataArray(np.ones((3, 5)), dims=("a", "b")) with observe(xmodel, {"y": observed_values}): pm.sample_prior_predictive() idata = pm.sample( diff --git a/tests/dims/utils.py b/tests/dims/utils.py index e84dba4a5..07a340695 100644 --- a/tests/dims/utils.py +++ b/tests/dims/utils.py @@ -11,8 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from pytensor import graph_replace from pytensor.graph import rewrite_graph from pytensor.printing import debugprint +from pytensor.xtensor import as_xtensor from pymc import Model from pymc.testing import equal_computations_up_to_root @@ -21,7 +23,7 @@ def assert_equivalent_random_graph(model: Model, reference_model: Model) -> bool: """Check if the random graph of a model with xtensor variables is equivalent.""" lowered_model = rewrite_graph( - model.basic_RVs + model.deterministics + model.potentials, + [var.values for var in model.basic_RVs + model.deterministics + model.potentials], include=( "lower_xtensor", "inline_ofg_expansion_xtensor", @@ -46,8 +48,13 @@ def assert_equivalent_random_graph(model: Model, reference_model: Model) -> bool def assert_equivalent_logp_graph(model: Model, reference_model: Model) -> bool: """Check if the logp graph of a model with xtensor variables is equivalent.""" + # Replace xtensor value variables by tensor value variables + replacements = { + var: as_xtensor(var.values.clone(name=var.name), dims=var.dims) for var in model.value_vars + } + model_logp = graph_replace(model.logp(), replacements) lowered_model_logp = rewrite_graph( - [model.logp()], + [model_logp], include=("lower_xtensor", "canonicalize", "local_remove_all_assert"), ) reference_lowered_model_logp = rewrite_graph( From d86dba48e10ad267c400a095946e0af0c36508a4 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 30 Jun 2025 12:51:01 +0200 Subject: [PATCH 06/11] Allow Dim version of simple SymbolicRandomVariables --- pymc/dims/distributions/scalar.py | 19 ++++++++++++++++- pymc/distributions/distribution.py | 28 ++++++++++++++++++++++--- tests/dims/distributions/test_scalar.py | 17 +++++++++++++++ 3 files changed, 60 insertions(+), 4 deletions(-) diff --git a/pymc/dims/distributions/scalar.py b/pymc/dims/distributions/scalar.py index 540f69cf8..80069c638 100644 --- a/pymc/dims/distributions/scalar.py +++ b/pymc/dims/distributions/scalar.py @@ -14,6 +14,8 @@ import pytensor.xtensor as ptx import pytensor.xtensor.random as pxr +from pytensor.xtensor import as_xtensor + from pymc.dims.distributions.core import ( DimDistribution, PositiveDimDistribution, @@ -21,7 +23,7 @@ ) from pymc.distributions.continuous import Beta as RegularBeta from pymc.distributions.continuous import Gamma as RegularGamma -from pymc.distributions.continuous import flat, halfflat +from pymc.distributions.continuous import HalfStudentTRV, flat, halfflat def _get_sigma_from_either_sigma_or_tau(*, sigma, tau): @@ -89,6 +91,21 @@ def dist(cls, nu, mu=0, sigma=None, *, lam=None, **kwargs): return super().dist([nu, mu, sigma], **kwargs) +class HalfStudentT(PositiveDimDistribution): + @classmethod + def dist(cls, nu, sigma=None, *, lam=None, **kwargs): + sigma = _get_sigma_from_either_sigma_or_tau(sigma=sigma, tau=lam) + return super().dist([nu, sigma], **kwargs) + + @classmethod + def xrv_op(self, nu, sigma, core_dims=None, extra_dims=None, rng=None): + nu = as_xtensor(nu) + sigma = as_xtensor(sigma) + core_rv = HalfStudentTRV.rv_op(nu=nu.values, sigma=sigma.values).owner.op + xop = pxr._as_xrv(core_rv) + return xop(nu, sigma, core_dims=core_dims, extra_dims=extra_dims, rng=rng) + + class Cauchy(DimDistribution): xrv_op = pxr.cauchy diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index f8712c51e..213af82f4 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -32,7 +32,7 @@ from pytensor.graph.rewriting.basic import in2out from pytensor.graph.utils import MetaType from pytensor.tensor.basic import as_tensor_variable -from pytensor.tensor.random.op import RandomVariable +from pytensor.tensor.random.op import RandomVariable, RNGConsumerOp from pytensor.tensor.random.rewriting import local_subtensor_rv_lift from pytensor.tensor.random.utils import normalize_size_param from pytensor.tensor.rewriting.shape import ShapeFeature @@ -207,7 +207,7 @@ def __get__(self, owner_self, owner_cls): return self.fget(owner_self if owner_self is not None else owner_cls) -class SymbolicRandomVariable(MeasurableOp, OpFromGraph): +class SymbolicRandomVariable(MeasurableOp, RNGConsumerOp, OpFromGraph): """Symbolic Random Variable. This is a subclasse of `OpFromGraph` which is used to encapsulate the symbolic @@ -294,7 +294,10 @@ def default_output(cls_or_self) -> int | None: @staticmethod def get_input_output_type_idxs( extended_signature: str | None, - ) -> tuple[tuple[tuple[int], int | None, tuple[int]], tuple[tuple[int], tuple[int]]]: + ) -> tuple[ + tuple[tuple[int, ...], int | None, tuple[int, ...]], + tuple[tuple[int, ...], tuple[int, ...]], + ]: """Parse extended_signature and return indexes for *[rng], [size] and parameters as well as outputs.""" if extended_signature is None: raise ValueError("extended_signature must be provided") @@ -367,8 +370,27 @@ def __init__( kwargs.setdefault("inline", True) kwargs.setdefault("strict", True) + # Many RVS have a size argument, even when this is `None` and is therefore unused + kwargs.setdefault("on_unused_input", "ignore") super().__init__(*args, **kwargs) + def make_node(self, *inputs): + # If we try to build the RV with a different size type (vector -> None or None -> vector) + # We need to rebuild the Op with new size type in the inner graph + if self.extended_signature is not None: + (rng_arg_idxs, size_arg_idx, param_idxs), _ = self.get_input_output_type_idxs( + self.extended_signature + ) + if size_arg_idx is not None and len(rng_arg_idxs) == 1: + new_size_type = normalize_size_param(inputs[size_arg_idx]).type + if not self.input_types[size_arg_idx].in_same_class(new_size_type): + params = [inputs[idx] for idx in param_idxs] + size = inputs[size_arg_idx] + rng = inputs[rng_arg_idxs[0]] + return self.rebuild_rv(*params, size=size, rng=rng).owner + + return super().make_node(*inputs) + def update(self, node: Apply) -> dict[Variable, Variable]: """Symbolic update expression for input random state variables. diff --git a/tests/dims/distributions/test_scalar.py b/tests/dims/distributions/test_scalar.py index a487591c0..df47422de 100644 --- a/tests/dims/distributions/test_scalar.py +++ b/tests/dims/distributions/test_scalar.py @@ -22,6 +22,7 @@ HalfCauchy, HalfFlat, HalfNormal, + HalfStudentT, InverseGamma, Laplace, LogNormal, @@ -119,6 +120,22 @@ def test_studentt(): assert_equivalent_logp_graph(model, reference_model) +def test_halfstudentt(): + coords = {"a": range(3)} + with Model(coords=coords) as model: + HalfStudentT("x", nu=1, dims="a") + HalfStudentT("y", nu=1, sigma=3, dims="a") + HalfStudentT("z", nu=1, lam=3, dims="a") + + with Model(coords=coords) as reference_model: + regular_distributions.HalfStudentT("x", nu=1, dims="a") + regular_distributions.HalfStudentT("y", nu=1, sigma=3, dims="a") + regular_distributions.HalfStudentT("z", nu=1, lam=3, dims="a") + + assert_equivalent_random_graph(model, reference_model) + assert_equivalent_logp_graph(model, reference_model) + + def test_cauchy(): coords = {"a": range(3)} with Model(coords=coords) as model: From 5050fe692e0baf8694866c1f9cd9dede418344d3 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Tue, 24 Jun 2025 19:10:18 +0200 Subject: [PATCH 07/11] Implement Dim ZeroSumNormal --- pymc/dims/distributions/transforms.py | 42 +++++++++++++++ pymc/dims/distributions/vector.py | 68 ++++++++++++++++++++++--- pymc/distributions/multivariate.py | 11 ++-- tests/dims/distributions/test_vector.py | 20 +++++++- tests/dims/test_model.py | 53 +++++++++++++++++++ 5 files changed, 181 insertions(+), 13 deletions(-) diff --git a/pymc/dims/distributions/transforms.py b/pymc/dims/distributions/transforms.py index 6805d1b5c..8f49d2a16 100644 --- a/pymc/dims/distributions/transforms.py +++ b/pymc/dims/distributions/transforms.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import pytensor.tensor as pt import pytensor.xtensor as ptx from pymc.logprob.transforms import Transform @@ -51,3 +52,44 @@ def log_jac_det(self, value, *inputs): log_odds_transform = LogOddsTransform() + + +class ZeroSumTransform(DimTransform): + name = "zerosum" + + def __init__(self, dims: tuple[str, ...]): + self.dims = dims + + @staticmethod + def extend_dim(array, dim): + n = (array.sizes[dim] + 1).astype("floatX") + sum_vals = array.sum(dim) + norm = sum_vals / (pt.sqrt(n) + n) + fill_val = norm - sum_vals / pt.sqrt(n) + + out = ptx.concat([array, fill_val], dim=dim) + return out - norm + + @staticmethod + def reduce_dim(array, dim): + n = array.sizes[dim].astype("floatX") + last = array.isel({dim: -1}) + + sum_vals = -last * pt.sqrt(n) + norm = sum_vals / (pt.sqrt(n) + n) + return array.isel({dim: slice(None, -1)}) + norm + + def forward(self, value, *rv_inputs): + for dim in self.dims: + value = self.reduce_dim(value, dim=dim) + return value + + def backward(self, value, *rv_inputs): + for dim in self.dims: + value = self.extend_dim(value, dim=dim) + return value + + def log_jac_det(self, value, *rv_inputs): + # Use following once broadcast_like is implemented + # as_xtensor(0).broadcast_like(value, exclude=self.dims)` + return value.sum(self.dims) * 0 diff --git a/pymc/dims/distributions/vector.py b/pymc/dims/distributions/vector.py index b11bd56f2..76638ba40 100644 --- a/pymc/dims/distributions/vector.py +++ b/pymc/dims/distributions/vector.py @@ -14,11 +14,14 @@ import pytensor.xtensor as ptx import pytensor.xtensor.random as ptxr -from pytensor.tensor.random.utils import normalize_size_param +from pytensor.tensor import as_tensor +from pytensor.xtensor import as_xtensor from pytensor.xtensor import random as pxr from pymc.dims.distributions.core import VectorDimDistribution +from pymc.dims.distributions.transforms import ZeroSumTransform from pymc.distributions.multivariate import ZeroSumNormalRV +from pymc.util import UNSET class Categorical(VectorDimDistribution): @@ -94,9 +97,60 @@ def dist(cls, mu, cov=None, *, chol=None, lower=True, core_dims=None, **kwargs): return super().dist([mu, cov], core_dims=core_dims, **kwargs) -class DimZeroSumNormalRV(ZeroSumNormalRV): - def make_node(self, rng, size, sigma, support_shape): - if not self.input_types[1].in_same_class(normalize_size_param(size).type): - # We need to rebuild the graph with new size type - return self.rv_op(sigma, support_shape, size=size, rng=rng).owner - return super().make_node(rng, size, sigma, support_shape) +class ZeroSumNormal(VectorDimDistribution): + @classmethod + def __new__( + cls, *args, core_dims=None, dims=None, default_transform=UNSET, observed=None, **kwargs + ): + if core_dims is not None: + if isinstance(core_dims, str): + core_dims = (core_dims,) + + # Create default_transform + if observed is None and default_transform is UNSET: + default_transform = ZeroSumTransform(dims=core_dims) + + # If the user didn't specify dims, take it from core_dims + # We need them to be forwarded to dist in the `dim_lenghts` argument + if dims is None and core_dims is not None: + dims = (..., *core_dims) + + return super().__new__( + *args, + core_dims=core_dims, + dims=dims, + default_transform=default_transform, + observed=observed, + **kwargs, + ) + + @classmethod + def dist(cls, sigma=1.0, *, core_dims=None, dim_lengths, **kwargs): + if isinstance(core_dims, str): + core_dims = (core_dims,) + if core_dims is None or len(core_dims) == 0: + raise ValueError("ZeroSumNormal requires atleast 1 core_dims") + + support_dims = as_xtensor( + as_tensor([dim_lengths[core_dim] for core_dim in core_dims]), dims=("_",) + ) + sigma = cls._as_xtensor(sigma) + + return super().dist( + [sigma, support_dims], core_dims=core_dims, dim_lengths=dim_lengths, **kwargs + ) + + @classmethod + def xrv_op(self, sigma, support_dims, core_dims, extra_dims=None, rng=None): + sigma = as_xtensor(sigma) + support_dims = as_xtensor(support_dims, dims=("_",)) + support_shape = support_dims.values + core_rv = ZeroSumNormalRV.rv_op(sigma=sigma.values, support_shape=support_shape).owner.op + xop = pxr._as_xrv( + core_rv, + core_inps_dims_map=[(), (0,)], + core_out_dims_map=tuple(range(1, len(core_dims) + 1)), + ) + # Dummy "_" core dim to absorb the support_shape vector + # If ZeroSumNormal expected a scalar per support dim, this wouldn't be needed + return xop(sigma, support_dims, core_dims=("_", *core_dims), extra_dims=extra_dims, rng=rng) diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index f5a5506bb..5dd2509ef 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -2664,6 +2664,7 @@ def logp(value, alpha, K): class ZeroSumNormalRV(SymbolicRandomVariable): """ZeroSumNormal random variable.""" + name = "ZeroSumNormal" _print_name = ("ZeroSumNormal", "\\operatorname{ZeroSumNormal}") @classmethod @@ -2687,12 +2688,12 @@ def rv_op(cls, sigma, support_shape, *, size=None, rng=None): zerosum_rv -= zerosum_rv.mean(axis=-axis - 1, keepdims=True) support_str = ",".join([f"d{i}" for i in range(n_zerosum_axes)]) - extended_signature = f"[rng],(),(s),[size]->[rng],({support_str})" - return ZeroSumNormalRV( - inputs=[rng, sigma, support_shape, size], + extended_signature = f"[rng],[size],(),(s)->[rng],({support_str})" + return cls( + inputs=[rng, size, sigma, support_shape], outputs=[next_rng, zerosum_rv], extended_signature=extended_signature, - )(rng, sigma, support_shape, size) + )(rng, size, sigma, support_shape) class ZeroSumNormal(Distribution): @@ -2828,7 +2829,7 @@ def zerosum_default_transform(op, rv): @_logprob.register(ZeroSumNormalRV) -def zerosumnormal_logp(op, values, rng, sigma, support_shape, size, **kwargs): +def zerosumnormal_logp(op, values, rng, size, sigma, support_shape, **kwargs): (value,) = values shape = value.shape n_zerosum_axes = op.ndim_supp diff --git a/tests/dims/distributions/test_vector.py b/tests/dims/distributions/test_vector.py index 0f08505db..3a57453b4 100644 --- a/tests/dims/distributions/test_vector.py +++ b/tests/dims/distributions/test_vector.py @@ -19,7 +19,7 @@ import pymc.distributions as regular_distributions from pymc import Model -from pymc.dims import Categorical, MvNormal +from pymc.dims import Categorical, MvNormal, ZeroSumNormal from tests.dims.utils import assert_equivalent_logp_graph, assert_equivalent_random_graph @@ -60,3 +60,21 @@ def test_mvnormal(): assert_equivalent_random_graph(model, reference_model) assert_equivalent_logp_graph(model, reference_model) + + +def test_zerosumnormal(): + coords = {"a": range(3), "b": range(2)} + with Model(coords=coords) as model: + ZeroSumNormal("x", core_dims=("b",), dims=("a", "b")) + ZeroSumNormal("y", sigma=3, core_dims=("b",), dims=("a", "b")) + ZeroSumNormal("z", core_dims=("a", "b"), dims=("a", "b")) + + with Model(coords=coords) as reference_model: + regular_distributions.ZeroSumNormal("x", dims=("a", "b")) + regular_distributions.ZeroSumNormal("y", sigma=3, n_zerosum_axes=1, dims=("a", "b")) + regular_distributions.ZeroSumNormal("z", n_zerosum_axes=2, dims=("a", "b")) + + assert_equivalent_random_graph(model, reference_model) + # Logp is correct, but we have join(..., -1) and join(..., 1), that don't get canonicalized to the same + # Should work once https://github.com/pymc-devs/pytensor/issues/1505 is fixed + # assert_equivalent_logp_graph(model, reference_model) diff --git a/tests/dims/test_model.py b/tests/dims/test_model.py index c67f629b9..68446c1a6 100644 --- a/tests/dims/test_model.py +++ b/tests/dims/test_model.py @@ -172,3 +172,56 @@ def test_complex_model(): tune=200, chains=2, draws=50, compute_convergence_checks=False, progressbar=False ) pm.sample_posterior_predictive(idata, progressbar=False) + + +def test_zerosumnormal_model(): + coords = {"time": range(5), "item": range(3)} + + with pm.Model(coords=coords) as model: + zsn_item = pmd.ZeroSumNormal("zsn_item", core_dims="item", dims=("time", "item")) + zsn_time = pmd.ZeroSumNormal("zsn_time", core_dims="time", dims=("time", "item")) + zsn_item_time = pmd.ZeroSumNormal("zsn_item_time", core_dims=("item", "time")) + assert zsn_item.type.dims == ("time", "item") + assert zsn_time.type.dims == ("time", "item") + assert zsn_item_time.type.dims == ("item", "time") + + zsn_item_draw, zsn_time_draw, zsn_item_time_draw = pm.draw( + [zsn_item, zsn_time, zsn_item_time], random_seed=1 + ) + assert zsn_item_draw.shape == (5, 3) + np.testing.assert_allclose(zsn_item_draw.mean(-1), 0, atol=1e-13) + assert not np.allclose(zsn_item_draw.mean(0), 0, atol=1e-13) + + assert zsn_time_draw.shape == (5, 3) + np.testing.assert_allclose(zsn_time_draw.mean(0), 0, atol=1e-13) + assert not np.allclose(zsn_time_draw.mean(-1), 0, atol=1e-13) + + assert zsn_item_time_draw.shape == (3, 5) + np.testing.assert_allclose(zsn_item_time_draw.mean(), 0, atol=1e-13) + + with pm.Model(coords=coords) as ref_model: + # Check that the ZeroSumNormal can be used in a model + pm.ZeroSumNormal("zsn_item", dims=("time", "item")) + pm.ZeroSumNormal("zsn_time", dims=("item", "time")) + pm.ZeroSumNormal("zsn_item_time", n_zerosum_axes=2, dims=("item", "time")) + + # Check initial_point and logp + ip = model.initial_point() + ref_ip = ref_model.initial_point() + assert ip.keys() == ref_ip.keys() + for i, (ip_value, ref_ip_value) in enumerate(zip(ip.values(), ref_ip.values())): + if i == 1: + # zsn_time is actually transposed in the original model + ip_value = ip_value.T + np.testing.assert_allclose(ip_value, ref_ip_value) + + logp_fn = model.compile_logp() + ref_logp_fn = ref_model.compile_logp() + np.testing.assert_allclose(logp_fn(ip), ref_logp_fn(ref_ip)) + + # Test a new point + rng = np.random.default_rng(68) + new_ip = ip.copy() + for key in new_ip: + new_ip[key] += rng.uniform(size=new_ip[key].shape) + np.testing.assert_allclose(logp_fn(new_ip), ref_logp_fn(new_ip)) From 8de64d69378670e4db63643037f9e186e6ba5a53 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Sun, 29 Jun 2025 23:54:06 +0200 Subject: [PATCH 08/11] Arviz don't fail hard on incompatible coordinate lengths --- pymc/backends/arviz.py | 32 +++++++++++++++++++++++++++++++- tests/backends/test_arviz.py | 21 +++++++++++++++++++++ 2 files changed, 52 insertions(+), 1 deletion(-) diff --git a/pymc/backends/arviz.py b/pymc/backends/arviz.py index f0f0eec96..9cc20ed07 100644 --- a/pymc/backends/arviz.py +++ b/pymc/backends/arviz.py @@ -49,10 +49,40 @@ _log = logging.getLogger(__name__) + +RAISE_ON_INCOMPATIBLE_COORD_LENGTHS = False + + # random variable object ... Var = Any +def dict_to_dataset_drop_incompatible_coords(vars_dict, *args, dims, coords, **kwargs): + safe_coords = coords + + if not RAISE_ON_INCOMPATIBLE_COORD_LENGTHS: + coords_lengths = {k: len(v) for k, v in coords.items()} + for var_name, var in vars_dict.items(): + # Iterate in reversed because of chain/draw batch dimensions + for dim, dim_length in zip(reversed(dims.get(var_name, ())), reversed(var.shape)): + coord_length = coords_lengths.get(dim, None) + if (coord_length is not None) and (coord_length != dim_length): + warnings.warn( + f"Incompatible coordinate length of {coord_length} for dimension '{dim}' of variable '{var_name}'.\n" + "The originate coordinates for this dim will not be included in the returned dataset for any of the variables. " + "Instead they will default to `np.arange(var_length)` and the shorter variables will be right-padded with nan.\n" + "To make this warning into an error set `pymc.backends.arviz.RAISE_ON_INCOMPATIBLE_COORD_LENGTHS` to `True`", + UserWarning, + ) + if safe_coords is coords: + safe_coords = coords.copy() + safe_coords.pop(dim) + coords_lengths.pop(dim) + + # FIXME: Would be better to drop coordinates altogether, but arviz defaults to `np.arange(var_length)` + return dict_to_dataset(vars_dict, *args, dims=dims, coords=safe_coords, **kwargs) + + def find_observations(model: "Model") -> dict[str, Var]: """If there are observations available, return them as a dictionary.""" observations = {} @@ -365,7 +395,7 @@ def priors_to_xarray(self): priors_dict[group] = ( None if var_names is None - else dict_to_dataset( + else dict_to_dataset_drop_incompatible_coords( {k: np.expand_dims(self.prior[k], 0) for k in var_names}, library=pymc, coords=self.coords, diff --git a/tests/backends/test_arviz.py b/tests/backends/test_arviz.py index 3c06288b3..ddf81e274 100644 --- a/tests/backends/test_arviz.py +++ b/tests/backends/test_arviz.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import re import warnings import numpy as np @@ -837,3 +838,23 @@ def test_dataset_to_point_list_str_key(self): ds[3] = xarray.DataArray([1, 2, 3]) with pytest.raises(ValueError, match="must be str"): dataset_to_point_list(ds, sample_dims=["chain", "draw"]) + + +def test_incompatible_coordinate_lengths(): + with pm.Model(coords={"a": [0, 1, 2]}) as m: + x = pm.Normal("x", dims="a") + y = pm.Deterministic("b", x[1:], dims=("a",)) + + with pytest.warns( + UserWarning, + match=re.escape( + "Incompatible coordinate length of 3 for dimension 'a' of variable 'b'" + ), + ): + pm.sample_prior_predictive(draws=1) + + pm.backends.arviz.RAISE_ON_INCOMPATIBLE_COORD_LENGTHS = True + with pytest.raises(ValueError): + pm.sample_prior_predictive(draws=1) + + pm.backends.arviz.RAISE_ON_INCOMPATIBLE_COORD_LENGTHS = False From ba76921c8f6f0544e51cfef4fe3eafdf4df9c05e Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Sun, 29 Jun 2025 23:31:47 +0200 Subject: [PATCH 09/11] Propagate Op name to SymbolicRandomVariables --- pymc/distributions/distribution.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index 213af82f4..27d53c868 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -372,6 +372,8 @@ def __init__( kwargs.setdefault("strict", True) # Many RVS have a size argument, even when this is `None` and is therefore unused kwargs.setdefault("on_unused_input", "ignore") + if hasattr(self, "name"): + kwargs.setdefault("name", self.name) super().__init__(*args, **kwargs) def make_node(self, *inputs): From 713eddde3ed777c0381e6aafc4cad88a64d75b4a Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Sun, 29 Jun 2025 23:55:10 +0200 Subject: [PATCH 10/11] Tweaks to model_graph to play nice with XTensorVariables * Use RV Op name when provided * More robust detection of observed data variables (after https://github.com/pymc-devs/pymc/pull/7656 arbitrary graphs are allowed) * Remove self loops explicitly (closes https://github.com/pymc-devs/pymc/issues/7722) --- pymc/model_graph.py | 50 ++++++++++++++++++++------------------------- 1 file changed, 22 insertions(+), 28 deletions(-) diff --git a/pymc/model_graph.py b/pymc/model_graph.py index 5185bfbf1..1e3a955c6 100644 --- a/pymc/model_graph.py +++ b/pymc/model_graph.py @@ -22,14 +22,13 @@ from pytensor import function from pytensor.graph import Apply -from pytensor.graph.basic import ancestors, walk -from pytensor.scalar.basic import Cast -from pytensor.tensor.elemwise import Elemwise +from pytensor.graph.basic import Variable, ancestors, walk from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.shape import Shape from pytensor.tensor.variable import TensorVariable from pymc.model.core import modelcontext +from pymc.pytensorf import _cheap_eval_mode from pymc.util import VarName, get_default_varnames, get_var_name __all__ = ( @@ -77,7 +76,7 @@ def create_plate_label_with_dim_length( def fast_eval(var): - return function([], var, mode="FAST_COMPILE")() + return function([], var, mode=_cheap_eval_mode)() class NodeType(str, Enum): @@ -124,12 +123,14 @@ def default_potential(var: TensorVariable) -> GraphvizNodeKwargs: } -def random_variable_symbol(var: TensorVariable) -> str: +def random_variable_symbol(var: Variable) -> str: """Get the symbol of the random variable.""" - symbol = var.owner.op.__class__.__name__ + op = var.owner.op - if symbol.endswith("RV"): - symbol = symbol[:-2] + if name := getattr(op, "name", None): + symbol = name[0].upper() + name[1:] + else: + symbol = op.__class__.__name__.removesuffix("RV") return symbol @@ -319,28 +320,21 @@ def make_compute_graph( input_map[var_name] = input_map[var_name].union(parent_name) if var in self.model.observed_RVs: - obs_node = self.model.rvs_to_values[var] - - # loop created so that the elif block can go through this again - # and remove any intermediate ops, notably dtype casting, to observations - while True: - obs_name = obs_node.name - if obs_name and obs_name != var_name: + # Make observed `Data` variables flow from the observed RV, and not the other way around + # (In the generative graph they usually inform shape of the observed RV) + # We have to iterate over the ancestors of the observed values because there can be + # deterministic operations in between the `Data` variable and the observed value. + obs_var = self.model.rvs_to_values[var] + for ancestor in ancestors([obs_var]): + if (obs_name := cast(VarName, ancestor.name)) in input_map: input_map[var_name] = input_map[var_name].difference({obs_name}) input_map[obs_name] = input_map[obs_name].union({var_name}) - break - elif ( - # for cases where observations are cast to a certain dtype - # see issue 5795: https://github.com/pymc-devs/pymc/issues/5795 - obs_node.owner - and isinstance(obs_node.owner.op, Elemwise) - and isinstance(obs_node.owner.op.scalar_op, Cast) - ): - # we can retrieve the observation node by going up the graph - obs_node = obs_node.owner.inputs[0] - else: + # break assumes observed values can depend on only one `Data` variable break + # Remove self references + for var_name in input_map: + input_map[var_name] = input_map[var_name].difference({var_name}) return input_map def get_plates( @@ -360,13 +354,13 @@ def get_plates( plates = defaultdict(set) # TODO: Evaluate all RV shapes at once - # This should help find discrepencies, and + # This should help find discrepancies, and # avoids unnecessary function compiles for determining labels. dim_lengths: dict[str, int] = { dim_name: fast_eval(value).item() for dim_name, value in self.model.dim_lengths.items() } var_shapes: dict[str, tuple[int, ...]] = { - var_name: tuple(fast_eval(self.model[var_name].shape)) + var_name: tuple(map(int, fast_eval(self.model[var_name].shape))) for var_name in self.vars_to_plot(var_names) } From c0ffb50151e4446b5ba1b619b417cfdba6bd0085 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Sun, 29 Jun 2025 23:54:40 +0200 Subject: [PATCH 11/11] Core notebook on dims module --- .../learn/core_notebooks/dims_module.ipynb | 1666 +++++++++++++++++ 1 file changed, 1666 insertions(+) create mode 100644 docs/source/learn/core_notebooks/dims_module.ipynb diff --git a/docs/source/learn/core_notebooks/dims_module.ipynb b/docs/source/learn/core_notebooks/dims_module.ipynb new file mode 100644 index 000000000..dbe44f0b7 --- /dev/null +++ b/docs/source/learn/core_notebooks/dims_module.ipynb @@ -0,0 +1,1666 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "17e37649edaa8d0d", + "metadata": {}, + "source": [ + "# PyMC dims module" + ] + }, + { + "cell_type": "markdown", + "id": "4a2002b7ab9e00eb", + "metadata": {}, + "source": [ + "## A short history of dims in PyMC\n", + "\n", + "PyMC introduced the ability to specify model variable `dims` in version 3.9 in June 2020 (5 years as of the time of writing). There was a discrete mention following [14 other new features](https://github.com/pymc-devs/pymc/blob/1d00f3eb81723523968f3610e81a0c42fd96326f/RELEASE-NOTES.md?plain=1#L236), for a feature that would over time become a foundation of the library.\n", + "\n", + "It allows users to more naturally specify the dimensions of model variables with strings and provides a \"seamless\" conversion to arviz :doc:`InferenceData ` objects, which have become the standard for storing and investigating results from probabilistic programming languages.\n", + "\n", + "However, the behavior of dims is rather limited. It can only be used to specify the shape of new random variables and label existing dimensions (e.g., in :func:`~pymc.Deterministic`). It has otherwise no effect on the computation, unlike operations done with :class:`~arviz.InferenceData` variables, which are based on :mod:`xarray` and where dims inform array selection, alignment and broadcasting behavior.\n", + "\n", + "In contrast, with PyMC models, users have to write computations that follow numpy semantics, which usually mean a plethora of transpositions, reshapes, new axis (`None`) and numerical axis arguments sprinkled everywhere. It can be hard to get these right and after all is said and done, it's somewhat hard to make sense of the written model.\n", + "\n", + "## Expanding the role of dims\n", + "\n", + "PyMC introduces an experimental :mod:`pymc.dims` module that allows users to define data, distributions and math operations that respect dim semantics, as close as possible to xarray. Let us take a look. We start with a model written in old-fashioned PyMC syntax:" + ] + }, + { + "cell_type": "code", + "id": "921cb1a07f11b51f", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-30T15:53:25.197894Z", + "start_time": "2025-06-30T15:53:20.093896Z" + } + }, + "source": [ + "import numpy as np\n", + "\n", + "import pymc as pm\n", + "\n", + "seed = sum(map(ord, \"dims module\"))\n", + "rng = np.random.default_rng(seed)" + ], + "outputs": [], + "execution_count": 1 + }, + { + "cell_type": "code", + "id": "69ecff83e6b09dc3", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-30T15:53:25.522571Z", + "start_time": "2025-06-30T15:53:25.393579Z" + } + }, + "source": [ + "# Very realistic looking data!\n", + "observed_response_np = np.ones((5, 20), dtype=int)\n", + "coords = coords = {\n", + " \"participant\": range(5),\n", + " \"trial\": range(20),\n", + " \"item\": range(3),\n", + "}\n", + "with pm.Model(coords=coords) as model:\n", + " observed_response = pm.Data(\n", + " \"observed_response\", observed_response_np, dims=(\"participant\", \"trial\")\n", + " )\n", + " # Use ZeroSumNormal to avoid identifiability issues\n", + " participant_preference = pm.ZeroSumNormal(\n", + " \"participant_preference\", n_zerosum_axes=1, dims=(\"participant\", \"item\")\n", + " )\n", + "\n", + " # Shared time effects across all participants\n", + " time_effects = pm.Normal(\"time_effects\", dims=(\"trial\", \"item\"))\n", + "\n", + " trial_preference = pm.Deterministic(\n", + " \"trial_preference\",\n", + " participant_preference[:, None, :] + time_effects[None, :, :],\n", + " dims=(\"participant\", \"trial\", \"item\"),\n", + " )\n", + "\n", + " response = pm.Categorical(\n", + " \"response\",\n", + " p=pm.math.softmax(trial_preference, axis=-1),\n", + " observed=observed_response,\n", + " dims=(\"participant\", \"trial\"),\n", + " )" + ], + "outputs": [], + "execution_count": 2 + }, + { + "cell_type": "code", + "id": "2efa25b6d8713c0", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-30T15:53:29.819511Z", + "start_time": "2025-06-30T15:53:25.547610Z" + } + }, + "source": [ + "model.to_graphviz()" + ], + "outputs": [ + { + "data": { + "image/svg+xml": "\n\n\n\n\n\n\n\nclusterparticipant (5) x trial (20)\n\nparticipant (5) x trial (20)\n\n\nclusterparticipant (5) x item (3)\n\nparticipant (5) x item (3)\n\n\nclustertrial (20) x item (3)\n\ntrial (20) x item (3)\n\n\nclusterparticipant (5) x trial (20) x item (3)\n\nparticipant (5) x trial (20) x item (3)\n\n\n\nresponse\n\nresponse\n~\nCategorical\n\n\n\nobserved_response\n\nobserved_response\n~\nData\n\n\n\nresponse->observed_response\n\n\n\n\n\nparticipant_preference\n\nparticipant_preference\n~\nZeroSumNormal\n\n\n\ntrial_preference\n\ntrial_preference\n~\nDeterministic\n\n\n\nparticipant_preference->trial_preference\n\n\n\n\n\ntime_effects\n\ntime_effects\n~\nNormal\n\n\n\ntime_effects->trial_preference\n\n\n\n\n\ntrial_preference->response\n\n\n\n\n\n", + "text/plain": [ + "" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 3 + }, + { + "cell_type": "markdown", + "id": "cd191bed68527806", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-26T21:31:07.115779Z", + "start_time": "2025-06-26T21:31:05.140898Z" + } + }, + "source": "And now let's write the equivalent model using the :mod:`pymc.dims` module." + }, + { + "cell_type": "code", + "id": "94964484499b163c", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-30T15:53:29.908171Z", + "start_time": "2025-06-30T15:53:29.845474Z" + } + }, + "source": [ + "import pymc.dims as pmd" + ], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/ricardo/Documents/pymc/pymc/dims/__init__.py:66: UserWarning: The `pymc.dims` module is experimental and may contain critical bugs (p=0.565).\n", + "Please report any issues you encounter at https://github.com/pymc-devs/pymc/issues.\n", + "Disclaimer: This an experimental API and may change at any time.\n", + " __init__()\n" + ] + } + ], + "execution_count": 4 + }, + { + "cell_type": "code", + "id": "c020e450cc165e46", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-30T15:53:29.997156Z", + "start_time": "2025-06-30T15:53:29.935025Z" + } + }, + "source": [ + "with pm.Model(coords=coords) as dmodel:\n", + " observed_response = pmd.Data(\n", + " \"observed_response\", observed_response_np, dims=(\"participant\", \"trial\")\n", + " )\n", + " participant_preference = pmd.ZeroSumNormal(\n", + " \"participant_preference\", core_dims=\"item\", dims=(\"participant\", \"item\")\n", + " )\n", + "\n", + " # Shared time effects across all participants\n", + " time_effects = pmd.Normal(\"time_effects\", dims=(\"item\", \"trial\"))\n", + "\n", + " trial_preference = pmd.Deterministic(\n", + " \"trial_preference\",\n", + " participant_preference + time_effects,\n", + " )\n", + "\n", + " response = pmd.Categorical(\n", + " \"response\",\n", + " p=pmd.math.softmax(trial_preference, dim=\"item\"),\n", + " core_dims=\"item\",\n", + " observed=observed_response,\n", + " )" + ], + "outputs": [], + "execution_count": 5 + }, + { + "cell_type": "markdown", + "id": "be93cdc3ae56689a", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-26T21:31:07.214221Z", + "start_time": "2025-06-26T21:31:07.191541Z" + } + }, + "source": [ + "Note we still use the same :class:`~pymc.Model` constructor, but everything else was now defined with an equivalent function or class defined in the :mod:`pymc.dims` module.\n", + "\n", + "There are some notable differences:\n", + "1. ZeroSumNormal takes a `core_dims` argument instead of `n_zerosum_axes`. This tells PyMC, which of the `dims` that define the distribution are constrained to be zero-summed. All distributions that take non-scalar parameters now require a `core_dims` argument, whereas before they were assumed to be right-aligned by the user (see more in :doc:`dimensionality`). This is no longer necessary. You shouldn't have to worry about the order of the dimensions in your model, just their meaning!\n", + "2. The `trial_preference` computation is mathematically equivalent to what we had before, without us having to align dimensions for broadcasting following numpy-semantics.\n", + "3. The `softmax` operation is defined based on the `dim` argument, not the positional axis. Note: Why is it `dim` and not `core_dims`? We try to stay as close as possible to xarray syntax which use `dim` throughout (even though they don't have a softmax operation). The xarray functions use `dim` so we use it as well. However, reusing `dim` for distributions would probably be confusing, since they already have the `dims` argument.\n", + "4. The `Categorical` observed variable, like `ZeroSumNormal` requires a `core_dims` argument, to know which dimension corresponds to the probability vector. Before users were requested to place this dimension explicitly on the rightmost axis. This is not necessary anymore.\n", + "5. Even though dims were not specified for either `trial_preference` or `response`, PyMC automatically infers them. You can check the graphviz representation is identical" + ] + }, + { + "cell_type": "code", + "id": "43e1a65de0af06b5", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-30T15:53:30.408038Z", + "start_time": "2025-06-30T15:53:30.034734Z" + } + }, + "source": [ + "dmodel.to_graphviz()" + ], + "outputs": [ + { + "data": { + "image/svg+xml": "\n\n\n\n\n\n\n\nclusterparticipant (5) x trial (20)\n\nparticipant (5) x trial (20)\n\n\nclusterparticipant (5) x item (3)\n\nparticipant (5) x item (3)\n\n\nclusteritem (3) x trial (20)\n\nitem (3) x trial (20)\n\n\nclusterparticipant (5) x item (3) x trial (20)\n\nparticipant (5) x item (3) x trial (20)\n\n\n\nresponse\n\nresponse\n~\nCategorical\n\n\n\nobserved_response\n\nobserved_response\n~\nData\n\n\n\nresponse->observed_response\n\n\n\n\n\nparticipant_preference\n\nparticipant_preference\n~\nZeroSumNormal\n\n\n\ntrial_preference\n\ntrial_preference\n~\nDeterministic\n\n\n\nparticipant_preference->trial_preference\n\n\n\n\n\ntime_effects\n\ntime_effects\n~\nNormal\n\n\n\ntime_effects->trial_preference\n\n\n\n\n\ntrial_preference->response\n\n\n\n\n\n", + "text/plain": [ + "" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 6 + }, + { + "cell_type": "markdown", + "id": "41d93de2d86c61fd", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-26T21:31:07.373955Z", + "start_time": "2025-06-26T21:31:07.243130Z" + } + }, + "source": [ + "To convince ourselves the models are equivalent we print the logp of each variable evaluated at the initial_point." + ] + }, + { + "cell_type": "code", + "id": "11073d9b67a72f72", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-30T15:53:34.534342Z", + "start_time": "2025-06-30T15:53:30.457494Z" + } + }, + "source": [ + "print(model.point_logps())\n", + "print(dmodel.point_logps())" + ], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'participant_preference': np.float64(-9.19), 'time_effects': np.float64(-55.14), 'response': np.float64(-109.86)}\n", + "{'participant_preference': np.float64(-9.19), 'time_effects': np.float64(-55.14), 'response': np.float64(-109.86)}\n" + ] + } + ], + "execution_count": 7 + }, + { + "cell_type": "markdown", + "id": "ee981114897c7ee0", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-26T21:31:08.349480Z", + "start_time": "2025-06-26T21:31:07.387385Z" + } + }, + "source": [ + "### A brief look under the hood" + ] + }, + { + "cell_type": "markdown", + "id": "e360b5f5e9e8ca1e", + "metadata": {}, + "source": "The :mod:`pymc.dims` module functionality is built on top of an also experimental :mod:`pytensor.xtensor` module in PyTensor, which is the :lib:`xarray` analogous to the :mod:`pytensor.tensor` module you may be familiar with (see :doc:`pymc_and_pytensor`). Whereas regular distributions and math operations return :class:`pytensor.tensor.TensorVariable` objects, the :mod:`pymc.dims` module returns :class:`pytensor.xtensor.XTensorVariable` objects. These are very similar to TensorVariable, but they have a `dims` attribute that contains the names of the dimensions of the variable, and which modifies the behavior of operations on them." + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "We create a regular Normal random variable with 3 entries, and perform an outer addition on them using Numpy syntax.", + "id": "5a1fa8ecb7309782" + }, + { + "cell_type": "code", + "id": "9bece958e432c369", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-30T15:53:34.611874Z", + "start_time": "2025-06-30T15:53:34.598883Z" + } + }, + "source": [ + "regular_normal = pm.Normal.dist(mu=pm.math.as_tensor([0, 1, 2]), sigma=1, shape=(3,))\n", + "regular_normal.type" + ], + "outputs": [ + { + "data": { + "text/plain": [ + "TensorType(float64, shape=(3,))" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 8 + }, + { + "cell_type": "code", + "id": "5c0aa77e31170c54", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-30T15:53:34.700430Z", + "start_time": "2025-06-30T15:53:34.693544Z" + } + }, + "source": [ + "type(regular_normal)" + ], + "outputs": [ + { + "data": { + "text/plain": [ + "pytensor.tensor.variable.TensorVariable" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 9 + }, + { + "cell_type": "code", + "id": "d77168a8c6e89de9", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-30T15:53:34.805130Z", + "start_time": "2025-06-30T15:53:34.792870Z" + } + }, + "source": [ + "outer_addition = regular_normal[:, None] + regular_normal[None, :]\n", + "outer_addition.type" + ], + "outputs": [ + { + "data": { + "text/plain": [ + "TensorType(float64, shape=(3, 3))" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 10 + }, + { + "cell_type": "code", + "id": "ac95fe88c1877fbe", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-30T15:53:34.930130Z", + "start_time": "2025-06-30T15:53:34.887799Z" + } + }, + "source": [ + "pm.draw(outer_addition, random_seed=rng)" + ], + "outputs": [ + { + "data": { + "text/plain": [ + "array([[0.61284312, 1.68384684, 1.72225487],\n", + " [1.68384684, 2.75485056, 2.79325859],\n", + " [1.72225487, 2.79325859, 2.83166662]])" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 11 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "The same operation with a dimmed Normal variable, requires the use of `rename` so that the dimensions broadcast orthogonally.", + "id": "b7d78e2f3984adbc" + }, + { + "cell_type": "code", + "id": "4d9e8417789d4c18", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-30T15:53:35.055530Z", + "start_time": "2025-06-30T15:53:35.045241Z" + } + }, + "source": [ + "dims_normal = pmd.Normal.dist(mu=pmd.math.as_xtensor([0, 1, 2], dims=(\"a\",)), sigma=1)\n", + "dims_normal.type" + ], + "outputs": [ + { + "data": { + "text/plain": [ + "XTensorType(float64, shape=(3,), dims=('a',))" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 12 + }, + { + "cell_type": "code", + "id": "1d0c30011c910370", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-30T15:53:35.264315Z", + "start_time": "2025-06-30T15:53:35.256071Z" + } + }, + "source": [ + "type(dims_normal)" + ], + "outputs": [ + { + "data": { + "text/plain": [ + "pytensor.xtensor.type.XTensorVariable" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 13 + }, + { + "cell_type": "code", + "id": "923e75f235dcf219", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-30T15:53:35.472551Z", + "start_time": "2025-06-30T15:53:35.465427Z" + } + }, + "source": [ + "outer_addition = dims_normal + dims_normal.rename({\"a\": \"b\"})\n", + "outer_addition.type" + ], + "outputs": [ + { + "data": { + "text/plain": [ + "XTensorType(float64, shape=(3, 3), dims=('a', 'b'))" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 14 + }, + { + "cell_type": "code", + "id": "cd9e2a2634a0e898", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-30T15:53:35.758679Z", + "start_time": "2025-06-30T15:53:35.649256Z" + } + }, + "source": [ + "pm.draw(outer_addition, random_seed=rng)" + ], + "outputs": [ + { + "data": { + "text/plain": [ + "array([[ 3.76355516, 0.31059132, 6.5420105 ],\n", + " [ 0.31059132, -3.14237253, 3.08904666],\n", + " [ 6.5420105 , 3.08904666, 9.32046584]])" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 15 + }, + { + "cell_type": "markdown", + "id": "e12c7cda782cac69", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-26T21:31:08.365035095Z", + "start_time": "2025-06-26T11:42:19.021166Z" + } + }, + "source": [ + "### Redundant (or implicit) dims" + ] + }, + { + "cell_type": "markdown", + "id": "4ac4877e24859dfe", + "metadata": {}, + "source": [ + "When defining deterministic operations or creating variables whose dimension are all implied by the parameters, there's no need to specify the `dims` argument, as PyMC will automatically know them.\n", + "\n", + "Despite this, we anticipate users will still want to do it. They can work as a sanity check that the dimensions of the variables are what on expects them to be, or simply as type hints for someone reading the model.\n", + "\n", + "PyMC allows specifying dimensions in these cases. To reduce confusion, the output will always be transposed to be aligned with the user-specified dims." + ] + }, + { + "cell_type": "code", + "id": "cd6031d682b88e51", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-30T15:53:35.989010Z", + "start_time": "2025-06-30T15:53:35.965369Z" + } + }, + "source": [ + "with pm.Model(coords={\"a\": range(2), \"b\": range(5)}) as example:\n", + " x = pmd.Normal(\"x\", dims=(\"a\", \"b\"))\n", + " det_implicit_dims = pmd.Deterministic(\"det1\", x + 1)\n", + " det_explicit_dims = pmd.Deterministic(\"det2\", x + 1, dims=(\"a\", \"b\"))\n", + " det_transposed_dims = pmd.Deterministic(\"y\", x + 1, dims=(\"b\", \"a\"))\n", + "\n", + "print(f\"{det_implicit_dims.dims=}\")\n", + "print(f\"{det_explicit_dims.dims=}\")\n", + "print(f\"{det_transposed_dims.dims=}\")" + ], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "det_implicit_dims.dims=('a', 'b')\n", + "det_explicit_dims.dims=('a', 'b')\n", + "det_transposed_dims.dims=('b', 'a')\n" + ] + } + ], + "execution_count": 16 + }, + { + "cell_type": "markdown", + "id": "db657b20b447c56a", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-26T21:31:08.365174897Z", + "start_time": "2025-06-26T11:42:19.389561Z" + } + }, + "source": [ + "This happens with `Deterministic`, `Potential` and every distribution in the `dims` module.\n", + "Everytime you specify dims you will get back a variable with the same order." + ] + }, + { + "cell_type": "markdown", + "id": "e5f73575a51d316a", + "metadata": {}, + "source": [ + "Furthermore, (and unlike regular PyMC objects) it is now valid to use ellipsis in the `dims` argument, which like in xarray transpose means \"all the other dimensions\" should stay in the same order." + ] + }, + { + "cell_type": "code", + "id": "a1e3597de136004f", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-30T15:53:36.397445Z", + "start_time": "2025-06-30T15:53:36.371682Z" + } + }, + "source": [ + "with pm.Model(coords={\"a\": range(2), \"b\": range(5)}) as example:\n", + " x = pmd.Normal(\"x\", dims=(\"a\", \"b\"))\n", + " det_ellipsis1 = pmd.Deterministic(\"det1\", x + 1, dims=(...,))\n", + " det_ellipsis2 = pmd.Deterministic(\"det2\", x + 1, dims=(..., \"a\"))\n", + " det_ellipsis3 = pmd.Deterministic(\"det3\", x + 1, dims=(\"b\", ...))\n", + "\n", + "print(f\"{det_ellipsis1.dims=}\")\n", + "print(f\"{det_ellipsis2.dims=}\")\n", + "print(f\"{det_ellipsis3.dims=}\")" + ], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "det_ellipsis1.dims=('a', 'b')\n", + "det_ellipsis2.dims=('b', 'a')\n", + "det_ellipsis3.dims=('b', 'a')\n" + ] + } + ], + "execution_count": 17 + }, + { + "cell_type": "markdown", + "id": "318c563fb2eb106b", + "metadata": {}, + "source": [ + "### What functionality is supported?" + ] + }, + { + "cell_type": "markdown", + "id": "88acbfae651fe294", + "metadata": {}, + "source": [ + "The documentation is still a work in progress, and there is no complete list of distributions and operations that are supported just yet. \n", + "\n", + "#### Model constructors\n", + "The following PyMC model constructors are available in the `dims` module.\n", + "\n", + " * :func:`~pymc.dims.Data`\n", + " * :func:`~pymc.dims.Deterministic`\n", + " * :func:`~pymc.dims.Potential`\n", + "\n", + "They all return :class:`pytensor.xtensor.type.XTensorVariable` objects, and either infer `dims` from the input or require the user to specify them explicitly. If they can be inferred, it is possible to transpose and use ellipsis in the `dims` argument, as described above.\n", + "\n", + "#### Distributions\n", + "We want to offer all the existing distributions and parametrizations under the :mod:`pymc.dims` module, with the following expected API differences:\n", + " * all vector arguments (and observed values) must have known dims. An error is raised otherwise.\n", + " * distributions with non-scalar inputs will require a `core_dims` argument.\n", + " * The meaning of the `core_dims` argument will be denoted in the docstrings of each distribution. For example, for the MvNormal, the `core_dims` are the two dimensions of the covariance matrix, one (and only one) of which must also be present in the mean parameter. The shared `core_dim` is the one that persists in the output. Sometimes the order of `core_dims` will be important!\n", + " * dims accept ellipsis, and variables are transposed to match the user-specified `dims` argument.\n", + " * shape and size cannot be provided.\n", + " * the :met:`pymc.distributions.core.DimDistribution.dist` method accepts a `dims_length` argument, of the form `{dim_name: dim_length}`.\n", + " * only transforms defined in :mod:`pymc.dims.transforms` can be used with distributions from the module.\n", + "\n", + "#### Operations on variables\n", + "\n", + "Calling a PyMC distribution from the :mod:`pymc.dims` module returns an :class:`pytensor.xtensor.type.XTensorVariable`.\n", + "\n", + "The expectation is that every :class:`xarray.DataArray` method in xarray should have an equivalent version for XTensorVariables. So if you can do `x.diff(dim=\"a\")` in xarray, you should be able to do `x.diff(dim=\"a\")` with XTensorVariables as well.\n", + "\n", + "In addition, many numerical operations are available in the :mod:`pymc.dims.math` module, which provides a superset of `ufuncs` functions found in xarray (like `exp`). It also includes submodules such as `linalg` that provide counterpart to libraries like :lib:`xarray_einstats` (such as `linalg.solve`).\n", + "\n", + "Finally, functions that are available at the module level in xarray (like `concat`) are also available in the :mod:`pymc.dims` namespace.\n", + "\n", + "To facilitate adoption of these functions and methods, we try to follow the same API used by the xarray and related packages. However, some methods or keyword arguments won't be supported explicitly (like `.isel`, more on that at the end), in which case an informative error or warning will be raised. \n", + "\n", + "If you find an API difference or some missing functionality, and no reason is provided, please [open an issue](https://github.com/pymc-devs/pymc/issues) to let us know (after checking nobody has done it already).\n", + "\n", + "In the meantime, the next section provides some hints on how to make use of pre-existing functionality in PyMC/PyTensor." + ] + }, + { + "cell_type": "markdown", + "id": "98a94f5939872e75", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-26T21:31:08.365484979Z", + "start_time": "2025-06-26T11:42:19.577389Z" + } + }, + "source": [ + "### Combining dims module with the old API" + ] + }, + { + "cell_type": "markdown", + "id": "7a253e168df7d982", + "metadata": {}, + "source": [ + "Because the `dims` module is more recent in does not offer all the functionality of the old API.\n", + "\n", + "You can always combine the two APIs by converting the variables explicitly. To obtain a regular non-dimmed variable from a dimmed variable, you can use :attr:`pytensor.xtensor.type.XTensorVariable.values` (like in xarray) or the more verbose :func:`pymc.dims.as_xtensor`.\n", + "\n", + "Otherwise, if you try to pass an XTensorVariable to a function or distribution that does not support it, you will usually see an error like this:" + ] + }, + { + "cell_type": "code", + "id": "8fd12e65aa9739c", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-30T15:53:36.555888Z", + "start_time": "2025-06-30T15:53:36.548186Z" + } + }, + "source": [ + "mu = pmd.math.as_xtensor([0, 1, 2], dims=(\"a\",))\n", + "try:\n", + " pm.Normal.dist(mu=mu)\n", + "except TypeError as e:\n", + " print(f\"{e.__class__.__name__}: {e}\")" + ], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "TypeError: To avoid subtle bugs, PyTensor forbids automatic conversion of XTensorVariable to TensorVariable.\n", + "You can convert explicitly using `x.values` or pass `allow_xtensor_conversion=True`.\n" + ] + } + ], + "execution_count": 18 + }, + { + "cell_type": "code", + "id": "335b094a6cdd0bd8", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-30T15:53:36.767188Z", + "start_time": "2025-06-30T15:53:36.745075Z" + } + }, + "source": [ + "pm.Normal.dist(mu=x.values).type" + ], + "outputs": [ + { + "data": { + "text/plain": [ + "TensorType(float64, shape=(None, None))" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 19 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "The order of the dimensions follows that specified in the :attr:`pytensor.xtensor.type.XTensorVariable.dims` property. To be sure this matches the expectation you can use a :met:`pytensor.xtensor.type.XTensorVariable.transpose` operation to reorder the dimensions before converting to a regular variable.", + "id": "eca88323e2529a57" + }, + { + "cell_type": "markdown", + "id": "9e256f247d3bd9e1", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-26T21:31:08.365732804Z", + "start_time": "2025-06-26T11:42:19.755931Z" + } + }, + "source": [ + "Conversely, if you try to pass a regular variable to a function or distribution that expects an XTensorVariable, you will see an error like this:" + ] + }, + { + "cell_type": "code", + "id": "380f7161ca2d22e4", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-30T15:53:36.879320Z", + "start_time": "2025-06-30T15:53:36.868282Z" + } + }, + "source": [ + "mu = pm.math.as_tensor([0, 1, 2], name=\"mu_x\")\n", + "try:\n", + " x = pmd.Normal.dist(mu=mu)\n", + "except Exception as e:\n", + " print(f\"{e.__class__.__name__}: {e}\")" + ], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ValueError: Variable mu_x{[0 1 2]} must have dims associated with it.\n", + "To avoid subtle bugs, PyMC does not make any assumptions about the dims of parameters.\n", + "Use `as_xtensor` with the `dims` keyword argument to specify the dims explicitly.\n" + ] + } + ], + "execution_count": 20 + }, + { + "cell_type": "markdown", + "id": "e507c0a8e76447fd", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-26T21:31:08.365847399Z", + "start_time": "2025-06-26T11:42:19.826864Z" + } + }, + "source": [ + "Which you can avoid by explicitly converting the variable to a dimmed variable:" + ] + }, + { + "cell_type": "code", + "id": "b50363261fe81df6", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-30T15:53:37.226477Z", + "start_time": "2025-06-30T15:53:37.204756Z" + } + }, + "source": [ + "pmd.Normal.dist(mu=pmd.as_xtensor(mu, dims=(\"a\",))).type" + ], + "outputs": [ + { + "data": { + "text/plain": [ + "XTensorType(float64, shape=(3,), dims=('a',))" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 21 + }, + { + "cell_type": "markdown", + "id": "e7b13e9b5cd20a63", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-26T21:31:08.365953097Z", + "start_time": "2025-06-26T11:42:19.885598Z" + } + }, + "source": [ + "#### Example\n", + "\n", + "To put this to practice, let us write a model that uses the :class:`~pymc.LKJCholeskyCov` distribution, which at the time of writing is not yet available in the :mod:`pymc.dims` module." + ] + }, + { + "cell_type": "code", + "id": "b4ae78e1b95f5198", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-30T15:53:38.236088Z", + "start_time": "2025-06-30T15:53:37.672791Z" + } + }, + "source": [ + "with pm.Model(coords={\"core1\": range(3), \"core2\": range(3), \"batch\": range(5)}) as mixed_api_model:\n", + " chol, _, _ = pm.LKJCholeskyCov(\n", + " \"chol\",\n", + " eta=1,\n", + " n=3,\n", + " sd_dist=pm.Exponential.dist(1),\n", + " )\n", + " chol_xr = pmd.as_xtensor(chol, dims=(\"core1\", \"core2\"))\n", + "\n", + " mu = pmd.Normal(\"mu\", dims=(\"batch\", \"core1\"))\n", + " y = pmd.MvNormal(\n", + " \"y\",\n", + " mu,\n", + " chol=chol_xr,\n", + " core_dims=(\"core1\", \"core2\"),\n", + " )\n", + "\n", + "print(f\"{chol_xr.dims=}\")\n", + "print(f\"{mu.dims=}\")\n", + "print(f\"{y.dims=}\")" + ], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "chol_xr.dims=('core1', 'core2')\n", + "mu.dims=('batch', 'core1')\n", + "y.dims=('batch', 'core1')\n" + ] + } + ], + "execution_count": 22 + }, + { + "cell_type": "markdown", + "id": "3f0dd0fea61f8862", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-26T21:31:08.366068103Z", + "start_time": "2025-06-26T11:42:19.946909Z" + } + }, + "source": [ + "Note that we had to pass a \"regular\" Exponential distribution to the :class:`~pymc.LKJCholeskyCov` constructor. In general all distribution \"factories\" which are parametrized by unnamed distributions created with the :met:`pymc.distributions.distribution.Distribution.dist` method, won't work with variables created with the :mod:`pymc.dims` module.\n", + "\n", + "Overtime we hope to implement such functionality directly in the :mod:`pymc.dims` module, but for now you have to be aware of this limitation." + ] + }, + { + "cell_type": "markdown", + "id": "6464ae49ff629660", + "metadata": {}, + "source": [ + "## Case study: a splines model comes ashore" + ] + }, + { + "cell_type": "markdown", + "id": "5150d0fa3017d9b9", + "metadata": {}, + "source": [ + "### A model begging for vectorization" + ] + }, + { + "cell_type": "markdown", + "id": "c7c622bbde1c675", + "metadata": {}, + "source": [ + "The model below was presented by a user in a [bug report](). There may have been other reasons for approaching the model in this way, and it may have deviated from the user application for the purposes of providing a reproducible example for the bug report.\n", + "\n", + "With the disclaimer out of the way, we can say that the model is written in a way that is highly suboptimal. Specifically it misses (or actively breaks) many opportunities for vectorization. Seasoned Python programmers will know that python loops are SLOW, and that tools like numpy provide a way to escape from this handicap.\n", + "\n", + "PyMC code is not exactly numpy, for starters it uses a lazy symbolic computation library (PyTensor) that generates compiled code on demand. But it very much likes to be given numpy-like code. To begin with these graphs are much smaller and therefore easier to reason about (in fact the original bug could only be triggered for graphs with more than 500 nodes). Secondly, numpy-like graphs naturally translate to vectorized CPU and GPU code, which you want at the end of the day." + ] + }, + { + "cell_type": "code", + "id": "726f696e6bf17d44", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-30T15:53:39.748768Z", + "start_time": "2025-06-30T15:53:39.740157Z" + } + }, + "source": [ + "# Simulated data of some spline\n", + "N = 500\n", + "x_np = np.linspace(0, 10, N)\n", + "y_obs_np = np.piecewise(\n", + " x_np,\n", + " [x_np <= 3, (x_np > 3) & (x_np <= 7), x_np > 7],\n", + " [lambda x: 0.5 * x, lambda x: 1.5 + 0.2 * (x - 3), lambda x: 2.3 - 0.1 * (x - 7)],\n", + ")\n", + "y_obs_np += rng.normal(0, 0.2, size=N) # Add noise\n", + "\n", + "# Artificial groups\n", + "groups = [0, 1, 2]\n", + "group_idx_np = np.random.choice(groups, size=N)\n", + "\n", + "n_knots = 50\n", + "knots_np = np.linspace(0, 10, num=n_knots)" + ], + "outputs": [], + "execution_count": 23 + }, + { + "cell_type": "code", + "id": "2e8519125d17f0c8", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-30T15:53:40.367851Z", + "start_time": "2025-06-30T15:53:39.854962Z" + } + }, + "source": [ + "with pm.Model() as non_vectorized_splines_model:\n", + " sigma_beta0 = pm.HalfNormal(\"sigma_beta0\", sigma=10)\n", + " sigma = pm.HalfCauchy(\"sigma\", beta=1)\n", + "\n", + " # Create likelihood per group\n", + " for gr in groups:\n", + " idx = group_idx_np == gr\n", + "\n", + " beta0 = pm.HalfNormal(f\"beta0_{gr}\", sigma=sigma_beta0)\n", + " z = pm.Normal(f\"z_{gr}\", mu=0, sigma=2, shape=n_knots)\n", + "\n", + " delta_factors = pm.math.softmax(z)\n", + " slope_factors = 1 - pm.math.cumsum(delta_factors[:-1])\n", + " spline_slopes = pm.math.stack(\n", + " [beta0] + [beta0 * slope_factors[i] for i in range(n_knots - 1)]\n", + " )\n", + " beta = pm.Deterministic(\n", + " f\"beta_{gr}\",\n", + " pm.math.concatenate(([beta0], pm.math.diff(spline_slopes))),\n", + " )\n", + "\n", + " hinge_terms = [pm.math.maximum(0, x_np[idx] - knot) for knot in knots_np]\n", + " X = pm.math.stack([hinge_terms[i] for i in range(n_knots)], axis=1)\n", + "\n", + " mu = pm.math.dot(X, beta)\n", + "\n", + " pm.Normal(f\"y_{gr}\", mu=mu, sigma=sigma, observed=y_obs_np[idx])" + ], + "outputs": [], + "execution_count": 24 + }, + { + "cell_type": "code", + "id": "c56dcb51a07a5b54", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-30T15:53:40.700094Z", + "start_time": "2025-06-30T15:53:40.463806Z" + } + }, + "source": [ + "non_vectorized_splines_model.to_graphviz()" + ], + "outputs": [ + { + "data": { + "image/svg+xml": "\n\n\n\n\n\n\n\ncluster50\n\n50\n\n\ncluster176\n\n176\n\n\ncluster168\n\n168\n\n\ncluster156\n\n156\n\n\n\nbeta0_0\n\nbeta0_0\n~\nHalfnormal\n\n\n\nbeta_0\n\nbeta_0\n~\nDeterministic\n\n\n\nbeta0_0->beta_0\n\n\n\n\n\nsigma_beta0\n\nsigma_beta0\n~\nHalfnormal\n\n\n\nsigma_beta0->beta0_0\n\n\n\n\n\nbeta0_2\n\nbeta0_2\n~\nHalfnormal\n\n\n\nsigma_beta0->beta0_2\n\n\n\n\n\nbeta0_1\n\nbeta0_1\n~\nHalfnormal\n\n\n\nsigma_beta0->beta0_1\n\n\n\n\n\nbeta_2\n\nbeta_2\n~\nDeterministic\n\n\n\nbeta0_2->beta_2\n\n\n\n\n\nbeta_1\n\nbeta_1\n~\nDeterministic\n\n\n\nbeta0_1->beta_1\n\n\n\n\n\nsigma\n\nsigma\n~\nHalfcauchy\n\n\n\ny_0\n\ny_0\n~\nNormal\n\n\n\nsigma->y_0\n\n\n\n\n\ny_1\n\ny_1\n~\nNormal\n\n\n\nsigma->y_1\n\n\n\n\n\ny_2\n\ny_2\n~\nNormal\n\n\n\nsigma->y_2\n\n\n\n\n\nz_1\n\nz_1\n~\nNormal\n\n\n\nz_1->beta_1\n\n\n\n\n\nz_2\n\nz_2\n~\nNormal\n\n\n\nz_2->beta_2\n\n\n\n\n\nbeta_0->y_0\n\n\n\n\n\nbeta_1->y_1\n\n\n\n\n\nbeta_2->y_2\n\n\n\n\n\nz_0\n\nz_0\n~\nNormal\n\n\n\nz_0->beta_0\n\n\n\n\n\n", + "text/plain": [ + "" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 25 + }, + { + "cell_type": "markdown", + "id": "33d16a1bef908f3f", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-26T21:31:52.686637Z", + "start_time": "2025-06-26T21:31:52.509222Z" + } + }, + "source": [ + "### Old style vectorization" + ] + }, + { + "cell_type": "markdown", + "id": "c8feda2a4b6bb25d", + "metadata": {}, + "source": [ + "With some work we can rewrite the model to use vectorized operations.\n", + "\n", + "We'll introduce `coords` and :func:`~pymc.Data` to make the model more self-contained." + ] + }, + { + "cell_type": "code", + "id": "96c129be4eadafaf", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-30T15:53:40.891985Z", + "start_time": "2025-06-30T15:53:40.820861Z" + } + }, + "source": [ + "coords = {\n", + " \"group\": range(3),\n", + " \"knots\": range(n_knots),\n", + " \"obs\": range(N),\n", + "}\n", + "with pm.Model(coords=coords) as vectorized_splines_model:\n", + " x = pm.Data(\"x\", x_np, dims=\"obs\")\n", + " y_obs = pm.Data(\"y_obs\", y_obs_np, dims=\"obs\")\n", + "\n", + " knots = pm.Data(\"knots\", knots_np, dims=\"knot\")\n", + "\n", + " sigma = pm.HalfCauchy(\"sigma\", beta=1)\n", + " sigma_beta0 = pm.HalfNormal(\"sigma_beta0\", sigma=10)\n", + " beta0 = pm.HalfNormal(\"beta_0\", sigma=sigma_beta0, dims=\"group\")\n", + " z = pm.Normal(\"z\", dims=(\"group\", \"knot\"))\n", + "\n", + " delta_factors = pm.math.softmax(z, axis=-1) # (groups, knot)\n", + " slope_factors = 1 - pm.math.cumsum(delta_factors[:, :-1], axis=-1) # (groups, knot-1)\n", + " spline_slopes = pm.math.concatenate(\n", + " [beta0[:, None], beta0[:, None] * slope_factors], axis=-1\n", + " ) # (groups, knot-1)\n", + " beta = pm.math.concatenate(\n", + " [beta0[:, None], pm.math.diff(spline_slopes, axis=-1)], axis=-1\n", + " ) # (groups, knot)\n", + "\n", + " beta = pm.Deterministic(\"beta\", beta, dims=(\"group\", \"knot\"))\n", + "\n", + " X = pm.math.maximum(0, x[:, None] - knots[None, :]) # (n, knot)\n", + " mu = (X * beta[group_idx_np]).sum(-1) # ((n, knots) * (n, knots)).sum(-1) = (n,)\n", + " y = pm.Normal(\"y\", mu=mu, sigma=sigma, observed=y_obs, dims=\"obs\")" + ], + "outputs": [], + "execution_count": 26 + }, + { + "cell_type": "code", + "id": "bd94fa06a1190a05", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-30T15:53:41.063986Z", + "start_time": "2025-06-30T15:53:40.980450Z" + } + }, + "source": [ + "vectorized_splines_model.to_graphviz()" + ], + "outputs": [ + { + "data": { + "image/svg+xml": "\n\n\n\n\n\n\n\nclusterobs (500)\n\nobs (500)\n\n\nclusterknot (50)\n\nknot (50)\n\n\nclustergroup (3)\n\ngroup (3)\n\n\nclustergroup (3) x knot (50)\n\ngroup (3) x knot (50)\n\n\n\ny_obs\n\ny_obs\n~\nData\n\n\n\nx\n\nx\n~\nData\n\n\n\ny\n\ny\n~\nNormal\n\n\n\nx->y\n\n\n\n\n\ny->y_obs\n\n\n\n\n\nknots\n\nknots\n~\nData\n\n\n\nknots->y\n\n\n\n\n\nsigma\n\nsigma\n~\nHalfcauchy\n\n\n\nsigma->y\n\n\n\n\n\nsigma_beta0\n\nsigma_beta0\n~\nHalfnormal\n\n\n\nbeta_0\n\nbeta_0\n~\nHalfnormal\n\n\n\nsigma_beta0->beta_0\n\n\n\n\n\nbeta\n\nbeta\n~\nDeterministic\n\n\n\nbeta_0->beta\n\n\n\n\n\nbeta->y\n\n\n\n\n\nz\n\nz\n~\nNormal\n\n\n\nz->beta\n\n\n\n\n\n", + "text/plain": [ + "" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 27 + }, + { + "cell_type": "markdown", + "id": "a98bbe7f2e17783a", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-26T21:32:15.059906Z", + "start_time": "2025-06-26T21:32:15.024281Z" + } + }, + "source": [ + "The graphviz does not show the whole complexity of the models. The biggest problem lied in the multiple list comprehensions used in the origin model. Every iteration extends the computational graph (basically unrolling the python loop), which becomes unfeasible for PyMC to handle.\n", + "\n", + "The use of 3 likelihood and sets of priors is otherwise fine and can make more sense in some cases." + ] + }, + { + "cell_type": "code", + "id": "90f8f4ee3fd9f2a3", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-30T15:53:41.192185Z", + "start_time": "2025-06-30T15:53:41.162347Z" + } + }, + "source": [ + "from pytensor.graph import FunctionGraph\n", + "\n", + "non_vectorized_splines_model_graph = FunctionGraph(\n", + " outputs=non_vectorized_splines_model.observed_RVs, clone=False\n", + ")\n", + "vectorized_model_nodes = len(\n", + " FunctionGraph(outputs=vectorized_splines_model.basic_RVs, clone=False).apply_nodes\n", + ")\n", + "print(f\"Non-vectorized model has {len(non_vectorized_splines_model_graph.apply_nodes)} nodes\")\n", + "print(f\"Vectorized model has {vectorized_model_nodes} nodes\")" + ], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Non-vectorized model has 806 nodes\n", + "Vectorized model has 38 nodes\n" + ] + } + ], + "execution_count": 28 + }, + { + "cell_type": "markdown", + "id": "a532b2ec365bfecf", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-26T21:32:17.432572Z", + "start_time": "2025-06-26T21:32:17.415088Z" + } + }, + "source": [ + "### Vectorization with dims" + ] + }, + { + "cell_type": "markdown", + "id": "15d0e3dcfd9eb94a", + "metadata": {}, + "source": [ + "It is however not trivial to write (or translate into) vectorized code like this. It takes some time to grok the patterns and there are many pain-points. We reckon that the conversion between the first and second model took at least one hour, including debugging and testing that the models were indeed equivalent.\n", + "\n", + "We believe that the :mod:`pymc.dims` module will facilitate writing efficient vectorized code. So let's try and do that." + ] + }, + { + "cell_type": "code", + "id": "33471edc920db393", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-30T15:53:41.689900Z", + "start_time": "2025-06-30T15:53:41.666926Z" + } + }, + "source": [ + "with pm.Model(coords=coords) as dims_splines_model:\n", + " x = pmd.Data(\"x\", x_np, dims=\"obs\")\n", + " y_obs = pmd.Data(\"y_obs\", y_obs_np, dims=\"obs\")\n", + " knots = pmd.Data(\"knots\", knots_np, dims=(\"knot\",))\n", + " group_idx = pmd.math.as_xtensor(group_idx_np, dims=(\"obs\",))\n", + "\n", + " sigma = pmd.HalfCauchy(\"sigma\", beta=1)\n", + " sigma_beta0 = pmd.HalfNormal(\"sigma_beta0\", sigma=10)\n", + " beta0 = pmd.HalfNormal(\"beta_0\", sigma=sigma_beta0, dims=(\"group\",))\n", + " z = pmd.Normal(\"z\", dims=(\"group\", \"knot\"))\n", + "\n", + " delta_factors = pmd.math.softmax(z, dim=\"knot\")\n", + " slope_factors = 1 - delta_factors.isel(knot=slice(None, -1)).cumsum(\"knot\")\n", + " spline_slopes = pmd.concat([beta0, beta0 * slope_factors], dim=\"knot\")\n", + " beta = pm.Deterministic(\"beta\", pmd.concat([beta0, spline_slopes.diff(\"knot\")], dim=\"knot\"))\n", + "\n", + " X = pmd.math.maximum(0, x - knots)\n", + " mu = (X * beta.isel(group=group_idx)).sum(\"knot\")\n", + " y = pmd.Normal(\"y\", mu=mu, sigma=sigma, observed=y_obs)" + ], + "outputs": [], + "execution_count": 29 + }, + { + "cell_type": "code", + "id": "5f55cf90d737dcf", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-30T15:53:42.362928Z", + "start_time": "2025-06-30T15:53:42.151483Z" + } + }, + "source": [ + "dims_splines_model.to_graphviz()" + ], + "outputs": [ + { + "data": { + "image/svg+xml": "\n\n\n\n\n\n\n\nclusterobs (500)\n\nobs (500)\n\n\nclusterknot (50)\n\nknot (50)\n\n\nclustergroup (3)\n\ngroup (3)\n\n\nclustergroup (3) x knot (50)\n\ngroup (3) x knot (50)\n\n\ncluster3 x 50\n\n3 x 50\n\n\n\ny_obs\n\ny_obs\n~\nData\n\n\n\nx\n\nx\n~\nData\n\n\n\ny\n\ny\n~\nNormal\n\n\n\nx->y\n\n\n\n\n\ny->y_obs\n\n\n\n\n\nknots\n\nknots\n~\nData\n\n\n\nknots->y\n\n\n\n\n\nsigma\n\nsigma\n~\nHalfcauchy\n\n\n\nsigma->y\n\n\n\n\n\nsigma_beta0\n\nsigma_beta0\n~\nHalfnormal\n\n\n\nbeta_0\n\nbeta_0\n~\nHalfnormal\n\n\n\nsigma_beta0->beta_0\n\n\n\n\n\nbeta\n\nbeta\n~\nDeterministic\n\n\n\nbeta_0->beta\n\n\n\n\n\nz\n\nz\n~\nNormal\n\n\n\nz->beta\n\n\n\n\n\nbeta->y\n\n\n\n\n\n", + "text/plain": [ + "" + ] + }, + "execution_count": 30, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 30 + }, + { + "cell_type": "code", + "id": "3a9dd4669342ff9e", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-30T15:53:42.797760Z", + "start_time": "2025-06-30T15:53:42.794464Z" + } + }, + "source": [ + "# Comment out if you want to wait a long while for the results\n", + "# non_vectorized_splines_model.point_logps()" + ], + "outputs": [], + "execution_count": 31 + }, + { + "cell_type": "code", + "id": "86701f276baaa7c", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-30T15:53:44.309696Z", + "start_time": "2025-06-30T15:53:43.273978Z" + } + }, + "source": [ + "vectorized_splines_model.point_logps()" + ], + "outputs": [ + { + "data": { + "text/plain": [ + "{'sigma': np.float64(-1.14),\n", + " 'sigma_beta0': np.float64(-0.73),\n", + " 'beta_0': np.float64(-2.18),\n", + " 'z': np.float64(-137.84),\n", + " 'y': np.float64(-319962.47)}" + ] + }, + "execution_count": 32, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 32 + }, + { + "cell_type": "code", + "id": "f34c353131915cef", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-30T15:53:46.279261Z", + "start_time": "2025-06-30T15:53:45.344064Z" + } + }, + "source": [ + "dims_splines_model.point_logps()" + ], + "outputs": [ + { + "data": { + "text/plain": [ + "{'sigma': np.float64(-1.14),\n", + " 'sigma_beta0': np.float64(-0.73),\n", + " 'beta_0': np.float64(-2.18),\n", + " 'z': np.float64(-137.84),\n", + " 'y': np.float64(-319962.47)}" + ] + }, + "execution_count": 33, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 33 + }, + { + "cell_type": "markdown", + "id": "93ec3533ec122baa", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-26T21:32:29.344857Z", + "start_time": "2025-06-26T21:32:28.811467Z" + } + }, + "source": [ + "## What about coordinates?" + ] + }, + { + "cell_type": "markdown", + "id": "b20324c9f4f65d4b", + "metadata": {}, + "source": [ + "The new xtensor variable and operations do not propagate information about coordinates, which means you cannot perform coordinate-related operations like you can in xarray.\n", + "This includes things like `sel`, `loc`, `drop`.\n", + "\n", + "While it is perhaps disappointing for someone used to xarray, it is a necessary trade-off to allow PyMC to evaluate the model in a performant way. Just like before, PyMC uses PyTensor under the hood, which at the end of the day compiles functions down into the C (or numba or JAX) backends. None of these backends support dims or coordinates. To be able to keep using these backends, PyTensor rewrites the xtensor operations into equivalent tensor operations, which are pretty much abstract numpy code. This is relative easy to do because it's mostly about aligning dimensions for broadcasting or indexing correctly.\n", + "\n", + "Rewriting coordinate-related operations into numpy-like code is a different matter. Many such operations don't have straightforward equivalency, they are more like querying or joining a database than performing array operations.\n", + "\n", + "PyMC models will keep supporting the `coords` argument as a way to specify dimensions of model variables. But for modelling purposes, only the dimension names and their lengths play a role." + ] + }, + { + "cell_type": "markdown", + "id": "f963a53c229c04b9", + "metadata": {}, + "source": [ + "### One final note of caution on coordinates" + ] + }, + { + "cell_type": "markdown", + "id": "626c004fa7abdccf", + "metadata": {}, + "source": [ + "When you provide coords to a PyMC model, they will be attached to any functions that returns xarray or InferenceData objects.\n", + "\n", + "There is one potential issue with this. Like in xarray it is valid to have multiple arrays with the same dims but different shapes. Some operations, like indexing or concatenating, act on this premise. This is also possible with PyMC models, and in fact we had such a case in the last example when we indexed the spline variable.\n", + "\n", + "After sampling, PyMC will try to reattach the coordinates to any computed variables (i.e., distributions, data or deterministics), but these might not have the right shape, or they might not be correctly aligned. \n", + "\n", + "We illustrate this with next model, where we have two variables with the `a` dim but different shapes, and only one matches the shape of the coordinates specified in the model. When PyMC tries to convert the results of sampling to InferenceData, it will issue a warning and refuse to propagate the original coordinates." + ] + }, + { + "cell_type": "code", + "id": "2c6f3a2b70a4d77d", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-30T15:53:46.850715Z", + "start_time": "2025-06-30T15:53:46.750208Z" + } + }, + "source": [ + "with pm.Model(coords={\"a\": [-1, 0, 1]}) as m:\n", + " x = pmd.Normal(\"x\", dims=(\"a\",))\n", + " y = pmd.Deterministic(\"y\", x.isel(a=slice(1, None)))\n", + " assert y.dims == (\"a\",)\n", + "\n", + " idata = pm.sample_prior_predictive()\n", + "idata.prior[\"y\"].coords" + ], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Sampling: [x]\n", + "/home/ricardo/Documents/pymc/pymc/backends/arviz.py:70: UserWarning: Incompatible coordinate length of 3 found for dimension a of variable y.\n", + "The originate coordinates for this dim will not be included in the returned dataset for any of the variables. Instead they will default to `np.arange(var_length)` and the shorter variables will be right-padded with nan.\n", + "To make this warning into an error set `pymc.backends.arviz.RAISE_ON_INCOMPATIBLE_COORD_LENGTHS` to `True`\n", + " warnings.warn(\n" + ] + }, + { + "data": { + "text/plain": [ + "Coordinates:\n", + " * chain (chain) int64 8B 0\n", + " * draw (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499\n", + " * a (a) int64 24B 0 1 2" + ] + }, + "execution_count": 34, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 34 + }, + { + "cell_type": "markdown", + "id": "a20137348aa4f5bf", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-26T21:31:08.371358319Z", + "start_time": "2025-06-26T11:42:20.084741Z" + } + }, + "source": [ + "A user who wishes to retain coondinates for further analysis will have to manually specify them after sampling or to rename the intermediate dimensions to something else that has compatible coordinates." + ] + }, + { + "cell_type": "code", + "id": "466f152cb249d1d2", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-30T15:53:47.572296Z", + "start_time": "2025-06-30T15:53:47.486879Z" + } + }, + "source": [ + "with pm.Model(coords={\"a\": [-3, -2, -1], \"a*\": [-2, -1]}) as m:\n", + " x = pmd.Normal(\"x\", dims=(\"a\",))\n", + " y = pmd.Deterministic(\"y\", x.isel(a=slice(1, None)).rename({\"a\": \"a*\"}))\n", + " assert y.dims == (\"a*\",)\n", + " # You can rename back to the original name if you need it for further operations\n", + " y = y.rename({\"a*\": \"a\"})\n", + "\n", + " idata = pm.sample_prior_predictive(draws=1)\n", + "idata.prior[\"y\"].coords" + ], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Sampling: [x]\n" + ] + }, + { + "data": { + "text/plain": [ + "Coordinates:\n", + " * chain (chain) int64 8B 0\n", + " * draw (draw) int64 8B 0\n", + " * a* (a*) int64 16B -2 -1" + ] + }, + "execution_count": 35, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 35 + }, + { + "cell_type": "markdown", + "id": "9a051144b1259fc8", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-26T21:31:08.371499143Z", + "start_time": "2025-06-26T11:42:20.159762Z" + } + }, + "source": [ + "Note that when doing advanced indexing the name of the indexed dimension can be controlled by the name of the indexing xtensor" + ] + }, + { + "cell_type": "code", + "id": "87726153a7d1fd2", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-30T15:53:48.186501Z", + "start_time": "2025-06-30T15:53:48.181799Z" + } + }, + "source": "x.isel(a=pmd.math.as_xtensor([0, 1, 2], dims=(\"a*\",))).dims", + "outputs": [ + { + "data": { + "text/plain": [ + "('a*',)" + ] + }, + "execution_count": 36, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 36 + }, + { + "cell_type": "markdown", + "id": "38a8684f-4849-4f69-a7c7-7433182358c2", + "metadata": {}, + "source": [ + "Silent bugs can still happen if the shapes are compatible with the wrong coords as in the example below" + ] + }, + { + "cell_type": "code", + "id": "8443e99c-8c90-40e1-8189-943e374c1387", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-30T15:53:48.819667Z", + "start_time": "2025-06-30T15:53:48.747097Z" + } + }, + "source": [ + "with pm.Model(coords={\"a\": [1, 2, 3]}):\n", + " x = pmd.Normal(\"x\", dims=(\"a\",))\n", + " pmd.Deterministic(\"x_reversed\", x[::-1])\n", + " idata = pm.sample_prior_predictive(draws=1)\n", + "idata.prior[\"x_reversed\"].coords" + ], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Sampling: [x]\n" + ] + }, + { + "data": { + "text/plain": [ + "Coordinates:\n", + " * chain (chain) int64 8B 0\n", + " * draw (draw) int64 8B 0\n", + " * a (a) int64 24B 1 2 3" + ] + }, + "execution_count": 37, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 37 + }, + { + "cell_type": "markdown", + "id": "1fd902d9-33fd-40d6-a66d-fc15f1064b5c", + "metadata": {}, + "source": [ + "Whereas xarray would flip the coordinates" + ] + }, + { + "cell_type": "code", + "id": "2d282c1a-8a30-476d-8d92-94ea9f9a6971", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-30T15:53:49.345188Z", + "start_time": "2025-06-30T15:53:49.337999Z" + } + }, + "source": [ + "idata.prior[\"x\"].isel(a=slice(None, None, -1)).coords" + ], + "outputs": [ + { + "data": { + "text/plain": [ + "Coordinates:\n", + " * chain (chain) int64 8B 0\n", + " * draw (draw) int64 8B 0\n", + " * a (a) int64 24B 3 2 1" + ] + }, + "execution_count": 38, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 38 + }, + { + "cell_type": "markdown", + "id": "fa8552be-16b8-4dd0-9bdc-38df8d5609b1", + "metadata": {}, + "source": [ + "This is a symptom of PyMC inability to reason about coords symbolically. Is not a new problem with the :mod:`pymc.dims` module, but it is made more likely because the functions from the :mod:`pymc.dims` module require and propagate dimension names everywhere. We are still working on how to work around the problem of incompatible coordinates.\n", + "\n", + "We remind users that :func:`~pymc.Deterministic` are never required in a model, they are just a way to request that some intermediate operations be included in the returned results. If you use them, pay extra attention to whether the model level coordinates are appropriate for the variable in the :func:`~pymc.Deterministic`." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "pymc", + "language": "python", + "name": "pymc" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.8" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}