Skip to content

ENH: Symbolic Vectorization #109

Closed
Closed
@ferrine

Description

@ferrine
fg = FunctionGraph([inputs], [outputs])
# follow the JAX API https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html
vfg = pytensor.graph.vectorize(fg, in_axis=[0], out_axis=[0], axis_name={0: "batch"})

# additionally, some collective Ops seem to be useful
...
# no OP if there until pytensor.graph.vectorize is called
sum_batch = pytensor.tensor.collective.Sum(tensor, axis="batch")
mean_batch = pytensor.tensor.collective.Mean(tensor, axis="batch")
...

Context for the issue:

Graph rewriting to vectorize operations in a symbolic way is a huge step to improve pymc/pytensor user experience.

Example use cases:

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions