Skip to content

Handle non-square intermediate Blockwise operations #1017

Open
@ricardoV94

Description

@ricardoV94

Description

When output shapes depend on input values, Blockwise are not necessarily valid at runtime. For example the following graph is not supported by PyTensor at runtime, because it would require support for ragged arrays in the intermediate Blockwise(Arange):

import pytensor.tensor as pt
from pytensor.graph import vectorize

i = pt.scalar("i", dtype=int)
y = pt.sum(pt.arange(0, i))

new_i = pt.vector("new_i", dtype=int)
new_y = vectorize(y, {i: new_i})
new_y.eval({new_i: [1, 2, 3, 4]})  # ValueError

However if we were to wrap the Arange + Sum in an OpFromGraph, that subgraph would be a valid Blockwise, and PyTensor would be happy to evaluate it:

import pytensor.tensor as pt
from pytensor.graph import vectorize
from pytensor.compile.builders import OpFromGraph

i = pt.scalar("i", dtype=int)
y_ = pt.sum(pt.arange(0, i))
y = OpFromGraph([i], [y_])(i)

new_i = pt.vector("new_i", dtype=int)
new_y = vectorize(y, {i: new_i})
new_y.eval({new_i: [1, 2, 3, 4]})  # [0, 1, 3, 6]

Would be nice to use this trick to support end-to-end vectorization in these cases. Some of the logic needed to infer whether an Op has a square shape or not is being developed in #1015.

Some of the logic developed in pymc-devs/pymc-extras#300 to understand how dimensions propagate over nodes could be repurposed to figure out in which cases a subgraph collapses ragged dimensions.

We can start very simple and just allow immediate reductions of ragged dimensions.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions