Description
Goal
We should implement a new non-user facing Op that is introduced during specialization to fuse Composite and CAReduce Ops.
Terminology
- Scalar Op: Op that performs a basic operation on scalar inputs (add, exp, sin, ...)
- Elemwise Op: Op that extends a Scalar Op into tensor inputs (similar to numpy ufunc). This is what users usually build graphs in terms of when they write
tensor.add
,tensor.exp
, ...) - Composite Op: A scalar Op that performs multiple pure scalar operations on a set of scalar inputs. It "fuses them" in a single pass. Composite Ops can also be turned into Elemwise Ops, which leads us to...
- FusionRewrite: Rewrite during the specialization phase that replaces multiple Elemwise operations by a single Elemwise Composite, so as to avoid iterating over each intermediate output.
- CAReduce: An
Op
that performs a Commutative Associative Reduce operation on tensor inputs. It has a core binaryScalar
Op
such asAdd
orAnd
. It is called sequentially on the inputs along certain axis, until those are reduced.
Description
Users don't usually work with Composites directly. Instead we have a FusionRewrite introduced during the specialization phase that replaces sequences of simple Elemwise Ops by a single large Elemwise Composite:
You can see it here:
import pytensor
import pytensor.tensor as pt
xs = pt.vector("xs")
out = pt.exp(pt.sin(pt.cos(pt.log(xs))))
print("Before:")
pytensor.dprint(out)
f_out = pytensor.function([xs], out)
print("\nAfter:")
pytensor.dprint(f_out)
Before:
Elemwise{exp,no_inplace} [id A]
|Elemwise{sin,no_inplace} [id B]
|Elemwise{cos,no_inplace} [id C]
|Elemwise{log,no_inplace} [id D]
|xs [id E]
After:
Elemwise{Composite} [id A] 0
|xs [id B]
Inner graphs:
Elemwise{Composite} [id A]
>exp [id C]
> |sin [id D]
> |cos [id E]
> |log [id F]
> |<float64> [id G]
We would like to follow up the FusionRewrite with a new rewrite that fuses Composites with CAReduce operations. So if we have a graph like this:
xs = pt.vector("xs")
elemwise_out = pt.exp(pt.sin(pt.cos(pt.log(xs))))
reduction_out = pt.sum(elemwise_out)
We want to perform a single pass through xs
. The evaluation pseudo-code in Python would look something like:
reduction_out = 0 # Initial value of the Add CAReduce
for x in xs.ravel():
reduction_out += exp(sin(cos(log(x))))
This requires either extending the CAReduce (or creating a new Op) that represents that new type of composite operation. Ideally it would support Composites with multiple reduced and non-reduced outputs and multiple types of reduction (say And
in one output and Add
in another. I am not sure about axis... So we might only apply it to the None
case for now, where all axis are reduced.
As we want to discontinue the C backend, I think we should focus on implementing this for the Numba backend alone (and Python, for debugging purposes). The JAX backend probably does optimizations of this kind directly and we are not working at such a low level there anyway, so we wouldn't do anything about JAX either.
Steps
- Implement new Op that performs Elemwise + partial CAReduce operations on certain outputs. This Op must contain all the meta-information needed to create the write evaluation function. This Op need not have gradients nor C or JAX implementation.
- Implement Python perform method (for debugging purposes, arguably lower priority, but it has proved useful to debug and understand other complex Ops like Scan)
- Implement Numba perform method (some foundational work was done on A more low level implementation of vectorize in numba #92, see below)
- Implement Numba-only rewrite that replaces Composite + CAReduce by new Op
Relevant links
- A more low level implementation of vectorize in numba #92 implemented a low level Elemwise in Numba, already with capacity to accumulate partial outputs, but only for Addition. The Numba-related work would be done here.
- Downstream 1285: Add a fusion rewrite for CAReduces with Elemwise inputs #40 Did some small scope progress here, by using single output Composites as the
ScalarOp
ofCAReduce
. This however won't work for multiple mixed outputs. - Fuse
Elemwise
graphs that have multiple outputs and clients #121 extended the FusionRewrite to handle multiple output Composites - Optimize
Sum
s ofMakeVector
s andJoin
s #59 is ever more relevant here, as it would increase the number of Elemwise+Reduction graphs that could be optimized, by removing unnecessary shape-related operations and increasing the chance that they happen immediately below a Composite operation. - FusionOptimizer truncation logic should be backend specific #140