Closed
Description
fg = FunctionGraph([inputs], [outputs])
# follow the JAX API https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html
vfg = pytensor.graph.vectorize(fg, in_axis=[0], out_axis=[0], axis_name={0: "batch"})
# additionally, some collective Ops seem to be useful
...
# no OP if there until pytensor.graph.vectorize is called
sum_batch = pytensor.tensor.collective.Sum(tensor, axis="batch")
mean_batch = pytensor.tensor.collective.Mean(tensor, axis="batch")
...
Context for the issue:
Graph rewriting to vectorize operations in a symbolic way is a huge step to improve pymc/pytensor user experience.
Example use cases:
- Remove scan from VI
https://github.com/pymc-devs/pymc/blob/main/pymc/variational/opvi.py#L1355 - Vectrorize NUTS over chains, to be done and to be implemented