Use expand_dims
/ squeeze
in JAX implementation of Dimshuffle
#847
Labels
expand_dims
/ squeeze
in JAX implementation of Dimshuffle
#847
pytensor/pytensor/link/jax/dispatch/elemwise.py
Lines 72 to 89 in d3bd1f1
The JAX docs of lax.reshape (which np.reshape uses) suggest this may be better for further optimizations: https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.reshape.html#jax.lax.reshape
Relevant part:
The text was updated successfully, but these errors were encountered: