Skip to content

Commit

Permalink
ChiSquared now returns a Gamma random variable (pymc-devs#7007)
Browse files Browse the repository at this point in the history
  • Loading branch information
wd60622 authored Nov 17, 2023
1 parent cac99d9 commit 305eb39
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 48 deletions.
31 changes: 12 additions & 19 deletions pymc/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
BetaRV,
_gamma,
cauchy,
chisquare,
exponential,
gumbel,
halfcauchy,
Expand All @@ -56,7 +55,7 @@
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.variable import TensorConstant

from pymc.logprob.abstract import _logcdf_helper, _logprob_helper
from pymc.logprob.abstract import _logprob_helper
from pymc.logprob.basic import icdf

try:
Expand Down Expand Up @@ -2374,16 +2373,21 @@ def logcdf(value, alpha, beta):
)


class ChiSquared(PositiveContinuous):
class ChiSquared:
r"""
:math:`\chi^2` log-likelihood.
This is the distribution from the sum of the squares of :math:`\nu` independent standard normal random variables or a special
case of the gamma distribution with :math:`\alpha = \nu/2` and :math:`\beta = 1/2`.
The pdf of this distribution is
.. math::
f(x \mid \nu) = \frac{x^{(\nu-2)/2}e^{-x/2}}{2^{\nu/2}\Gamma(\nu/2)}
Read more about the :math:`\chi^2` distribution at https://en.wikipedia.org/wiki/Chi-squared_distribution
.. plot::
:context: close-figs
Expand Down Expand Up @@ -2413,24 +2417,13 @@ class ChiSquared(PositiveContinuous):
nu : tensor_like of float
Degrees of freedom (nu > 0).
"""
rv_op = chisquare

@classmethod
def dist(cls, nu, *args, **kwargs):
nu = pt.as_tensor_variable(floatX(nu))
return super().dist([nu], *args, **kwargs)

def moment(rv, size, nu):
moment = nu
if not rv_size_is_none(size):
moment = pt.full(size, moment)
return moment
def __new__(cls, name, nu, **kwargs):
return Gamma(name, alpha=nu / 2, beta=1 / 2, **kwargs)

def logp(value, nu):
return _logprob_helper(Gamma.dist(alpha=nu / 2, beta=0.5), value)

def logcdf(value, nu):
return _logcdf_helper(Gamma.dist(alpha=nu / 2, beta=0.5), value)
@classmethod
def dist(cls, nu, **kwargs):
return Gamma.dist(alpha=nu / 2, beta=1 / 2, **kwargs)


# TODO: Remove this once logp for multiplication is working!
Expand Down
26 changes: 0 additions & 26 deletions tests/distributions/test_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -1132,19 +1132,6 @@ def test_beta_moment(self, alpha, beta, size, expected):
pm.Beta("x", alpha=alpha, beta=beta, size=size)
assert_moment_is_expected(model, expected)

@pytest.mark.parametrize(
"nu, size, expected",
[
(1, None, 1),
(1, 5, np.full(5, 1)),
(np.arange(1, 6), None, np.arange(1, 6)),
],
)
def test_chisquared_moment(self, nu, size, expected):
with pm.Model() as model:
pm.ChiSquared("x", nu=nu, size=size)
assert_moment_is_expected(model, expected)

@pytest.mark.parametrize(
"lam, size, expected",
[
Expand Down Expand Up @@ -2243,19 +2230,6 @@ class TestInverseGammaMuSigma(BaseTestDistributionRandom):
checks_to_run = ["check_pymc_params_match_rv_op"]


class TestChiSquared(BaseTestDistributionRandom):
pymc_dist = pm.ChiSquared
pymc_dist_params = {"nu": 2.0}
expected_rv_op_params = {"nu": 2.0}
reference_dist_params = {"df": 2.0}
reference_dist = seeded_numpy_distribution_builder("chisquare")
checks_to_run = [
"check_pymc_params_match_rv_op",
"check_pymc_draws_match_reference",
"check_rv_size",
]


class TestLogistic(BaseTestDistributionRandom):
pymc_dist = pm.Logistic
pymc_dist_params = {"mu": 1.0, "s": 2.0}
Expand Down
4 changes: 3 additions & 1 deletion tests/logprob/test_transform_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from pytensor.graph import FunctionGraph
from pytensor.graph.basic import equal_computations

import pymc as pm

from pymc.distributions.transforms import _default_transform, log, logodds
from pymc.logprob import conditional_logp
from pymc.logprob.abstract import MeasurableVariable, _logprob
Expand Down Expand Up @@ -154,7 +156,7 @@ def test_original_values_output_dict():
(),
),
(
pt.random.chisquare,
pm.ChiSquared.dist,
(1.5,),
lambda df: sp.stats.chi2(df),
(),
Expand Down
4 changes: 2 additions & 2 deletions tests/logprob/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@

from pytensor.graph.basic import equal_computations

from pymc.distributions.continuous import Cauchy
from pymc.distributions.continuous import Cauchy, ChiSquared
from pymc.logprob.basic import conditional_logp, icdf, logcdf, logp
from pymc.logprob.transforms import (
ArccoshTransform,
Expand Down Expand Up @@ -431,7 +431,7 @@ def test_sqr_transform(self):

def test_sqrt_transform(self):
# The sqrt of a chisquare with n df is a chi distribution with n df
x_rv = pt.sqrt(pt.random.chisquare(df=3, size=(4,)))
x_rv = pt.sqrt(ChiSquared.dist(nu=3, size=(4,)))
x_rv.name = "x"

x_vv = x_rv.clone()
Expand Down

0 comments on commit 305eb39

Please sign in to comment.