Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding DirichletProcess function #121

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

larryshamalama
Copy link
Member

I have made several attempts since my first summer of code, the recent one in #66 but for a Dirichlet Process Mixture.

This PR adds a Truncated DirichletProcess function. The API is not fully determined and flexible for now.

In a mixture, a truncated DPM can inherit from pm.Mixture, so the most of the building blocks should be available. I decide to create this PR to showcase that we can have a non-Mixture DirichletProcess in PyMC:

with pm.Model() as model:
    alpha = pm.Uniform("alpha", 0.0, 10.0)
    base_dist = pm.Normal("base_dist", 0.0, 1.0, shape=(K + 1,))
    sbw, atoms = pmx.dp.DirichletProcess("dp", alpha, base_dist, K, observed=obs)

    trace = pm.sample(target_accept=0.95)

The DirichletProcess function returns two objects: a stick-breaking weights variable created from pm.StickBreakingWeights.dist() and atoms, which can be one of the observations in obs or a newly drawn atom from the base distribution base_dist. See the added notebooks in this PR for a demonstration on an example from Bayesian Nonparametric Data Analysis.

Posterior predictive sampling will need some thought. I will add a comment in the code on where this needs editing. For anyone reading this PR, comments are appreciated.

Special thanks to @ricardoV94 for bouncing ideas back and forth recently, and @fonnesbeck, @AustinRochford for the initial supervision for this project.

Tests are to be added, of course. I am thinking of revisiting #66 which has been created a while ago in a bit.

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB


atoms = pm.Deterministic(
atoms_name,
var=pt.stack([pt.constant(observed)[atom_selection], base_dist], axis=-1)[
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ricardoV94 Following our conversation a few weeks (or months?) ago, I was able to make this work. Thanks for the ideas.

However, I believe that posterior predictive sampling would require defining a custom distribution class. I'm not so sure at this point, I believe that this would need some creativity and possibly revisiting the sketch that you thought about a while back.

@larryshamalama larryshamalama marked this pull request as ready for review March 23, 2023 13:18
Scale concentration parameter (alpha > 0) specifying the size of "sticks", or generated
weights, from the stick-breaking process. Ideally, alpha should have a prior and not be
a fixed constant.
base_dist: single batched distribution
Copy link
Member

@ricardoV94 ricardoV94 Mar 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to use the same API as in other distribution factories, where the user passes a .dist variable and we resize it ourselves (and in this case, register in the model as well)?

https://github.com/pymc-devs/pymc/blob/f3ce16f2606f523137c27466069f1ab737626f21/pymc/distributions/censored.py#L55-L59

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, good idea

@larryshamalama
Copy link
Member Author

larryshamalama commented Jun 28, 2023

Revisiting this PR after a while. I have concluded that, if we were to bring a nice API to Dirichlet Processes as intuitive as PyMC's conventional API, it would be quite difficult unless it is "hackathon"ed for several days. The main reason is that conditioning on data, i.e. using the observed=... keyword in PyMC, does not blend well with the fact that the base distribution G0 remains "unobserved" throughout. As recommended by @ricardoV94, this can be solved by concatenated observed values with G0 in a SymbolicRandomVariable and have a logp-ignorant custom sampler that somehow offers a different behaviour depending on whether sample_prior_predictive or sample_posterior_predictive is called. With the time it took to get here, I think that I will close this progress until I can devote a couple of days (maybe weeks) to focus on this avenue.

I will commit my recent changes, but it is under my impression that this PR can be added under docs rather than a formal functionality. The docs will introduce how a posterior Dirichlet Process can be integrated, i.e. "hacked", into our PPL, although several functionalities will fall short. What do you think @ricardoV94?

@ricardoV94
Copy link
Member

Sounds reasonable

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants