-
Notifications
You must be signed in to change notification settings - Fork 108
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add specialization rewrite for solve with batched b #482
Conversation
The vectorization of shape operations (mostly reshape) from #454 is also critical for the discourse-related performance issue |
22f01cc
to
2eae804
Compare
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #482 +/- ##
==========================================
+ Coverage 80.73% 80.75% +0.01%
==========================================
Files 160 161 +1
Lines 46007 46058 +51
Branches 11245 11258 +13
==========================================
+ Hits 37146 37195 +49
Misses 6636 6636
- Partials 2225 2227 +2
|
matrix_b_solve = Blockwise(new_core_op) | ||
|
||
# Apply the rewrite | ||
new_solve = _T(matrix_b_solve(a, _T(b))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually we could put any batched b dimension here when there are multiple of them. We may choose the larger one to reduce outer looping
@@ -131,6 +138,52 @@ def generic_solve_to_solve_triangular(fgraph, node): | |||
] | |||
|
|||
|
|||
@register_stabilize |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I want to make this a specialize only rewrite but for now it's in stabilize because pymc includes those before calling grad, and otherwise we still end up with messy unoptimized blockwise graphs
8d254f9
to
4f4379c
Compare
Related to pymc-devs/pymc#6993 and https://discourse.pymc.io/t/version-dependant-slowing-down-of-gaussian-mixture-sampling-in-ubuntu-20-04/13219
This PR adds a rewrite that optimizes solve when there is a batched vector b (b_ndim=1) and a single a.
This form was previously used for the logp of MvNormal in PyMC and is potentially much faster when the batched dimensions of b are large. More importantly the grad will look better.
I have a couple of other PRs that improve the blockwise grads for the truly new supported cases, but that it's slower work.
PS: Obviously we should also implement the JAX Ops #430, but this looks like a nice optimization, specially if it solves the regression described in the discourse issue above.