Description
Description
When output shapes depend on input values, Blockwise are not necessarily valid at runtime. For example the following graph is not supported by PyTensor at runtime, because it would require support for ragged arrays in the intermediate Blockwise(Arange)
:
import pytensor.tensor as pt
from pytensor.graph import vectorize
i = pt.scalar("i", dtype=int)
y = pt.sum(pt.arange(0, i))
new_i = pt.vector("new_i", dtype=int)
new_y = vectorize(y, {i: new_i})
new_y.eval({new_i: [1, 2, 3, 4]}) # ValueError
However if we were to wrap the Arange + Sum in an OpFromGraph, that subgraph would be a valid Blockwise, and PyTensor would be happy to evaluate it:
import pytensor.tensor as pt
from pytensor.graph import vectorize
from pytensor.compile.builders import OpFromGraph
i = pt.scalar("i", dtype=int)
y_ = pt.sum(pt.arange(0, i))
y = OpFromGraph([i], [y_])(i)
new_i = pt.vector("new_i", dtype=int)
new_y = vectorize(y, {i: new_i})
new_y.eval({new_i: [1, 2, 3, 4]}) # [0, 1, 3, 6]
Would be nice to use this trick to support end-to-end vectorization in these cases. Some of the logic needed to infer whether an Op has a square shape or not is being developed in #1015.
Some of the logic developed in pymc-devs/pymc-extras#300 to understand how dimensions propagate over nodes could be repurposed to figure out in which cases a subgraph collapses ragged dimensions.
We can start very simple and just allow immediate reductions of ragged dimensions.