Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Ops for LU Factorization #1218

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

jessegrabowski
Copy link
Member

@jessegrabowski jessegrabowski commented Feb 18, 2025

Description

This PR will add the following Ops:

  • lu
  • lu_factor
  • lu_solve

As well as dispatches for numba/jax (and maybe torch, though help is welcome there).

The reason for wanting these is that it will make the gradients of solve faster. I think this is a major reason why jax has faster gradients than us (at least when solve is implicated). They route everything to lu_solve(lu_factor(A), b), and reuse lu_factor(A) in the backward pass.

Related Issue

  • Closes #
  • Related to #

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pytensor--1218.org.readthedocs.build/en/1218/

@jessegrabowski jessegrabowski force-pushed the LU-factorization branch 3 times, most recently from b1f8c9d to 2a5b361 Compare February 20, 2025 06:05


def lu(
a: TensorLike, permute_l=False, check_finite=True, p_indices=False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can allow overwrite_a in the signature but just ignore it like JAX does

compare_numba_and_py([A], out, test_inputs=[A_val], inplace=True)

else:
# compare_numba_and_py fails: NotImplementedError: Non-jitted BlockwiseWithCoreShape not implemented
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just pass eval_obj_mode=False

def __init__(self, *, overwrite_a=False, check_finite=True):
self.overwrite_a = overwrite_a
self.check_finite = check_finite
self.gufunc_signature = "(m,m)->(m,m),(m)"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put signature out of init since it's constant

f = pytensor.function([A], out, mode="NUMBA")

if len(shape) == 2:
compare_numba_and_py([A], out, [A_val], inplace=True)
Copy link
Member

@ricardoV94 ricardoV94 Feb 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same you can use compare_numba_and_py

@ricardoV94
Copy link
Member

ricardoV94 commented Feb 28, 2025

Any benchmarks on solve written with these Ops?

@jessegrabowski
Copy link
Member Author

Working on that next, just ironing out some bugs in the lu_solve Op (which is what all this is building towards)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants