Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace our Tri op with an OpFromGraph #1265

Open
jessegrabowski opened this issue Mar 5, 2025 · 1 comment
Open

Replace our Tri op with an OpFromGraph #1265

jessegrabowski opened this issue Mar 5, 2025 · 1 comment
Labels

Comments

@jessegrabowski
Copy link
Member

jessegrabowski commented Mar 5, 2025

Description

Currently we have an Op that calls np.tri, but we can very easily build lower triangular mask matrices with _iota:

from pytensor.tensor.einsum import _iota
def tri(M, N, k):
    return ((_iota(M) + k) > _iota(N)).astype(int)

This is what jax does. The benefit of doing things this way is that we'll automatically have a dispatchable Op for Numba (numba supports np.tri, but only under specific circumstances -- I tried a naive dispatch and it didn't work ) and Pytorch (#821 asks for Tri, so this would check off that box)

I suggest we wrap this in a dummy OpFromGraph like we do for Kron and AllocDiag so that the dprints are nicer. We can also overload the L_op if we want? The current tri has grad_undefined, so we could keep that if it's correct. Or just keep the autodiff solution -- the proposed _iota function should be differentiable.

@Nimish-4
Copy link

Nimish-4 commented Mar 8, 2025

@jessegrabowski New contributor here. I created a PR based on my understanding of the issue. It probably has a few errors. I was unable to run pytest locally due to some circular import issue. Waiting for feedback!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants