-
Notifications
You must be signed in to change notification settings - Fork 108
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement Dimshuffle in jax using expand_dims/squeeze #987
base: main
Are you sure you want to change the base?
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #987 +/- ##
==========================================
- Coverage 81.74% 81.74% -0.01%
==========================================
Files 183 183
Lines 47724 47722 -2
Branches 11616 11615 -1
==========================================
- Hits 39011 39009 -2
Misses 6520 6520
Partials 2193 2193
|
|
||
res = jnp.reshape(res, shape) | ||
res = jax.lax.expand_dims(res, op.augment) | ||
res = jax.lax.squeeze(res, op.drop) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think after the transpose the dropped dimension are all on the right, not in op.drop
? So should be something like squeeze(res, tuple(-(np.arange(len(op.drop) + 1))
? Also I think the squeeze
should be done before the expand_dims
.
I suspect we need more tests of DimShuffle in JAX backend that test for multiple expand_dims / transposition / drop, to catch accidents like in this refactor (unless I'm wrong and the logic is correct, but I suspect not).
Also not related, but the op.inplace
logic below is not needed in JAX, which never does inplace operations
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the case of a transpose, op.drop
would be an empty list. But tuple(-(np.arange(len(op.drop) + 1))
would be non empty and if squeeze
is done before expand_dims
, the logic would fail in this case
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, but an alternative is still needed. The current implementation in this PR is wrong AFAICT
Description
Implement DimShuffle using expand_dims/squeeze
Related Issue
expand_dims
/squeeze
in JAX implementation ofDimshuffle
#847expand_dims
/squeeze
in JAX implementation ofDimshuffle
#847Checklist
Type of change