Skip to content

Commit

Permalink
[GlobalOpt] Re-enable (unpack, expand_shape) propagation.
Browse files Browse the repository at this point in the history
It was disabled because of an upstream bug. After adding more control on
the control function, we no longer need the workaround. The revision
flips the behavior to what it was.

Issue: iree-org#17734

Signed-off-by: hanhanW <[email protected]>
  • Loading branch information
hanhanW committed Jul 29, 2024
1 parent 45323df commit 65640c6
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@ struct DataLayoutPropagationPass
patterns, [](OpOperand *opOperand) {
Operation *producer = opOperand->get().getDefiningOp();
Operation *consumer = opOperand->getOwner();
(void)consumer;
// Currently only bubble up/push down pack/unpack through
// collapse/expand shape ops.
// TODO(#17734): The propagation through expand_shape ops is broken.
// Enable the propagation once we find it useful and the upstream
// issue is fixed.
return isa<tensor::CollapseShapeOp>(producer);
if (isa<tensor::PackOp>(consumer)) {
return isa<tensor::CollapseShapeOp>(producer);
}
if (isa<tensor::UnPackOp>(producer)) {
return isa<tensor::ExpandShapeOp>(consumer);
}
return false;
});
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
funcOp.emitOpError("folding patterns failed");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,9 @@ func.func @push_down_unpack_through_expand(%5: tensor<?x32x8x8xf32>, %dim: index
// CHECK-LABEL: func.func @push_down_unpack_through_expand
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
// TODO(#17734): Flip the check after we have better control function support.
// CHECK: tensor.unpack
// CHECK: tensor.expand_shape
// NO-CHECK: %[[C0:.+]] = arith.constant 0 : index
// NO-CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3], [4]] output_shape {{.*}} : tensor<?x32x8x8xf32> into tensor<?x32x32x8x8xf32>
// NO-CHECK: %[[DIM:.+]] = tensor.dim %[[EXPANDED]], %[[C0]] : tensor<?x32x32x8x8xf32>
// NO-CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor<?x256x256xf32>
// NO-CHECK: %[[UNPACK:.+]] = tensor.unpack %[[EXPANDED:.+]] outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %[[EMPTY]] : tensor<?x32x32x8x8xf32> -> tensor<?x256x256xf32>
// NO-CHECK: return %[[UNPACK]] : tensor<?x256x256xf32>
// CHECK: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3], [4]] output_shape {{.*}} : tensor<?x32x8x8xf32> into tensor<?x32x32x8x8xf32>
// CHECK: %[[DIM:.+]] = tensor.dim %[[EXPANDED]], %[[C0]] : tensor<?x32x32x8x8xf32>
// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor<?x256x256xf32>
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[EXPANDED:.+]] outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %[[EMPTY]] : tensor<?x32x32x8x8xf32> -> tensor<?x256x256xf32>
// CHECK: return %[[UNPACK]] : tensor<?x256x256xf32>

0 comments on commit 65640c6

Please sign in to comment.