Skip to content

Implement more meaningful Reshape operation #882

Closed
@ricardoV94

Description

@ricardoV94

Description

Analyzing graphs with reshape operations is rather complex because Reshape represents what we want, but not "what it means"".

Except for esoteric cases where Reshape shapes may come from a complex computation / shapes of other variables, it is usually a case of multiplying some dimensions (merging) and diving others (splitting). We could represent these cases with some sort of symbolic mapping:

x = tensor(shape=(4, 3, 2))
x.reshape(4, 6)  # JoinDims(0, (1, 2))

It almost begs for an extension of DimShuffle, which was brought up before: Theano/Theano#4640

Splitting dims is trickier, because there are many choices, we can split in different orders and sizes

x = tensor(shape=(12,))
x.reshape(2, 2, 3)
x.reshape(2, 3, 2)
x.reshape(4, 3)
...

Still an Op that achieves the same as splitting via reshape but knows which dims are going where (and in what quantities), would be more readable


An example where Reshape is currently hard to work with is during vectorization. If we have a common graph like reshape(x, x.shape[0] * x.shape[1], -1) we cannot return the desired output reshape(new_x, x.shape[0], x.shape[1] * x.shape[2], -1) eagerly because there is a chain of complex operations we must vectorize before we get to the Reshape node (Shape -> Subtensor -> Mul -> MakeVector). So we need to put it in a costly Blockwise and try our best to remove it during rewrites. This came up in #702 when vectorizing tensordot to get a batched_tensordot

Such a problem wouldn't exist with a symbolic reshape that is told what dims are being joined/split.

It also makes rewrites to remove/lift reshapes much simpler than they currently are:

def local_useless_reshape(fgraph, node):
"""Remove two kinds of useless `Reshape`.
- Remove `Reshape` when both the input and output have a single dimension.
- Remove `Reshape` when reshaping to the shape of the input.
"""
inp = node.inputs[0]
output = node.outputs[0]
output_shape = node.inputs[1]
if inp.type.ndim != output.type.ndim:
return False
# Simple case: both input and output have a single dimension.
# TODO FIXME XXX: This could hide errors if the user provides inconsistent
# shapes.
if (
inp.type.ndim == 1
and output.type.ndim == 1
and all(
s1 == s2
for s1, s2 in zip(inp.type.shape, output.type.shape)
if s1 == 1 or s2 == 1
)
):
return [inp]
# Second case: all the shapes match the input shape
# Match Reshape(x, x.shape)
if output_shape.owner and isinstance(output_shape.owner.op, Shape):
shape_input = output_shape.owner.inputs[0]
if shape_input == inp:
return [inp]
# Match Reshape(x, [x.shape[0], ..., x.shape[-1]]), accounting for
# broadcastable and constant dimensions
if output_shape.owner and isinstance(output_shape.owner.op, MakeVector):
output_shape_is = output_shape.owner.inputs
shape_feature = getattr(fgraph, "shape_feature", None)
nb_m1 = 0
shape_match = [False] * inp.type.ndim
for dim in range(inp.type.ndim):
outshp_i = output_shape_is[dim]
# Match Shape_i{dim}(input)
if (
outshp_i.owner
and isinstance(outshp_i.owner.op, Shape_i)
and outshp_i.owner.op.i == dim
and outshp_i.owner.inputs[0] == inp
):
shape_match[dim] = True
continue
# Match Shape(input)[dim]
if (
outshp_i.owner
and isinstance(outshp_i.owner.op, Subtensor)
and len(outshp_i.owner.inputs) == 2
and extract_constant(outshp_i.owner.inputs[1]) == dim
):
subtensor_inp = outshp_i.owner.inputs[0]
if subtensor_inp.owner and isinstance(subtensor_inp.owner.op, Shape):
shape_input_i = subtensor_inp.owner.inputs[0]
if shape_input_i == inp:
shape_match[dim] = True
continue
# Match 1 if input.type.shape[dim] == 1
cst_outshp_i = extract_constant(outshp_i, only_process_constants=1)
if inp.type.shape[dim] == 1 and cst_outshp_i == 1:
shape_match[dim] = True
continue
# Match -1
if cst_outshp_i == -1:
shape_match[dim] = True
nb_m1 += 1
continue
# Match shape_of[input][dim] or its constant equivalent
if shape_feature:
inpshp_i = shape_feature.get_shape(inp, dim)
if inpshp_i == outshp_i or (
extract_constant(inpshp_i, only_process_constants=1)
== extract_constant(outshp_i, only_process_constants=1)
):
shape_match[dim] = True
continue
if all(shape_match) and nb_m1 <= 1:
return [inp]
# TODO later: if all the shapes except one match, we may want to
# consider it useless as well, like we do in the 1-dim case.
return False


This is somewhat related to why we have Second and Alloc. The first one is easier to reason about because it tells us more immediately that we are broadcasting with the shape of a variable, whereas Alloc specifies the desired output without its meaning (specially after some rewrites, where the shape may become dissociated from the original variable)

Notes
-----
There are two ways of broadcasting arrays:
second(x, y) == alloc(y, broadcast_shapes(x.shape, y.shape))
The second can be more efficient because x doesn't usually need to be computed when we only want its shape.
It may also allow other rewrites that don't try to modify x when it has multiple clients (for fear of duplicating computation).
However, the first one is easier to reason about.
Knowing we have such a graph allows to do certain rewrites such as "sinking" broadcasting operations below Elemwise.
The same rewrites with alloc would be more complicated as we would need to symbolically combine the shapes of each one.
As an example contrast rewriting the following two equivalent graphs
alloc(x, broadcast_shapes(x.shape, y.shape)) + alloc(y, broadcast_shapes(x.shape, y.shape)) -> x + y
second(y, x) + second(x, y) -> x + y
Theano developers (mostly) preferred to use the first form during canonicalization and introduce the second form later,
via rewrites like `local_fill_to_alloc`, and using the `alloc_like` helper inside rewrites.
Many stabilize and stabilization rewrites refuse to be applied when a variable has multiple clients, so this is important.
"""

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions