-
Notifications
You must be signed in to change notification settings - Fork 108
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement all Ops in PyTorch (help welcome!) #821
Comments
Hi there @ricardoV94, I'm attending the hackathon at Pydata London. I'd like to work on Softmax. |
Sure! |
Thanks @ricardoV94 I have been doing some reading around the docs, I noticed that #764 has the initial setup. I believe that contains ground work for adding other Ops, so I'd be following the PR and the Softmax Ops will likely use some code from it. Am I thinking about this the right way? |
#764 is merged |
Hello @ricardoV94, I would like to work on Reshape |
Go ahead. We'll link and lock the Op when you open a PR |
Hi @ricardoV94, I have opened a PR for the Softmax Ops. I see it has been grouped with LogSoftmax and Grads, so I can update the PR to include them. |
Hi @ricardoV94, I will work on the |
If someone wants to look through the codebase and populate the list of Ops above that would also be very helpful :) |
@ricardoV94 Does something like this work? There are correponding torch function/method attached each op.
|
I could help with the remaining |
Thanks @twaclaw, feel free to open a PR |
I will have a look at @ricardoV94, regarding
|
@twaclaw in general we want to support exactly the same functionality from the original Op. When that is not possible or too complicated raising Regarding JAX, we probably cannot compile (JIT) any function that has unique in it because JAX can't handle dynamic shapes. So it's a bit moot whether we say we support axis or not, although the NotImplementedError could be removed and we could just dispatch to |
I'll be working on the indexing Ops now |
@ricardoV94, regarding the JAX implementation of |
I can't imagine many cases where I would know the size of the unique elements but not what/where they were? If I knew what/where they are I would just select them instead of using unique. More importantly we don't have size in our Unique Op and I don't think it makes sense to add it for this edge case. A more general approach will be to be clever about what can and cannot be jitted. I think there's an issue open for that already. For now we can probably just remove the implementation and let it raise NotImplementedError if as you say it's broken anyway |
You are right, |
I can take a look at the linear algebra ops. |
Seems like Reshape is become more and more relevant, if anyone wants to tackle it |
Is someone working on this? |
Not yet I think. You can go ahead |
Is the checklist at the top up to date on what else is needed? |
More or less up to date except linalg and indexing is being worked on |
If someone is interested we need to check whether we can bridge nicely between PyTensor and Torch random number generator APIs. We have added a recent documentation page explaining how random variables work in PyTensor: #928 As a reminder we're targetting torch compile functionality in case that matters |
@ricardoV94 I'll take a stab after i finish some of the operators in #939. I need to build a bit more familiarity with the pytensor code first |
Coming from PyMC, adding a sparse solve would be useful I believe... |
This issue is not very relevant for that request, since we first need it in PyTensor to begin with, before we add it to the PyTorch backend. We haven't done anything with Sparse stuff in the PyTorch backend to begin with |
Description
If you want to help implementing some of these Ops just leave a comment below saying which ones you are interested in. We'll give you some time to work on it (and then put it back up to grabs).
See the documentation for How to implement PyTorch Ops and tests: https://pytensor.readthedocs.io/en/latest/extending/creating_a_numba_jax_op.html
Example PR: #836
See #821 (comment) for suggestions on equivalent torch functions
Tensor creation Ops
Shape Ops
Math Ops
Indexing Ops
Branching Ops
Linalg Ops
SparseOps
RandomVariable Ops
If you need an Op that's not in this list, comment below and we'll add it!
The text was updated successfully, but these errors were encountered: