From d2547f40a94e4894f301e2f834af2f9b8501d2e5 Mon Sep 17 00:00:00 2001 From: Veganveins <16780529+Veganveins@users.noreply.github.com> Date: Tue, 16 Mar 2021 15:56:31 -0400 Subject: [PATCH] Test initial state probability estimation --- tests/test_estimation.py | 43 ++++++++++++++++++++++++++++++++++++++++ tests/utils.py | 7 +++++-- 2 files changed, 48 insertions(+), 2 deletions(-) create mode 100644 tests/test_estimation.py diff --git a/tests/test_estimation.py b/tests/test_estimation.py new file mode 100644 index 0000000..9498062 --- /dev/null +++ b/tests/test_estimation.py @@ -0,0 +1,43 @@ +import numpy as np +import pymc3 as pm +import theano.tensor as tt + +from pymc3_hmm.distributions import DiscreteMarkovChain, SwitchingProcess + +# from pymc3_hmm.step_methods import FFBSStep +from tests.utils import simulate_poiszero_hmm + + +# the way we use this initial probability state vector +def test_gamma_0_estimation(): # annotate this test with pytest + + np.random.seed(2032) + + poiszero_sim, _ = simulate_poiszero_hmm( + 30, 150, pi_0=[0.5, 0.5] + ) # test [.01, .99], [.99, .01] + y_test = poiszero_sim["Y_t"] + + with pm.Model() as test_model: + + # create a 2x2 matrix with 50% transition probs for both states + p_0_rv = np.array([0.5, 0.5]) + p_1_rv = np.array([0.5, 0.5]) + P_tt = tt.stack([p_0_rv, p_1_rv]) # transition matrix + P_rv = pm.Deterministic("P_tt", tt.shape_padleft(P_tt)) + + # pi_0 = gamma_0 = initial state + pi_0_rv = pm.Dirichlet("pi_0", np.r_[1, 1]) # "flat" prior for initial states + + # take the transition matrix and the flat prior for initial states + # creates a sequence of states and it "chooses" the state and time t based + # on how the trans matrix P_rv is defined and + S_rv = DiscreteMarkovChain("S_t", P_rv, pi_0_rv, shape=y_test.shape[0]) + + # takes series of states + SwitchingProcess("Y_t", [pm.Constant.dist(0), pm.Constant.dist(1)], S_rv) + + with test_model: + + prior_predictive = pm.sample_prior_predictive(samples=100) + assert 0.49 < prior_predictive["Y_t"].mean() < 0.51 diff --git a/tests/utils.py b/tests/utils.py index cecba78..861e0c2 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -6,7 +6,7 @@ def simulate_poiszero_hmm( - N, mu=10.0, pi_0_a=np.r_[1, 1], p_0_a=np.r_[5, 1], p_1_a=np.r_[1, 1] + N, mu=10.0, pi_0_a=np.r_[1, 1], p_0_a=np.r_[5, 1], p_1_a=np.r_[1, 1], pi_0=None ): with pm.Model() as test_model: @@ -16,7 +16,10 @@ def simulate_poiszero_hmm( P_tt = tt.stack([p_0_rv, p_1_rv]) P_rv = pm.Deterministic("P_tt", tt.shape_padleft(P_tt)) - pi_0_tt = pm.Dirichlet("pi_0", pi_0_a) + if pi_0: + pi_0_tt = pm.Dirichlet("pi_0", pi_0) + else: + pi_0_tt = pm.Dirichlet("pi_0", pi_0_a) S_rv = DiscreteMarkovChain("S_t", P_rv, pi_0_tt, shape=N)