Skip to content

New ADVI API #635

Open
jessegrabowski wants to merge 8 commits intopymc-devs:mainfrom
jessegrabowski:advi-refactor
Open

New ADVI API #635
jessegrabowski wants to merge 8 commits intopymc-devs:mainfrom
jessegrabowski:advi-refactor

Conversation

@jessegrabowski
Copy link
Copy Markdown
Member

@jessegrabowski jessegrabowski commented Feb 2, 2026

This PR moves the work from pymc-devs/pymc#7799 over here to extras. The key idea is to copy the numpyro guide model API, but with our own PyMC flair.

I also added:

  • stick the landing estimator
  • forward sampling helper
  • LLM sketch for a training API patterned after pytorch-lightning

I updated the notebook to use the proposed training API. It obviously needs a lot of work (being that it's llm trash) but hopefully it can get some ideas flowing.

@review-notebook-app
Copy link
Copy Markdown

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@codecov-commenter
Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 28.57143% with 160 lines in your changes missing coverage. Please review.
✅ Project coverage is 48.54%. Comparing base (12a18c2) to head (bfdb98b).

Files with missing lines Patch % Lines
pymc_extras/inference/advi/training.py 0.00% 140 Missing ⚠️
pymc_extras/inference/advi/objective.py 0.00% 13 Missing ⚠️
pymc_extras/inference/advi/pytensorf.py 75.00% 7 Missing ⚠️

❌ Your patch status has failed because the patch coverage (28.57%) is below the target coverage (50.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files

Impacted file tree graph

@@             Coverage Diff             @@
##             main     #635       +/-   ##
===========================================
+ Coverage   32.93%   48.54%   +15.60%     
===========================================
  Files          69       73        +4     
  Lines        7555     7779      +224     
===========================================
+ Hits         2488     3776     +1288     
+ Misses       5067     4003     -1064     
Files with missing lines Coverage Δ
pymc_extras/inference/advi/autoguide.py 100.00% <100.00%> (ø)
pymc_extras/inference/advi/pytensorf.py 75.00% <75.00%> (ø)
pymc_extras/inference/advi/objective.py 0.00% <0.00%> (ø)
pymc_extras/inference/advi/training.py 0.00% <0.00%> (ø)

... and 14 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Copy Markdown
Contributor

@zaxtax zaxtax left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is pretty neat! After some rebasing and small API changes we should definitely try to get this merged!

The probabilistic model.
guide : AutoGuideModel
The variational guide.
stick_the_landing : bool, optional
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like we need a better name than stick_the_landing . Also the function is supposed to return the logp and logq terms but the STL estimator is about returning only the path derivative component of the gradient.

Number of MC draws per step for gradient estimation, by default 10.
model : Model
The PyMC model to fit. If None, the model is inferred from context.
state : SVIState, optional
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer we copy what PyTorch Lightning did here and have the SVI State live in the Trainer object and be something we can pass upon initalisation.

return state

def sample_posterior(
self, draws: int, state: SVIState, model: Model | None = None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

drop SVIState and move into the Trainer object

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancements New feature or request help wanted Extra attention is needed inference

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants