Conversation
|
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
Codecov Report❌ Patch coverage is ❌ Your patch status has failed because the patch coverage (28.57%) is below the target coverage (50.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #635 +/- ##
===========================================
+ Coverage 32.93% 48.54% +15.60%
===========================================
Files 69 73 +4
Lines 7555 7779 +224
===========================================
+ Hits 2488 3776 +1288
+ Misses 5067 4003 -1064
🚀 New features to boost your workflow:
|
zaxtax
left a comment
There was a problem hiding this comment.
I think this is pretty neat! After some rebasing and small API changes we should definitely try to get this merged!
| The probabilistic model. | ||
| guide : AutoGuideModel | ||
| The variational guide. | ||
| stick_the_landing : bool, optional |
There was a problem hiding this comment.
I feel like we need a better name than stick_the_landing . Also the function is supposed to return the logp and logq terms but the STL estimator is about returning only the path derivative component of the gradient.
| Number of MC draws per step for gradient estimation, by default 10. | ||
| model : Model | ||
| The PyMC model to fit. If None, the model is inferred from context. | ||
| state : SVIState, optional |
There was a problem hiding this comment.
I'd prefer we copy what PyTorch Lightning did here and have the SVI State live in the Trainer object and be something we can pass upon initalisation.
| return state | ||
|
|
||
| def sample_posterior( | ||
| self, draws: int, state: SVIState, model: Model | None = None |
There was a problem hiding this comment.
drop SVIState and move into the Trainer object
This PR moves the work from pymc-devs/pymc#7799 over here to extras. The key idea is to copy the numpyro guide model API, but with our own PyMC flair.
I also added:
I updated the notebook to use the proposed training API. It obviously needs a lot of work (being that it's llm trash) but hopefully it can get some ideas flowing.