-
Notifications
You must be signed in to change notification settings - Fork 125
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Ops for LU Factorization #1218
base: main
Are you sure you want to change the base?
Conversation
3ceece7
to
e57bf39
Compare
b1f8c9d
to
2a5b361
Compare
…ontinuous output types
2a5b361
to
93314a5
Compare
|
||
|
||
def lu( | ||
a: TensorLike, permute_l=False, check_finite=True, p_indices=False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can allow overwrite_a
in the signature but just ignore it like JAX does
compare_numba_and_py([A], out, test_inputs=[A_val], inplace=True) | ||
|
||
else: | ||
# compare_numba_and_py fails: NotImplementedError: Non-jitted BlockwiseWithCoreShape not implemented |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just pass eval_obj_mode=False
def __init__(self, *, overwrite_a=False, check_finite=True): | ||
self.overwrite_a = overwrite_a | ||
self.check_finite = check_finite | ||
self.gufunc_signature = "(m,m)->(m,m),(m)" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Put signature out of init since it's constant
f = pytensor.function([A], out, mode="NUMBA") | ||
|
||
if len(shape) == 2: | ||
compare_numba_and_py([A], out, [A_val], inplace=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same you can use compare_numba_and_py
Any benchmarks on solve written with these Ops? |
Working on that next, just ironing out some bugs in the |
Description
This PR will add the following Ops:
As well as dispatches for numba/jax (and maybe torch, though help is welcome there).
The reason for wanting these is that it will make the gradients of solve faster. I think this is a major reason why jax has faster gradients than us (at least when solve is implicated). They route everything to
lu_solve(lu_factor(A), b)
, and reuselu_factor(A)
in the backward pass.Related Issue
Checklist
Type of change
📚 Documentation preview 📚: https://pytensor--1218.org.readthedocs.build/en/1218/