-
Notifications
You must be signed in to change notification settings - Fork 13
add in auto_set_hmm_seq for auto gen hmm_seq #39
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
1a49f79 to
3f60e31
Compare
brandonwillard
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like this helper function needs a simple test.
Also, I see some upstream changes in this commit (i.e. the docstring changes). Just, rebase your local version of this branch onto upstream/main and then (force) push that to your fork's version of this branch (e.g. origin/auto_hmm_seq).
| def auto_set_hmm_seq(N_states, model, states): | ||
| """ | ||
| Initiate a HMMStateSeq based on the length of the mixture component. | ||
|
|
||
| This function require pymc3 and HMMStateSeq. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function name and its docstring need to state that it's creating a transition matrix from rows that are Dirichlet priors. The same goes for the Dirichlet prior on the initial states, pi_0_tt.
Regarding the name, something like create_dirichlet_state_seq might work.
| S_rv = HMMStateSeq("V_t", P_rv, pi_0_tt, shape=states.shape[0]) | ||
| S_rv.tag.test_value = states | ||
|
|
||
| return locals() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general, locals() isn't a good thing to return, because it often contains more than is necessary and it doesn't clearly state what the intended returned values/types are. For those reasons, this idiom/approach can unnecessarily restrict garbage collection and confound static analysers—as well as other devs.
Instead, it could simply return a tuple (i.e. return P_rv, pi_0_tt, S_rv) or an explicitly created dict (i.e. return {"P": P_rv, ... }).
| P_rv = pm.Deterministic("Gamma", tt.shape_padleft(P_tt)) | ||
| pi_0_tt = compute_steady_state(P_rv) | ||
|
|
||
| S_rv = HMMStateSeq("V_t", P_rv, pi_0_tt, shape=states.shape[0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For consistency, we should probably correct the names of these variables and the names used by the PyMC3 objects they create (i.e. change P_rv to Gamma_rv and S_rv to V_t_rv, or the other way around).
There are other places in the codebase that need these updates, but we can do that separately. In this case, we just don't want to propagate the discord.
| ------- | ||
| locals(), a dict of local variables for reference in sampling steps. | ||
| """ | ||
| with model: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can make the model parameter optional (with a default of None) if we use model = pm.modelcontext(model) before this line. pm.modelcontext will get the model from the surrounding with-context, if any, or use the given model when it's non-None.
Adding in a helper function to set the auto init
HMMStateSeqbased on the number of mixture with a non-informative prior.