diff --git a/compiler/src/iree/compiler/GlobalOptimization/DataLayoutPropagation.cpp b/compiler/src/iree/compiler/GlobalOptimization/DataLayoutPropagation.cpp index 58929fc5b7d4..c56e7087c713 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/DataLayoutPropagation.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/DataLayoutPropagation.cpp @@ -27,13 +27,13 @@ struct DataLayoutPropagationPass patterns, [](OpOperand *opOperand) { Operation *producer = opOperand->get().getDefiningOp(); Operation *consumer = opOperand->getOwner(); - (void)consumer; - // Currently only bubble up/push down pack/unpack through - // collapse/expand shape ops. - // TODO(#17734): The propagation through expand_shape ops is broken. - // Enable the propagation once we find it useful and the upstream - // issue is fixed. - return isa(producer); + if (isa(consumer)) { + return isa(producer); + } + if (isa(producer)) { + return isa(consumer); + } + return false; }); if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { funcOp.emitOpError("folding patterns failed"); diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/data_layout_propagation.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/data_layout_propagation.mlir index bd262cf3f9b5..556cfedf8fee 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/test/data_layout_propagation.mlir +++ b/compiler/src/iree/compiler/GlobalOptimization/test/data_layout_propagation.mlir @@ -27,12 +27,9 @@ func.func @push_down_unpack_through_expand(%5: tensor, %dim: index // CHECK-LABEL: func.func @push_down_unpack_through_expand // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] -// TODO(#17734): Flip the check after we have better control function support. -// CHECK: tensor.unpack -// CHECK: tensor.expand_shape -// NO-CHECK: %[[C0:.+]] = arith.constant 0 : index -// NO-CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3], [4]] output_shape {{.*}} : tensor into tensor -// NO-CHECK: %[[DIM:.+]] = tensor.dim %[[EXPANDED]], %[[C0]] : tensor -// NO-CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor -// NO-CHECK: %[[UNPACK:.+]] = tensor.unpack %[[EXPANDED:.+]] outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %[[EMPTY]] : tensor -> tensor -// NO-CHECK: return %[[UNPACK]] : tensor +// CHECK: %[[C0:.+]] = arith.constant 0 : index +// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3], [4]] output_shape {{.*}} : tensor into tensor +// CHECK: %[[DIM:.+]] = tensor.dim %[[EXPANDED]], %[[C0]] : tensor +// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor +// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[EXPANDED:.+]] outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %[[EMPTY]] : tensor -> tensor +// CHECK: return %[[UNPACK]] : tensor