From 82f6892f7a671d608e493280887cd662e58058b7 Mon Sep 17 00:00:00 2001 From: Jing Xie Date: Wed, 24 Mar 2021 16:10:03 -0400 Subject: [PATCH] traverse through graph --- pymc3_hmm/distributions.py | 5 +- pymc3_hmm/step_methods.py | 104 +++++++++++++++++++++---------------- 2 files changed, 61 insertions(+), 48 deletions(-) diff --git a/pymc3_hmm/distributions.py b/pymc3_hmm/distributions.py index 1d5d0d8..980017e 100644 --- a/pymc3_hmm/distributions.py +++ b/pymc3_hmm/distributions.py @@ -131,7 +131,6 @@ def __init__(self, comp_dists, states, *args, **kwargs): """ self.states = tt.as_tensor_variable(pm.intX(states)) - self._logp_like = None if len(comp_dists) > 31: warnings.warn( @@ -187,9 +186,7 @@ def logp(self, obs): obs_tt = tt.as_tensor_variable(obs) - shape_var = tuple(obs_tt.shape.tag.test_value) - - logp_val = tt.alloc(-np.inf, *shape_var) + logp_val = tt.alloc(-np.inf, *obs.shape) for i, dist in enumerate(self.comp_dists): i_mask = tt.eq(self.states, i) diff --git a/pymc3_hmm/step_methods.py b/pymc3_hmm/step_methods.py index a117d47..2714885 100644 --- a/pymc3_hmm/step_methods.py +++ b/pymc3_hmm/step_methods.py @@ -13,11 +13,12 @@ from theano.graph.op import get_test_value as test_value from theano.graph.opt import OpRemove, pre_greedy_local_optimizer from theano.graph.optdb import Query +from theano.scan.utils import clone from theano.tensor.elemwise import DimShuffle, Elemwise from theano.tensor.subtensor import AdvancedIncSubtensor1 from theano.tensor.var import TensorConstant -from pymc3_hmm.distributions import DiscreteMarkovChain +from pymc3_hmm.distributions import DiscreteMarkovChain, PoissonZeroProcess from pymc3_hmm.utils import compute_trans_freqs big: float = 1e20 @@ -109,6 +110,44 @@ def ffbs_astep( return samples +def traverse_graph_and_replace(rv, state= None): + visited = [] + + graph = rv.logp_elemwiset + + if type(rv.distribution) is PoissonZeroProcess: + log_k = graph.owner.inputs[0].owner.inputs[0] + shared_log_k = shared(log_k.eval(), name=f'shared_log_k', borrow=True) + clone(graph, {log_k: shared_log_k}) + + st = graph.owner.inputs[0].owner.inputs[1].owner.inputs[0].owner.inputs[1].owner.inputs[1].owner.inputs[1] + st_alloc = st.owner.inputs[0].owner.inputs[1] + shared_st = shared(np.zeros_like(st_alloc.eval()) + state, name=f'shared_st', borrow=True) + + clone(graph, {st_alloc: shared_st}) + + return graph + + queue = [graph] + count = 0 + + while queue: + node = queue.pop(0) + if node not in visited: + visited.append(node) + if node.__str__() =="Alloc.0": + shared_var = shared(node.eval(), name = f'shared_var_{count}', borrow = True) + clone(graph, {node:shared_var}) + count += 1 + if node.owner is not None: + inputs = node.owner.inputs + for input in inputs: + queue.append(input) + + return graph + + + class FFBSStep(ArrayStep): r"""Forward-filtering backward-sampling steps. @@ -142,27 +181,16 @@ def __init__(self, var, values=None, model=None): # total log-likelihood values for each state in the sequence. var_sample = model.test_point[var.name] - self.log_likelihood_values = [] + log_likelihood_values = [] for i, dependent_rv in enumerate(self.dependent_rvs): number_of_state = len(dependent_rv.distribution.comp_dists) - shared_logp = shared(np.zeros((number_of_state, ) +var_sample.shape), - name=f"log_likelihood_values_{i}", borrow=True) - log_p_t = [] - for state_i in range(number_of_state) : - ## theano.graph.basic.clone_replace - ## replace state squence + for state_i in range(number_of_state): + log_p_t.append(traverse_graph_and_replace(dependent_rv, state_i)) - logp_t = dependent_rv.logp_elemwiset.clone() - alloc = logp_t.owner.inputs[0].owner.inputs[0].owner.inputs[0] - if alloc.__str__() =="Alloc.0" : - logp_t.owner.inputs[0].owner.inputs[0] = shared_logp[state_i] - log_p_t.append(logp_t) + log_likelihood_values.append(tt.stack(log_p_t)) - self.log_likelihood_values.append(log_p_t) - - - temp = tt.sum(self.log_likelihood_values, axis=0) + temp = tt.sum(log_likelihood_values, axis = 0) dependents_log_lik = model.fn(temp) @@ -170,6 +198,20 @@ def __init__(self, var, values=None, model=None): self.Gammas_fn = model.fn(var.distribution.Gammas) self.lik_dict = {} + M = number_of_state + N = var_sample.shape[0] + + lik_n: np.ndarray = np.empty((M,), dtype=float) + alpha_n: np.ndarray = np.empty((M,), dtype=float) + beta_n: np.ndarray = np.empty((M,), dtype=float) + samples: np.ndarray = np.empty((N,), dtype=np.int8) + alphas: np.ndarray = np.empty((M, N), dtype=float) + + self.lik_dict["lik_n"] = lik_n + self.lik_dict["alpha_n"] = alpha_n + self.lik_dict["beta_n"] = beta_n + self.lik_dict["samples"] = samples + self.lik_dict["alphas"] = alphas super().__init__([var], [dependents_log_lik], allvars=True) @@ -177,33 +219,7 @@ def astep(self, point, log_lik_fn, inputs): gamma_0 = self.gamma_0_fn(inputs) Gammas_t = self.Gammas_fn(inputs) - M = gamma_0.shape[-1] - N = point.shape[-1] - - state_seqs = np.broadcast_to(np.arange(M, dtype=int)[..., None], (M, N)) - log_lik_t = log_lik_fn(state_seqs) - - - if set(self.lik_dict.keys()) == set([ - "lik_n", - "alpha_n", - "beta_n", - "samples", - "alphas", - ]): - pass - else: - lik_n: np.ndarray = np.empty((M,), dtype=float) - alpha_n: np.ndarray = np.empty((M,), dtype=float) - beta_n: np.ndarray = np.empty((M,), dtype=float) - samples: np.ndarray = np.empty((N,), dtype=np.int8) - alphas: np.ndarray = np.empty((M, N), dtype=float) - - self.lik_dict["lik_n"] = lik_n - self.lik_dict["alpha_n"] = alpha_n - self.lik_dict["beta_n"] = beta_n - self.lik_dict["samples"] = samples - self.lik_dict["alphas"] = alphas + log_lik_t = log_lik_fn(point) return ffbs_astep(gamma_0, Gammas_t, log_lik_t, self.lik_dict)