You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Currently we have an Op that calls np.tri, but we can very easily build lower triangular mask matrices with _iota:
frompytensor.tensor.einsumimport_iotadeftri(M, N, k):
return ((_iota(M) +k) >_iota(N)).astype(int)
This is what jax does. The benefit of doing things this way is that we'll automatically have a dispatchable Op for Numba (numba supports np.tri, but only under specific circumstances -- I tried a naive dispatch and it didn't work ) and Pytorch (#821 asks for Tri, so this would check off that box)
I suggest we wrap this in a dummy OpFromGraph like we do for Kron and AllocDiag so that the dprints are nicer. We can also overload the L_op if we want? The current tri has grad_undefined, so we could keep that if it's correct. Or just keep the autodiff solution -- the proposed _iota function should be differentiable.
The text was updated successfully, but these errors were encountered:
@jessegrabowski New contributor here. I created a PR based on my understanding of the issue. It probably has a few errors. I was unable to run pytest locally due to some circular import issue. Waiting for feedback!
Description
Currently we have an Op that calls
np.tri
, but we can very easily build lower triangular mask matrices with_iota
:This is what jax does. The benefit of doing things this way is that we'll automatically have a dispatchable Op for Numba (numba supports
np.tri
, but only under specific circumstances -- I tried a naive dispatch and it didn't work ) and Pytorch (#821 asks for Tri, so this would check off that box)I suggest we wrap this in a dummy OpFromGraph like we do for
Kron
andAllocDiag
so that the dprints are nicer. We can also overload theL_op
if we want? The currenttri
hasgrad_undefined
, so we could keep that if it's correct. Or just keep the autodiff solution -- the proposed_iota
function should be differentiable.The text was updated successfully, but these errors were encountered: