-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Test initial state probability estimation
- Loading branch information
1 parent
66fa47e
commit d2547f4
Showing
2 changed files
with
48 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import numpy as np | ||
import pymc3 as pm | ||
import theano.tensor as tt | ||
|
||
from pymc3_hmm.distributions import DiscreteMarkovChain, SwitchingProcess | ||
|
||
# from pymc3_hmm.step_methods import FFBSStep | ||
from tests.utils import simulate_poiszero_hmm | ||
|
||
|
||
# the way we use this initial probability state vector | ||
def test_gamma_0_estimation(): # annotate this test with pytest | ||
|
||
np.random.seed(2032) | ||
|
||
poiszero_sim, _ = simulate_poiszero_hmm( | ||
30, 150, pi_0=[0.5, 0.5] | ||
) # test [.01, .99], [.99, .01] | ||
y_test = poiszero_sim["Y_t"] | ||
|
||
with pm.Model() as test_model: | ||
|
||
# create a 2x2 matrix with 50% transition probs for both states | ||
p_0_rv = np.array([0.5, 0.5]) | ||
p_1_rv = np.array([0.5, 0.5]) | ||
P_tt = tt.stack([p_0_rv, p_1_rv]) # transition matrix | ||
P_rv = pm.Deterministic("P_tt", tt.shape_padleft(P_tt)) | ||
|
||
# pi_0 = gamma_0 = initial state | ||
pi_0_rv = pm.Dirichlet("pi_0", np.r_[1, 1]) # "flat" prior for initial states | ||
|
||
# take the transition matrix and the flat prior for initial states | ||
# creates a sequence of states and it "chooses" the state and time t based | ||
# on how the trans matrix P_rv is defined and | ||
S_rv = DiscreteMarkovChain("S_t", P_rv, pi_0_rv, shape=y_test.shape[0]) | ||
|
||
# takes series of states | ||
SwitchingProcess("Y_t", [pm.Constant.dist(0), pm.Constant.dist(1)], S_rv) | ||
|
||
with test_model: | ||
|
||
prior_predictive = pm.sample_prior_predictive(samples=100) | ||
assert 0.49 < prior_predictive["Y_t"].mean() < 0.51 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters