Closed
Description
Description
Now that #306 is merged, there are a couple of follow ups we should do:
- Blockwise more Ops (Everything in linalg?)
- Dispatch vectorize for more Ops
- Alloc Blockwise improvements #532
- Shape Vectorize dispatch for shape operations #454
- Subtensor Blockwise improvements #532
- Arange Blockwise improvements #532
- ExtractDiag Blockwise improvements #532
- Assert
- Implement JAX and Numba dispatch
- Use jax
vectorize
Support Blockwise in JAX backend #487 - Use the machinery developed in Add support for random Generators in Numba backend #691
- Use jax
import pytensor.tensor as pt
from pytensor.graph import vectorize
from pytensor.compile.builders import OpFromGraph
i = pt.scalar("i", dtype=int)
y_ = pt.sum(pt.arange(0, i))
y = OpFromGraph([i], [y_])(i)
new_i = pt.vector("new_i", dtype=int)
new_y = vectorize(y, {i: new_i})
new_y.eval({new_i: [1, 2, 3, 4]}) # [0, 1, 3, 6]
We could explore automatically wrapping such sequences of "non-square blockwised Ops - reduced non-square dims" in an blockwised OpFromGraph during rewrites, to support those cases.