Description
Description
Analyzing graphs with reshape operations is rather complex because Reshape represents what we want, but not "what it means"".
Except for esoteric cases where Reshape
shapes may come from a complex computation / shapes of other variables, it is usually a case of multiplying some dimensions (merging) and diving others (splitting). We could represent these cases with some sort of symbolic mapping:
x = tensor(shape=(4, 3, 2))
x.reshape(4, 6) # JoinDims(0, (1, 2))
It almost begs for an extension of DimShuffle
, which was brought up before: Theano/Theano#4640
Splitting dims is trickier, because there are many choices, we can split in different orders and sizes
x = tensor(shape=(12,))
x.reshape(2, 2, 3)
x.reshape(2, 3, 2)
x.reshape(4, 3)
...
Still an Op that achieves the same as splitting via reshape but knows which dims are going where (and in what quantities), would be more readable
An example where Reshape is currently hard to work with is during vectorization. If we have a common graph like reshape(x, x.shape[0] * x.shape[1], -1)
we cannot return the desired output reshape(new_x, x.shape[0], x.shape[1] * x.shape[2], -1)
eagerly because there is a chain of complex operations we must vectorize before we get to the Reshape
node (Shape
-> Subtensor
-> Mul
-> MakeVector
). So we need to put it in a costly Blockwise and try our best to remove it during rewrites. This came up in #702 when vectorizing tensordot
to get a batched_tensordot
Such a problem wouldn't exist with a symbolic reshape that is told what dims are being joined/split.
It also makes rewrites to remove/lift reshapes much simpler than they currently are:
pytensor/pytensor/tensor/rewriting/shape.py
Lines 798 to 895 in bf73f8a
This is somewhat related to why we have Second
and Alloc
. The first one is easier to reason about because it tells us more immediately that we are broadcasting with the shape of a variable, whereas Alloc specifies the desired output without its meaning (specially after some rewrites, where the shape may become dissociated from the original variable)
pytensor/pytensor/tensor/rewriting/basic.py
Lines 3 to 23 in d62f4b1