Skip to content

Vectorize follow-up #430

Closed
Closed
@ricardoV94

Description

@ricardoV94

Description

Now that #306 is merged, there are a couple of follow ups we should do:

  1. Blockwise more Ops (Everything in linalg?)
  2. Dispatch vectorize for more Ops
    1. Alloc Blockwise improvements #532
    2. Shape Vectorize dispatch for shape operations #454
    3. Subtensor Blockwise improvements #532
    4. Arange Blockwise improvements #532
    5. ExtractDiag Blockwise improvements #532
    6. Assert
  3. Implement JAX and Numba dispatch
    1. Use jax vectorize Support Blockwise in JAX backend #487
    2. Use the machinery developed in Add support for random Generators in Numba backend #691
import pytensor.tensor as pt
from pytensor.graph import vectorize
from pytensor.compile.builders import OpFromGraph

i = pt.scalar("i", dtype=int)
y_ = pt.sum(pt.arange(0, i))
y = OpFromGraph([i], [y_])(i)

new_i = pt.vector("new_i", dtype=int)
new_y = vectorize(y, {i: new_i})
new_y.eval({new_i: [1, 2, 3, 4]})  # [0, 1, 3, 6]

We could explore automatically wrapping such sequences of "non-square blockwised Ops - reduced non-square dims" in an blockwised OpFromGraph during rewrites, to support those cases.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions