Skip to content

Commit

Permalink
Rename pm.Constant to pm.DiracDelta (pymc-devs#5903)
Browse files Browse the repository at this point in the history
  • Loading branch information
cluhmann authored Jun 18, 2022
1 parent 66215ac commit f9b749b
Show file tree
Hide file tree
Showing 8 changed files with 66 additions and 38 deletions.
2 changes: 1 addition & 1 deletion docs/source/api/distributions/discrete.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Discrete
DiscreteWeibull
Poisson
NegativeBinomial
Constant
DiracDelta
ZeroInflatedPoisson
ZeroInflatedBinomial
ZeroInflatedNegativeBinomial
Expand Down
2 changes: 2 additions & 0 deletions pymc/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
Binomial,
Categorical,
Constant,
DiracDelta,
DiscreteUniform,
DiscreteWeibull,
Geometric,
Expand Down Expand Up @@ -140,6 +141,7 @@
"Bernoulli",
"Poisson",
"NegativeBinomial",
"DiracDelta",
"Constant",
"ZeroInflatedPoisson",
"ZeroInflatedNegativeBinomial",
Expand Down
42 changes: 30 additions & 12 deletions pymc/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
"DiscreteWeibull",
"Poisson",
"NegativeBinomial",
"DiracDelta",
"Constant",
"ZeroInflatedPoisson",
"ZeroInflatedBinomial",
Expand Down Expand Up @@ -1337,11 +1338,11 @@ def logp(value, p):
)


class ConstantRV(RandomVariable):
name = "constant"
class DiracDeltaRV(RandomVariable):
name = "diracdelta"
ndim_supp = 0
ndims_params = [0]
_print_name = ("Constant", "\\operatorname{Constant}")
_print_name = ("DiracDelta", "\\operatorname{DiracDelta}")

def make_node(self, rng, size, dtype, c):
c = at.as_tensor_variable(c)
Expand All @@ -1354,22 +1355,22 @@ def rng_fn(cls, rng, c, size=None):
return np.full(size, c)


constant = ConstantRV()
diracdelta = DiracDeltaRV()


class Constant(Discrete):
class DiracDelta(Discrete):
r"""
Constant log-likelihood.
DiracDelta log-likelihood.
Parameters
----------
c: float or int
Constant parameter. The dtype of `c` determines the dtype of the distribution.
This can affect which sampler is assigned to Constant variables, or variables
that use Constant, such as Mixtures.
Dirac Delta parameter. The dtype of `c` determines the dtype of the distribution.
This can affect which sampler is assigned to DiracDelta variables, or variables
that use DiracDelta, such as Mixtures.
"""

rv_op = constant
rv_op = diracdelta

@classmethod
def dist(cls, c, *args, **kwargs):
Expand All @@ -1385,7 +1386,7 @@ def moment(rv, size, c):

def logp(value, c):
r"""
Calculate log-probability of Constant distribution at specified value.
Calculate log-probability of DiracDelta distribution at specified value.
Parameters
----------
Expand All @@ -1411,6 +1412,23 @@ def logcdf(value, c):
)


class Constant:
def __new__(cls, *args, **kwargs):
warnings.warn(
"pm.Constant has been deprecated. Use pm.DiracDelta instead.",
FutureWarning,
)
return DiracDelta(*args, **kwargs)

@classmethod
def dist(cls, *args, **kwargs):
warnings.warn(
"pm.Constant has been deprecated. Use pm.DiracDelta instead.",
FutureWarning,
)
return DiracDelta.dist(*args, **kwargs)


def _zero_inflated_mixture(*, name, nonzero_p, nonzero_dist, **kwargs):
"""Helper function to create a zero-inflated mixture
Expand All @@ -1419,7 +1437,7 @@ def _zero_inflated_mixture(*, name, nonzero_p, nonzero_dist, **kwargs):
nonzero_p = at.as_tensor_variable(floatX(nonzero_p))
weights = at.stack([1 - nonzero_p, nonzero_p], axis=-1)
comp_dists = [
Constant.dist(0),
DiracDelta.dist(0),
nonzero_dist,
]
if name is not None:
Expand Down
16 changes: 12 additions & 4 deletions pymc/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def polyagamma_cdf(*args, **kwargs):
Cauchy,
ChiSquared,
Constant,
DiracDelta,
Dirichlet,
DirichletMultinomial,
DiscreteUniform,
Expand Down Expand Up @@ -1729,9 +1730,9 @@ def test_poisson(self):
{"mu": Rplus},
)

def test_constantdist(self):
check_logp(Constant, I, {"c": I}, lambda value, c: np.log(c == value))
check_logcdf(Constant, I, {"c": I}, lambda value, c: np.log(value >= c))
def test_diracdeltadist(self):
check_logp(DiracDelta, I, {"c": I}, lambda value, c: np.log(c == value))
check_logcdf(DiracDelta, I, {"c": I}, lambda value, c: np.log(value >= c))

def test_zeroinflatedpoisson(self):
def logp_fn(value, psi, mu):
Expand Down Expand Up @@ -3065,7 +3066,7 @@ def test_issue_4499(self):
assert_almost_equal(m.compile_logp()({"x": np.ones(10)}), -np.log(2) * 10)

with pm.Model(check_bounds=False) as m:
x = pm.Constant("x", 1, size=10)
x = pm.DiracDelta("x", 1, size=10)
assert_almost_equal(m.compile_logp()({"x": np.ones(10)}), 0 * 10)


Expand Down Expand Up @@ -3328,3 +3329,10 @@ def test_zero_inflated_dists_dtype_and_broadcast(dist, non_psi_args):
x = dist([0.5, 0.5, 0.5], *non_psi_args)
assert x.dtype in discrete_types
assert x.eval().shape == (3,)


def test_constantdist_deprecated():
with pytest.warns(FutureWarning, match="DiracDelta"):
with Model() as m:
x = Constant("x", c=1)
assert isinstance(x.owner.op, pm.distributions.discrete.DiracDeltaRV)
6 changes: 3 additions & 3 deletions pymc/tests/test_distributions_moments.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
Categorical,
Cauchy,
ChiSquared,
Constant,
DensityDist,
DiracDelta,
Dirichlet,
DirichletMultinomial,
DiscreteUniform,
Expand Down Expand Up @@ -617,9 +617,9 @@ def test_negative_binomial_moment(n, p, size, expected):
(np.arange(1, 6), None, np.arange(1, 6)),
],
)
def test_constant_moment(c, size, expected):
def test_diracdelta_moment(c, size, expected):
with Model() as model:
Constant("x", c=c, size=size)
DiracDelta("x", c=c, size=size)
assert_moment_is_expected(model, expected)


Expand Down
22 changes: 11 additions & 11 deletions pymc/tests/test_distributions_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -1501,17 +1501,17 @@ def discrete_uniform_rng_fn(self, size, lower, upper, rng):
]


class TestConstant(BaseTestDistributionRandom):
def constant_rng_fn(self, size, c):
class TestDiracDelta(BaseTestDistributionRandom):
def diracdelta_rng_fn(self, size, c):
if size is None:
return c
return np.full(size, c)

pymc_dist = pm.Constant
pymc_dist = pm.DiracDelta
pymc_dist_params = {"c": 3}
expected_rv_op_params = {"c": 3}
reference_dist_params = {"c": 3}
reference_dist = lambda self: self.constant_rng_fn
reference_dist = lambda self: self.diracdelta_rng_fn
checks_to_run = [
"check_pymc_params_match_rv_op",
"check_pymc_draws_match_reference",
Expand All @@ -1524,10 +1524,10 @@ def constant_rng_fn(self, size, c):
)
def test_dtype(self, floatX):
with aesara.config.change_flags(floatX=floatX):
assert pm.Constant.dist(2**4).dtype == "int8"
assert pm.Constant.dist(2**16).dtype == "int32"
assert pm.Constant.dist(2**32).dtype == "int64"
assert pm.Constant.dist(2.0).dtype == floatX
assert pm.DiracDelta.dist(2**4).dtype == "int8"
assert pm.DiracDelta.dist(2**16).dtype == "int32"
assert pm.DiracDelta.dist(2**32).dtype == "int64"
assert pm.DiracDelta.dist(2.0).dtype == floatX


class TestOrderedLogistic(BaseTestDistributionRandom):
Expand Down Expand Up @@ -1860,8 +1860,8 @@ def ref_rand(size, n, eta):

class TestLKJCholeskyCov(BaseTestDistributionRandom):
pymc_dist = _LKJCholeskyCov
pymc_dist_params = {"eta": 1.0, "n": 3, "sd_dist": pm.Constant.dist([0.5, 1.0, 2.0])}
expected_rv_op_params = {"n": 3, "eta": 1.0, "sd_dist": pm.Constant.dist([0.5, 1.0, 2.0])}
pymc_dist_params = {"eta": 1.0, "n": 3, "sd_dist": pm.DiracDelta.dist([0.5, 1.0, 2.0])}
expected_rv_op_params = {"n": 3, "eta": 1.0, "sd_dist": pm.DiracDelta.dist([0.5, 1.0, 2.0])}
size = None

sizes_to_check = [None, (), 1, (1,), 5, (4, 5), (2, 4, 2)]
Expand Down Expand Up @@ -1891,7 +1891,7 @@ def check_rv_size(self):
def check_draws_match_expected(self):
# TODO: Find better comparison:
rng = aesara.shared(self.get_random_state(reset=True))
x = _LKJCholeskyCov.dist(n=2, eta=10_000, sd_dist=pm.Constant.dist([0.5, 2.0]), rng=rng)
x = _LKJCholeskyCov.dist(n=2, eta=10_000, sd_dist=pm.DiracDelta.dist([0.5, 2.0]), rng=rng)
assert np.all(np.abs(x.eval() - np.array([0.5, 0, 2.0])) < 0.01)


Expand Down
8 changes: 4 additions & 4 deletions pymc/tests/test_distributions_timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from pymc.aesaraf import floatX
from pymc.distributions.continuous import Flat, HalfNormal, Normal
from pymc.distributions.discrete import Constant
from pymc.distributions.discrete import DiracDelta
from pymc.distributions.logprob import logp
from pymc.distributions.multivariate import Dirichlet
from pymc.distributions.timeseries import (
Expand Down Expand Up @@ -100,11 +100,11 @@ class TestGaussianRandomWalkRandom(BaseTestDistributionRandom):
size = None

pymc_dist = pm.GaussianRandomWalk
pymc_dist_params = {"mu": 1.0, "sigma": 2, "init_dist": pm.Constant.dist(0), "steps": 4}
pymc_dist_params = {"mu": 1.0, "sigma": 2, "init_dist": pm.DiracDelta.dist(0), "steps": 4}
expected_rv_op_params = {
"mu": 1.0,
"sigma": 2,
"init_dist": pm.Constant.dist(0),
"init_dist": pm.DiracDelta.dist(0),
"steps": 4,
}

Expand Down Expand Up @@ -455,7 +455,7 @@ def test_multivariate_init_dist(self):
)
def test_moment(self, size, expected):
with Model() as model:
init_dist = Constant.dist([[1.0, 2.0], [3.0, 4.0]])
init_dist = DiracDelta.dist([[1.0, 2.0], [3.0, 4.0]])
AR("x", rho=[0, 0], init_dist=init_dist, steps=5, size=size)
assert_moment_is_expected(model, expected, check_finite_logp=False)

Expand Down
6 changes: 3 additions & 3 deletions pymc/tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1314,7 +1314,7 @@ def test_issue_4490(self):
def test_aesara_function_kwargs(self):
sharedvar = aesara.shared(0)
with pm.Model() as m:
x = pm.Constant("x", 0)
x = pm.DiracDelta("x", 0)
y = pm.Deterministic("y", x + sharedvar)

prior = pm.sample_prior_predictive(
Expand Down Expand Up @@ -1361,7 +1361,7 @@ def test_sample_from_xarray_posterior(self, point_list_arg_bug_fixture):
def test_aesara_function_kwargs(self):
sharedvar = aesara.shared(0)
with pm.Model() as m:
x = pm.Constant("x", 0.0)
x = pm.DiracDelta("x", 0.0)
y = pm.Deterministic("y", x + sharedvar)

pp = pm.sample_posterior_predictive(
Expand Down Expand Up @@ -1434,7 +1434,7 @@ def test_draw_different_samples(self):

def test_draw_aesara_function_kwargs(self):
sharedvar = aesara.shared(0)
x = pm.Constant.dist(0.0)
x = pm.DiracDelta.dist(0.0)
y = x + sharedvar
draws = pm.draw(
y,
Expand Down

0 comments on commit f9b749b

Please sign in to comment.