Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add specialization rewrite for solve with batched b #482

Merged
merged 3 commits into from
Nov 11, 2023

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Nov 7, 2023

Related to pymc-devs/pymc#6993 and https://discourse.pymc.io/t/version-dependant-slowing-down-of-gaussian-mixture-sampling-in-ubuntu-20-04/13219

This PR adds a rewrite that optimizes solve when there is a batched vector b (b_ndim=1) and a single a.
This form was previously used for the logp of MvNormal in PyMC and is potentially much faster when the batched dimensions of b are large. More importantly the grad will look better.

I have a couple of other PRs that improve the blockwise grads for the truly new supported cases, but that it's slower work.

PS: Obviously we should also implement the JAX Ops #430, but this looks like a nice optimization, specially if it solves the regression described in the discourse issue above.

@ricardoV94
Copy link
Member Author

ricardoV94 commented Nov 7, 2023

The vectorization of shape operations (mostly reshape) from #454 is also critical for the discourse-related performance issue

@ricardoV94 ricardoV94 force-pushed the solve_opt branch 2 times, most recently from 22f01cc to 2eae804 Compare November 7, 2023 16:11
@codecov-commenter
Copy link

codecov-commenter commented Nov 7, 2023

Codecov Report

Merging #482 (4f4379c) into main (893dc18) will increase coverage by 0.01%.
Report is 2 commits behind head on main.
The diff coverage is 94.33%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #482      +/-   ##
==========================================
+ Coverage   80.73%   80.75%   +0.01%     
==========================================
  Files         160      161       +1     
  Lines       46007    46058      +51     
  Branches    11245    11258      +13     
==========================================
+ Hits        37146    37195      +49     
  Misses       6636     6636              
- Partials     2225     2227       +2     
Files Coverage Δ
pytensor/link/jax/dispatch/__init__.py 100.00% <100.00%> (ø)
pytensor/tensor/blockwise.py 79.89% <100.00%> (+0.62%) ⬆️
pytensor/tensor/rewriting/linalg.py 81.25% <96.29%> (+3.93%) ⬆️
pytensor/link/jax/dispatch/blockwise.py 89.47% <89.47%> (ø)

@ricardoV94 ricardoV94 marked this pull request as ready for review November 7, 2023 19:13
matrix_b_solve = Blockwise(new_core_op)

# Apply the rewrite
new_solve = _T(matrix_b_solve(a, _T(b)))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually we could put any batched b dimension here when there are multiple of them. We may choose the larger one to reduce outer looping

@@ -131,6 +138,52 @@ def generic_solve_to_solve_triangular(fgraph, node):
]


@register_stabilize
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to make this a specialize only rewrite but for now it's in stabilize because pymc includes those before calling grad, and otherwise we still end up with messy unoptimized blockwise graphs

@ricardoV94 ricardoV94 merged commit 0f802ab into pymc-devs:main Nov 11, 2023
53 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants