Implement more meaningful Reshape
operation
#883
ricardoV94
started this conversation in
Ideas
Replies: 1 comment
-
Recent related discussions: #1201 and #1192 (comment) |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Description
Analyzing graphs with reshape operations is rather complex because Reshape represents not "the meaning", but rather "the final look" of the operation.
Except for esoteric cases where
Reshape
shapes may come from a complex computation / shapes of other variables, it is usually a case of multiplying some dimensions (merging) and diving others (splitting), plus squeezing/expand_dims.The last two are well encoded by DimShuffle, but there's nothing nice for the first two.
What if we had:
It almost begs for an extension of
DimShuffle
, which was brought up before: Theano/Theano#4640Splitting dims is trickier, because there are many choices, we can split in different orders and sizes
split_dims
is still unfortunate because we don't have symbolic dims. We can saysplit_dims(..., sizes=(x.shape[0], x.shape[1]))
though, which is still a bit more readable than Reshape (specially with the sneaky -1).How would it be used
Users will probably not know about this specialized Op, but in our internal uses where we know this is the goal we can introduce it. This is most of the cases I've seen: tensordot, tile, repeat..., matmul rewrites
We can also try to pay the one time cost of canonicalizing arbitrary Reshapes into join_dims / split_dims. In the end we can specialize back to
Reshape
Existing pain points
An example where Reshape is currently hard to work with is during vectorization. If we have a common graph like
reshape(x, x.shape[0] * x.shape[1], -1)
we cannot return the desired outputreshape(new_x, x.shape[0], x.shape[1] * x.shape[2], -1)
eagerly because there is a chain of complex operations we must vectorize before we get to theReshape
node (Shape
->Subtensor
->Mul
->MakeVector
). So we need to put it in a costly Blockwise and try our best to remove it during rewrites. This came up in #722 when vectorizingtensordot
to get abatched_tensordot
.Such a problem wouldn't exist with a
join_dims
, although it would still exist to some extent with asplit_dims
.Another is for repeated-element-irrelevant reductions, where we should be able to just ignore the reshape:
It also makes rewrites to remove/lift reshapes much simpler than they currently are:
pytensor/pytensor/tensor/rewriting/shape.py
Lines 798 to 895 in bf73f8a
Precedence
This is somewhat related to why we have
Second
andAlloc
. The first one is easier to reason about because it tells us more immediately that we are broadcasting with the shape of a variable, whereas Alloc specifies the desired output without its meaning (specially after some rewrites, where the shape may become dissociated from the original variable)pytensor/pytensor/tensor/rewriting/basic.py
Lines 3 to 23 in d62f4b1
Beta Was this translation helpful? Give feedback.
All reactions