Skip to content

Commit

Permalink
Fixed bug to allow y being shared variable
Browse files Browse the repository at this point in the history
  • Loading branch information
xjing76 authored and brandonwillard committed Mar 2, 2022
1 parent 087c03a commit d16c55d
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 22 deletions.
61 changes: 41 additions & 20 deletions pymc3_hmm/step_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,9 +546,11 @@ def hs_regression_model_Normal(dist, rv, model):
mu = dist.mu
y_X_fn = None
if hasattr(rv, "observations"):
obs = at.as_tensor_variable(rv.observations)
obs_fn = model.fn(obs)

def y_X_fn(points, X):
return rv.observations, X
return obs_fn(points), X

return y_X_fn, mu

Expand All @@ -567,13 +569,18 @@ def hs_regression_model_NegativeBinomial(dist, rv, model):
if hasattr(rv, "observations"):
from polyagamma import random_polyagamma

# pm.model.fn.signiturw
obs = at.as_tensor_variable(rv.observations)
h_z_alpha_fn = model.fn(
[alpha + rv.observations, eta.squeeze() - at.log(alpha), alpha]
[
alpha + obs,
eta.squeeze() - at.log(alpha),
alpha,
obs,
]
)

def y_X_fn(points, X):
h, z, alpha = h_z_alpha_fn(points)
h, z, alpha, obs = h_z_alpha_fn(points)

omega = random_polyagamma(h, z)

Expand All @@ -586,7 +593,7 @@ def y_X_fn(points, X):
else:
Phi = (X.T * np.sqrt(V_diag_inv)).T

y_aug = np.log(alpha) + (rv.observations - alpha) / (2.0 * omega)
y_aug = np.log(alpha) + (obs - alpha) / (2.0 * omega)
y_aug = (y_aug / sigma).astype(config.floatX)
return y_aug, Phi

Expand All @@ -595,6 +602,28 @@ def y_X_fn(points, X):
return None, eta


def find_dot(node, beta, model, y_fn):
if not node.owner:
return
# dense dot
if isinstance(node.owner.op, Dot):
if beta in node.owner.inputs:
X_fn = model.fn(node.owner.inputs[1].T)
return node, X_fn, y_fn
# sprase dot
if isinstance(node.owner.op, StructuredDot):
if beta in node.owner.inputs[1].owner.inputs:
X_fn = model.fn(node.owner.inputs[0])
return node, X_fn, y_fn
else:
# if exp transformation
if isinstance(node.owner.op, at.elemwise.Elemwise):
res = find_dot(node.owner.inputs[0], beta, model, y_fn)
if res:
node, X_fn, _ = res
return node, X_fn, y_fn


class HSStep(BlockedStep):
name = "hsgibbs"

Expand Down Expand Up @@ -625,15 +654,13 @@ def __init__(self, vars, values=None, model=None):
continue
elif isinstance(var, pm.model.DeterministicWrapper):
eta = var.owner.inputs[0]

dense_dot = eta.owner and isinstance(eta.owner.op, Dot)
sparse_dot = eta.owner and isinstance(eta.owner.op, StructuredDot)

dense_inputs = dense_dot and beta in eta.owner.inputs
sparse_inputs = sparse_dot and beta in eta.owner.inputs[1].owner.inputs

if not (dense_inputs or sparse_inputs):
continue
if eta.owner:
eta_X_fn = find_dot(eta, beta, model, y_X_fn)
if not eta_X_fn:
continue
eta, X_fn, y_X_fn = eta_X_fn
else:
continue # pragma: no cover

if not y_X_fn:
# We don't have the observation distribution, so we need to
Expand All @@ -656,11 +683,6 @@ def __init__(self, vars, values=None, model=None):
if var != obs_mu:
continue

if dense_inputs:
X_fn = model.fn(eta.owner.inputs[1].T)
else:
X_fn = model.fn(eta.owner.inputs[0])

if not (X_fn and y_X_fn):
raise NotImplementedError(
f"Cannot find a design matrix or dependent variable associated with {beta}" # noqa: E501
Expand All @@ -672,7 +694,6 @@ def __init__(self, vars, values=None, model=None):

# if observation dist is normal then y_aug_fn = y_fn when it is NB
# then, hs_regression_model, dispatch i.distribution...

self.vi = np.full(M, 1)
self.lambda2 = np.full(M, 1)
self.beta = np.full(M, 1)
Expand Down
50 changes: 48 additions & 2 deletions tests/test_step_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,19 @@

try:
import aesara.tensor as at
from aesara import shared
from aesara.graph.op import get_test_value
from aesara.sparse import structured_dot as sp_dot
except ImportError:
import theano.tensor as at
from theano.graph.op import get_test_value
from theano.sparse import structured_dot as sp_dot
from theano import shared

import pymc3 as pm
import pytest
import scipy as sp
from pymc3.exceptions import SamplingError

from pymc3_hmm.distributions import DiscreteMarkovChain, HorseShoe, PoissonZeroProcess
from pymc3_hmm.step_methods import (
Expand Down Expand Up @@ -449,7 +452,7 @@ def test_HSStep_Normal_Deterministic():

beta_samples = trace.posterior["beta"][0].values
assert beta_samples.shape == (50, M)
np.testing.assert_allclose(beta_samples.mean(0), beta_true, atol=0.3)
np.testing.assert_allclose(beta_samples.mean(0), beta_true, atol=0.5)


def test_HSStep_unsupported():
Expand Down Expand Up @@ -557,8 +560,15 @@ def test_HSStep_NegativeBinomial():
eta = pm.NegativeBinomial("eta", mu=beta.dot(X.T), alpha=1, shape=N)
pm.Normal("y", mu=at.exp(eta), sigma=1, observed=y_nb)

with pytest.raises(NotImplementedError):
with pytest.raises(SamplingError):
HSStep([beta])
pm.sample(
draws=N_draws,
step=hsstep,
chains=1,
return_inferencedata=True,
compute_convergence_checks=False,
)


def test_HSStep_NegativeBinomial_sparse():
Expand Down Expand Up @@ -589,3 +599,39 @@ def test_HSStep_NegativeBinomial_sparse():
beta_samples = trace.posterior["beta"][0].values
assert beta_samples.shape == (N_draws, M)
np.testing.assert_allclose(beta_samples.mean(0), beta_true, atol=0.5)


def test_HSStep_NegativeBinomial_sparse_shared_y():
np.random.seed(2032)
M = 5
N = 50
X = np.random.normal(size=N * M).reshape((N, M))
beta_true = np.array([1, 1, 2, 2, 0])
y_nb = pm.NegativeBinomial.dist(np.exp(X.dot(beta_true)), 1).random()

X = sp.sparse.csr_matrix(X)

X_tt = shared(X, name="X", borrow=True)
y_tt = shared(y_nb, name="y_t", borrow=True)

N_draws = 100
with pm.Model():
beta = HorseShoe("beta", tau=1, shape=M)
pm.NegativeBinomial(
"y",
mu=at.exp(sp_dot(X_tt, at.shape_padright(beta))),
alpha=1,
observed=y_tt,
)
hsstep = HSStep([beta])
trace = pm.sample(
draws=N_draws,
step=hsstep,
chains=1,
return_inferencedata=True,
compute_convergence_checks=False,
)

beta_samples = trace.posterior["beta"][0].values
assert beta_samples.shape == (N_draws, M)
np.testing.assert_allclose(beta_samples.mean(0), beta_true, atol=0.5)

0 comments on commit d16c55d

Please sign in to comment.