Skip to content

Commit

Permalink
store lik_n and alpha in ffb step attribute
Browse files Browse the repository at this point in the history
  • Loading branch information
xjing76 committed Mar 18, 2021
1 parent de16651 commit a2b711e
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 9 deletions.
22 changes: 17 additions & 5 deletions pymc3_hmm/step_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@
small: float = 1.0 / big


def ffbs_astep(gamma_0: np.ndarray, Gammas: np.ndarray, log_lik: np.ndarray):
def ffbs_astep(gamma_0: np.ndarray,
Gammas: np.ndarray,
log_lik: np.ndarray,
lik_dict: dict):
"""Sample a forward-filtered backward-sampled (FFBS) state sequence.
Parameters
Expand Down Expand Up @@ -65,8 +68,8 @@ def ffbs_astep(gamma_0: np.ndarray, Gammas: np.ndarray, log_lik: np.ndarray):
# sequence
Gamma: np.ndarray = np.broadcast_to(Gammas, (N,) + Gammas.shape[-2:])

lik_n: np.ndarray = np.empty((M,), dtype=float)
alpha_n: np.ndarray = np.empty((M,), dtype=float)
lik_n = lik_dict['lik_n']
alpha_n = lik_dict['alpha_n']

# Forward filtering
for n in range(N):
Expand Down Expand Up @@ -146,6 +149,7 @@ def __init__(self, var, values=None, model=None):

self.gamma_0_fn = model.fn(var.distribution.gamma_0)
self.Gammas_fn = model.fn(var.distribution.Gammas)
self.lik_dict = {}

super().__init__([var], [dependents_log_lik], allvars=True)

Expand All @@ -160,10 +164,18 @@ def astep(self, point, log_lik_fn, inputs):
# could be missing out on a much more efficient/faster approach to this
# potentially large computation.
# state_seqs = np.broadcast_to(np.arange(M, dtype=int)[..., None], (M, N))
# log_lik_t = log_lik_fn(state_seqs)
log_lik_t = np.stack([log_lik_fn(np.broadcast_to(m, N)) for m in range(M)])

return ffbs_astep(gamma_0, Gammas_t, log_lik_t)
if 'lik_n' in self.lik_dict and 'alpha_n' in self.lik_dict :
pass
else:
lik_n: np.ndarray = np.empty((M,), dtype=float)
alpha_n: np.ndarray = np.empty((M,), dtype=float)
self.lik_dict['lik_n'] = lik_n
self.lik_dict['alpha_n'] = alpha_n


return ffbs_astep(gamma_0, Gammas_t, log_lik_t, self.lik_dict)

@staticmethod
def competence(var):
Expand Down
12 changes: 8 additions & 4 deletions tests/test_step_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,20 @@ def test_ffbs_astep():
test_Gammas = np.array([[[0.9, 0.1], [0.1, 0.9]]])
test_gamma_0 = np.r_[0.5, 0.5]

lik_dict = {}
lik_dict['lik_n'] = np.empty((test_gamma_0.shape[-1],), dtype=float)
lik_dict['alpha_n'] = np.empty((test_gamma_0.shape[-1],), dtype=float)

test_log_lik_0 = np.stack(
[np.broadcast_to(0.0, 10000), np.broadcast_to(-np.inf, 10000)]
)
res = ffbs_astep(test_gamma_0, test_Gammas, test_log_lik_0)
res = ffbs_astep(test_gamma_0, test_Gammas, test_log_lik_0, lik_dict)
assert np.all(res == 0)

test_log_lik_1 = np.stack(
[np.broadcast_to(-np.inf, 10000), np.broadcast_to(0.0, 10000)]
)
res = ffbs_astep(test_gamma_0, test_Gammas, test_log_lik_1)
res = ffbs_astep(test_gamma_0, test_Gammas, test_log_lik_1, lik_dict)
assert np.all(res == 1)

# A well-separated mixture with non-degenerate likelihoods
Expand All @@ -59,7 +63,7 @@ def test_ffbs_astep():
# TODO FIXME: This is a statistically unsound/unstable check.
assert np.mean(np.abs(test_log_lik_p.argmax(0) - test_seq)) < 1e-2

res = ffbs_astep(test_gamma_0, test_Gammas, test_log_lik_p)
res = ffbs_astep(test_gamma_0, test_Gammas, test_log_lik_p, lik_dict)
# TODO FIXME: This is a statistically unsound/unstable check.
assert np.mean(np.abs(res - test_seq)) < 1e-2

Expand All @@ -81,7 +85,7 @@ def test_ffbs_astep():
test_log_lik[::2] = test_log_lik[::2][:, ::-1]
test_log_lik = test_log_lik.T

res = ffbs_astep(test_gamma_0, test_Gammas, test_log_lik)
res = ffbs_astep(test_gamma_0, test_Gammas, test_log_lik, lik_dict)
assert np.array_equal(res, np.r_[1, 0, 0, 1])


Expand Down

0 comments on commit a2b711e

Please sign in to comment.