Skip to content

Should Alloc be pushed downstream of expand_dims #884

Open
@ricardoV94

Description

@ricardoV94

We have some other rewrites that will push Alloc below Elemwise, so that we don't compute on repeated inputs, but this won't happen if there's an expand_dims in the way. As of now the following graph does not get lifted

import pytensor
import pytensor.tensor as pt

x = pt.vector("x", shape=(3,))
y = pt.alloc(x, 1000, 3)[None]
out = pt.exp(y)
pytensor.function([x], out).dprint(print_type=True)
# Exp [id A] <Tensor3(float64, shape=(1, 1000, 3))> 1
#  └─ Alloc [id B] <Tensor3(float64, shape=(1, 1000, 3))> 0
#     ├─ x [id C] <Vector(float64, shape=(3,))>
#     ├─ 1 [id D] <Scalar(int8, shape=())>
#     ├─ 1000 [id E] <Scalar(int16, shape=())>
#     └─ 3 [id F] <Scalar(int8, shape=())>

There is actually an "uncanonicalize" rewrite that allows "lifting" expand_dims above some Alloc, which would have helped here.

@register_uncanonicalize
@node_rewriter([DimShuffle])
def local_dimshuffle_alloc(fgraph, node):
"""
If an alloc is inside a dimshuffle which only adds dimension to the left,
scrap the dimshuffle and adds 1 into the alloc
dimshuffle{x, 0, 1}(alloc([3 4], 3, 2) => alloc([3 4], 1, 3, 2)
"""
if isinstance(node.op, DimShuffle) and node.inputs[0].owner:
input_ = node.inputs[0]
if isinstance(input_.owner.op, Alloc):
# check if it only adds dimension to the left
new_order = node.op.new_order
expected_new_order = ("x",) * (len(new_order) - input_.ndim) + tuple(
range(input_.ndim)
)
if new_order != expected_new_order:
return False
# count numbers of 'x'
nb_new_dims = len(new_order) - input_.ndim
new_shape_input = (1,) * nb_new_dims + tuple(input_.owner.inputs[1:])
return [alloc(input_.owner.inputs[0], *new_shape_input)]
return False

However, this is at odds with the opposite canonical local_alloc_sink_dimshuffle:

@register_specialize
@register_stabilize
@register_canonicalize
@node_rewriter([Alloc])
def local_alloc_sink_dimshuffle(fgraph, node):
r"""Convert broadcastable leading dimensions in an `Alloc` to `DimShuffle`\s."""

It's not obvious to me why the latter should be given preference. In general it seems like we can always lift expand_dims towards the inputs of the function (as it does not affect number of operations), and sink alloc towards the outputs. But here we are not allowing the "swap" when an expand_dims meets an alloc

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions