From 66fa47e19e1446f82eaebc7e363ded019820f651 Mon Sep 17 00:00:00 2001 From: Veganveins <16780529+Veganveins@users.noreply.github.com> Date: Tue, 16 Mar 2021 10:20:20 -0400 Subject: [PATCH 1/2] Test steady state approximation --- tests/test_utils.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/test_utils.py b/tests/test_utils.py index db81ef1..1d789a3 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -5,6 +5,7 @@ import theano.tensor as tt from pymc3_hmm.utils import ( + compute_steady_state, compute_trans_freqs, logdotexp, multilogit_inv, @@ -120,3 +121,17 @@ def test_multilogit_inv(test_input, test_output): res = multilogit_inv(tt.as_tensor_variable(test_input)) res = res.eval() assert np.array_equal(res.round(2), test_output) + + +test_cases = [ + (np.ones((2, 2)) * 0.5, np.array([0.5, 0.5])), + (np.eye(4), np.array([1, 0, 0, 0])), +] + + +@pytest.mark.parametrize("transition_matrix, steady_state", test_cases) +def test_compute_steady_state(transition_matrix, steady_state): + + P = tt.as_tensor_variable(transition_matrix) + ss_probs = compute_steady_state(P) + np.testing.assert_almost_equal(ss_probs.eval(), steady_state, 1) From 42d64ef770b8c033ab1440a33a947beb5a518553 Mon Sep 17 00:00:00 2001 From: Veganveins <16780529+Veganveins@users.noreply.github.com> Date: Tue, 16 Mar 2021 15:56:31 -0400 Subject: [PATCH 2/2] Test initial state probability estimation --- tests/test_estimation.py | 62 ++++++++++++++++++++++++++++++++++++++++ tests/utils.py | 7 +++-- 2 files changed, 67 insertions(+), 2 deletions(-) create mode 100644 tests/test_estimation.py diff --git a/tests/test_estimation.py b/tests/test_estimation.py new file mode 100644 index 0000000..a790fba --- /dev/null +++ b/tests/test_estimation.py @@ -0,0 +1,62 @@ +import numpy as np +import pymc3 as pm +import theano.tensor as tt + +from pymc3_hmm.distributions import DiscreteMarkovChain, PoissonZeroProcess +from pymc3_hmm.step_methods import FFBSStep + + +def simulate_poiszero_hmm( + N, + observed, + mu=10.0, + pi_0_a=np.r_[1, 1], + p_0_a=np.r_[5, 1], + p_1_a=np.r_[1, 1], + pi_0=None, +): + p_0_rv = pm.Dirichlet("p_0", p_0_a) + p_1_rv = pm.Dirichlet("p_1", p_1_a) + P_tt = tt.stack([p_0_rv, p_1_rv]) + P_rv = pm.Deterministic("P_tt", tt.shape_padleft(P_tt)) + if pi_0 is not None: + pi_0_tt = tt.as_tensor(pi_0) + else: + pi_0_tt = pm.Dirichlet("pi_0", pi_0_a) + S_rv = DiscreteMarkovChain("S_t", P_rv, pi_0_tt, shape=N) + S_rv.tag.test_value = (observed > 0) * 1 + return PoissonZeroProcess("Y_t", mu, S_rv, observed=observed) + + +def test_gamma_0_estimation(): + + np.random.seed(2032) + true_initial_states = np.array([0.02, 0.98]) + + with pm.Model(theano_config={"compute_test_value": "ignore"}) as sim_model: + _ = simulate_poiszero_hmm(30, np.zeros(30), 150, pi_0=true_initial_states) + + sim_point = pm.sample_prior_predictive(samples=1, model=sim_model) + sim_point["Y_t"] = sim_point["Y_t"].squeeze() + y_test = sim_point["Y_t"] + + with pm.Model() as test_model: + _ = simulate_poiszero_hmm( + 30, + y_test, + 150, + ) + states_step = FFBSStep([test_model["S_t"]]) + + posterior_trace = pm.sample( + step=[states_step], + draws=5, + return_inferencedata=True, + chains=1, + progressbar=True, + ) + + estimated_initial_state_probs = posterior_trace.posterior.pi_0.values[0].mean(0) + np.testing.assert_almost_equal( + estimated_initial_state_probs, true_initial_states, 1 + ) diff --git a/tests/utils.py b/tests/utils.py index cecba78..861e0c2 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -6,7 +6,7 @@ def simulate_poiszero_hmm( - N, mu=10.0, pi_0_a=np.r_[1, 1], p_0_a=np.r_[5, 1], p_1_a=np.r_[1, 1] + N, mu=10.0, pi_0_a=np.r_[1, 1], p_0_a=np.r_[5, 1], p_1_a=np.r_[1, 1], pi_0=None ): with pm.Model() as test_model: @@ -16,7 +16,10 @@ def simulate_poiszero_hmm( P_tt = tt.stack([p_0_rv, p_1_rv]) P_rv = pm.Deterministic("P_tt", tt.shape_padleft(P_tt)) - pi_0_tt = pm.Dirichlet("pi_0", pi_0_a) + if pi_0: + pi_0_tt = pm.Dirichlet("pi_0", pi_0) + else: + pi_0_tt = pm.Dirichlet("pi_0", pi_0_a) S_rv = DiscreteMarkovChain("S_t", P_rv, pi_0_tt, shape=N)