From 65640c66561eea2ebb33b7e7b24da0b8c3c7b46a Mon Sep 17 00:00:00 2001 From: hanhanW Date: Mon, 29 Jul 2024 16:36:41 -0700 Subject: [PATCH] [GlobalOpt] Re-enable (unpack, expand_shape) propagation. It was disabled because of an upstream bug. After adding more control on the control function, we no longer need the workaround. The revision flips the behavior to what it was. Issue: https://github.com/iree-org/iree/issues/17734 Signed-off-by: hanhanW --- .../GlobalOptimization/DataLayoutPropagation.cpp | 14 +++++++------- .../test/data_layout_propagation.mlir | 15 ++++++--------- 2 files changed, 13 insertions(+), 16 deletions(-) diff --git a/compiler/src/iree/compiler/GlobalOptimization/DataLayoutPropagation.cpp b/compiler/src/iree/compiler/GlobalOptimization/DataLayoutPropagation.cpp index 58929fc5b7d4..c56e7087c713 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/DataLayoutPropagation.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/DataLayoutPropagation.cpp @@ -27,13 +27,13 @@ struct DataLayoutPropagationPass patterns, [](OpOperand *opOperand) { Operation *producer = opOperand->get().getDefiningOp(); Operation *consumer = opOperand->getOwner(); - (void)consumer; - // Currently only bubble up/push down pack/unpack through - // collapse/expand shape ops. - // TODO(#17734): The propagation through expand_shape ops is broken. - // Enable the propagation once we find it useful and the upstream - // issue is fixed. - return isa(producer); + if (isa(consumer)) { + return isa(producer); + } + if (isa(producer)) { + return isa(consumer); + } + return false; }); if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { funcOp.emitOpError("folding patterns failed"); diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/data_layout_propagation.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/data_layout_propagation.mlir index bd262cf3f9b5..556cfedf8fee 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/test/data_layout_propagation.mlir +++ b/compiler/src/iree/compiler/GlobalOptimization/test/data_layout_propagation.mlir @@ -27,12 +27,9 @@ func.func @push_down_unpack_through_expand(%5: tensor, %dim: index // CHECK-LABEL: func.func @push_down_unpack_through_expand // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] -// TODO(#17734): Flip the check after we have better control function support. -// CHECK: tensor.unpack -// CHECK: tensor.expand_shape -// NO-CHECK: %[[C0:.+]] = arith.constant 0 : index -// NO-CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3], [4]] output_shape {{.*}} : tensor into tensor -// NO-CHECK: %[[DIM:.+]] = tensor.dim %[[EXPANDED]], %[[C0]] : tensor -// NO-CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor -// NO-CHECK: %[[UNPACK:.+]] = tensor.unpack %[[EXPANDED:.+]] outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %[[EMPTY]] : tensor -> tensor -// NO-CHECK: return %[[UNPACK]] : tensor +// CHECK: %[[C0:.+]] = arith.constant 0 : index +// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3], [4]] output_shape {{.*}} : tensor into tensor +// CHECK: %[[DIM:.+]] = tensor.dim %[[EXPANDED]], %[[C0]] : tensor +// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor +// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[EXPANDED:.+]] outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %[[EMPTY]] : tensor -> tensor +// CHECK: return %[[UNPACK]] : tensor