Open
Description
pytensor/pytensor/link/jax/dispatch/elemwise.py
Lines 72 to 89 in d3bd1f1
The JAX docs of lax.reshape (which np.reshape uses) suggest this may be better for further optimizations: https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.reshape.html#jax.lax.reshape
Relevant part:
For inserting/removing dimensions of size 1, prefer using lax.squeeze / lax.expand_dims. These preserve information about axis identity that may be useful for advanced transformation rules.