Skip to content

Commit

Permalink
Use BlockedStep base class for FFBStep and update output in-place
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Mar 31, 2021
1 parent 25efe93 commit dce65ce
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 40 deletions.
77 changes: 44 additions & 33 deletions pymc3_hmm/step_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pymc3 as pm
import theano.scalar as ts
import theano.tensor as tt
from pymc3.step_methods.arraystep import ArrayStep, Competence
from pymc3.step_methods.arraystep import ArrayStep, BlockedStep, Competence
from pymc3.util import get_untransformed_name
from theano.compile import optdb
from theano.graph.basic import Variable, graph_inputs
Expand All @@ -16,14 +16,16 @@
from theano.tensor.subtensor import AdvancedIncSubtensor1
from theano.tensor.var import TensorConstant

from pymc3_hmm.distributions import DiscreteMarkovChain
from pymc3_hmm.distributions import DiscreteMarkovChain, SwitchingProcess
from pymc3_hmm.utils import compute_trans_freqs

big: float = 1e20
small: float = 1.0 / big


def ffbs_astep(gamma_0: np.ndarray, Gammas: np.ndarray, log_lik: np.ndarray):
def ffbs_step(
gamma_0: np.ndarray, Gammas: np.ndarray, log_lik: np.ndarray, output: np.ndarray
):
"""Sample a forward-filtered backward-sampled (FFBS) state sequence.
Parameters
Expand Down Expand Up @@ -83,18 +85,15 @@ def ffbs_astep(gamma_0: np.ndarray, Gammas: np.ndarray, log_lik: np.ndarray):
alpha_nm1 = alpha_n
alphas[..., n] = alpha_n

# The FFBS samples
samples: np.ndarray = np.empty((N,), dtype=np.int8)

# The uniform samples used to sample the categorical states
unif_samples: np.ndarray = np.random.uniform(size=samples.shape)
unif_samples: np.ndarray = np.random.uniform(size=output.shape)

alpha_N: np.ndarray = alphas[..., N - 1]
beta_N: np.ndarray = alpha_N / alpha_N.sum()

state_np1: np.ndarray = np.searchsorted(beta_N.cumsum(), unif_samples[N - 1])

samples[N - 1] = state_np1
output[N - 1] = state_np1

beta_n: np.ndarray = np.empty((M,), dtype=float)

Expand All @@ -104,12 +103,12 @@ def ffbs_astep(gamma_0: np.ndarray, Gammas: np.ndarray, log_lik: np.ndarray):
beta_n /= np.sum(beta_n)

state_np1 = np.searchsorted(beta_n.cumsum(), unif_samples[n])
samples[n] = state_np1
output[n] = state_np1

return samples
return output


class FFBSStep(ArrayStep):
class FFBSStep(BlockedStep):
r"""Forward-filtering backward-sampling steps.
For a hidden Markov model with state sequence :math:`S_t`, observations
Expand All @@ -126,44 +125,56 @@ class FFBSStep(ArrayStep):

name = "ffbs"

def __init__(self, var, values=None, model=None):
def __init__(self, vars, values=None, model=None):

if len(vars) > 1:
raise ValueError("This sampler only takes one variable.")

(var,) = pm.inputvars(vars)

if not isinstance(var.distribution, DiscreteMarkovChain):
raise TypeError("This sampler only samples `DiscreteMarkovChain`s.")

model = pm.modelcontext(model)

(var,) = pm.inputvars(var)
self.vars = [var]

self.dependent_rvs = [
v
for v in model.basic_RVs
if v is not var and var in graph_inputs([v.logpt])
]

# We compile a function--from a Theano graph--that computes the
# total log-likelihood values for each state in the sequence.
dependents_log_lik = model.fn(
tt.sum([v.logp_elemwiset for v in self.dependent_rvs], axis=0)
)
dep_comps_logp_stacked = []
for i, dependent_rv in enumerate(self.dependent_rvs):
if isinstance(dependent_rv.distribution, SwitchingProcess):
comp_logps = []

self.gamma_0_fn = model.fn(var.distribution.gamma_0)
self.Gammas_fn = model.fn(var.distribution.Gammas)
# Get the log-likelihoood sequences for each state in this
# `SwitchingProcess` observations distribution
for comp_dist in dependent_rv.distribution.comp_dists:
comp_logps.append(comp_dist.logp(dependent_rv))

super().__init__([var], [dependents_log_lik], allvars=True)
comp_logp_stacked = tt.stack(comp_logps)
else:
raise TypeError(
"This sampler only supports `SwitchingProcess` observations"
)

def astep(self, point, log_lik_fn, inputs):
gamma_0 = self.gamma_0_fn(inputs)
Gammas_t = self.Gammas_fn(inputs)
dep_comps_logp_stacked.append(comp_logp_stacked)

M = gamma_0.shape[-1]
N = point.shape[-1]
comp_logp_stacked = tt.sum(dep_comps_logp_stacked, axis=0)

# TODO: Why won't broadcasting work with `log_lik_fn`? Seems like we
# could be missing out on a much more efficient/faster approach to this
# potentially large computation.
# state_seqs = np.broadcast_to(np.arange(M, dtype=int)[..., None], (M, N))
# log_lik_t = log_lik_fn(state_seqs)
log_lik_t = np.stack([log_lik_fn(np.broadcast_to(m, N)) for m in range(M)])
self.log_lik_states = model.fn(comp_logp_stacked)
self.gamma_0_fn = model.fn(var.distribution.gamma_0)
self.Gammas_fn = model.fn(var.distribution.Gammas)

return ffbs_astep(gamma_0, Gammas_t, log_lik_t)
def step(self, point):
gamma_0 = self.gamma_0_fn(point)
Gammas_t = self.Gammas_fn(point)
log_lik_state_vals = self.log_lik_states(point)
ffbs_step(gamma_0, Gammas_t, log_lik_state_vals, point[self.vars[0].name])
return point

@staticmethod
def competence(var):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def test_time_varying_model():

sim_point = pm.sample_prior_predictive(samples=1, model=sim_model)

y_t = sim_point["Y_t"].squeeze()
y_t = sim_point["Y_t"].squeeze().astype(int)

split = int(len(y_t) * 0.7)

Expand Down Expand Up @@ -155,7 +155,7 @@ def test_time_varying_model():
)

# Update the shared variable values
Y.set_value(np.ones(test_X.shape[0]))
Y.set_value(np.ones(test_X.shape[0], dtype=Y.dtype))
X.set_value(test_X)

model.V_t.distribution.shape = (test_X.shape[0],)
Expand Down
14 changes: 9 additions & 5 deletions tests/test_step_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from theano.graph.op import get_test_value

from pymc3_hmm.distributions import DiscreteMarkovChain, PoissonZeroProcess
from pymc3_hmm.step_methods import FFBSStep, TransMatConjugateStep, ffbs_astep
from pymc3_hmm.step_methods import FFBSStep, TransMatConjugateStep, ffbs_step
from pymc3_hmm.utils import compute_steady_state, compute_trans_freqs
from tests.utils import simulate_poiszero_hmm

Expand Down Expand Up @@ -36,13 +36,15 @@ def test_ffbs_astep():
test_log_lik_0 = np.stack(
[np.broadcast_to(0.0, 10000), np.broadcast_to(-np.inf, 10000)]
)
res = ffbs_astep(test_gamma_0, test_Gammas, test_log_lik_0)
res = np.empty(test_log_lik_0.shape[-1])
ffbs_step(test_gamma_0, test_Gammas, test_log_lik_0, res)
assert np.all(res == 0)

test_log_lik_1 = np.stack(
[np.broadcast_to(-np.inf, 10000), np.broadcast_to(0.0, 10000)]
)
res = ffbs_astep(test_gamma_0, test_Gammas, test_log_lik_1)
res = np.empty(test_log_lik_1.shape[-1])
ffbs_step(test_gamma_0, test_Gammas, test_log_lik_1, res)
assert np.all(res == 1)

# A well-separated mixture with non-degenerate likelihoods
Expand All @@ -59,7 +61,8 @@ def test_ffbs_astep():
# TODO FIXME: This is a statistically unsound/unstable check.
assert np.mean(np.abs(test_log_lik_p.argmax(0) - test_seq)) < 1e-2

res = ffbs_astep(test_gamma_0, test_Gammas, test_log_lik_p)
res = np.empty(test_log_lik_p.shape[-1])
ffbs_step(test_gamma_0, test_Gammas, test_log_lik_p, res)
# TODO FIXME: This is a statistically unsound/unstable check.
assert np.mean(np.abs(res - test_seq)) < 1e-2

Expand All @@ -81,7 +84,8 @@ def test_ffbs_astep():
test_log_lik[::2] = test_log_lik[::2][:, ::-1]
test_log_lik = test_log_lik.T

res = ffbs_astep(test_gamma_0, test_Gammas, test_log_lik)
res = np.empty(test_log_lik.shape[-1])
ffbs_step(test_gamma_0, test_Gammas, test_log_lik, res)
assert np.array_equal(res, np.r_[1, 0, 0, 1])


Expand Down

0 comments on commit dce65ce

Please sign in to comment.