Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Dimshuffle in jax using expand_dims/squeeze #987

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

HarshvirSandhu
Copy link
Contributor

Description

Implement DimShuffle using expand_dims/squeeze

Related Issue

Checklist

Type of change

  • Maintenance

Copy link

codecov bot commented Sep 7, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 81.74%. Comparing base (b66d859) to head (322f6ee).

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #987      +/-   ##
==========================================
- Coverage   81.74%   81.74%   -0.01%     
==========================================
  Files         183      183              
  Lines       47724    47722       -2     
  Branches    11616    11615       -1     
==========================================
- Hits        39011    39009       -2     
  Misses       6520     6520              
  Partials     2193     2193              
Files with missing lines Coverage Δ
pytensor/link/jax/dispatch/elemwise.py 81.15% <100.00%> (-0.54%) ⬇️


res = jnp.reshape(res, shape)
res = jax.lax.expand_dims(res, op.augment)
res = jax.lax.squeeze(res, op.drop)
Copy link
Member

@ricardoV94 ricardoV94 Sep 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think after the transpose the dropped dimension are all on the right, not in op.drop? So should be something like squeeze(res, tuple(-(np.arange(len(op.drop) + 1)) ? Also I think the squeeze should be done before the expand_dims.

I suspect we need more tests of DimShuffle in JAX backend that test for multiple expand_dims / transposition / drop, to catch accidents like in this refactor (unless I'm wrong and the logic is correct, but I suspect not).

Also not related, but the op.inplace logic below is not needed in JAX, which never does inplace operations

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the case of a transpose, op.drop would be an empty list. But tuple(-(np.arange(len(op.drop) + 1)) would be non empty and if squeeze is done before expand_dims, the logic would fail in this case

Copy link
Member

@ricardoV94 ricardoV94 Oct 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, but an alternative is still needed. The current implementation in this PR is wrong AFAICT

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Use expand_dims / squeeze in JAX implementation of Dimshuffle
2 participants