diff --git a/compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp b/compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp index 03b9aaa7f731..a6d7380b716b 100644 --- a/compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp @@ -208,9 +208,6 @@ static std::optional> getHorizontalFusionGroupMembers( if (!dominanceInfo.properlyDominates(seedOp, linalgOp)) { return false; } - if (!isHorizontalToGroup(linalgOp, allOps, dominanceInfo, seedOp)) { - return false; - } return true; }; @@ -240,6 +237,10 @@ static std::optional> getHorizontalFusionGroupMembers( continue; } + if (!isHorizontalToGroup(linalgUser, allOps, dominanceInfo, seedOp)) { + continue; + } + contractionOps.push_back(linalgUser); allOps.insert(linalgUser); if (contractionOps.size() >= fusionLimit) { diff --git a/compiler/src/iree/compiler/DispatchCreation/test/fuse_horizontal_contractions.mlir b/compiler/src/iree/compiler/DispatchCreation/test/fuse_horizontal_contractions.mlir index 7f724be8a2f8..7090cb7590db 100644 --- a/compiler/src/iree/compiler/DispatchCreation/test/fuse_horizontal_contractions.mlir +++ b/compiler/src/iree/compiler/DispatchCreation/test/fuse_horizontal_contractions.mlir @@ -365,3 +365,31 @@ util.func @dont_fuse_contractions_with_different_n(%lhs : tensor<10x20xf32>, // CHECK: %[[MATMUL1:.+]] = linalg.matmul // CHECK-SAME: ins(%[[LHS]], %[[RHS1]] : // CHECK: util.return %[[MATMUL0]], %[[MATMUL1]] + +// ----- + +util.func public @check_horizontal_independence(%arg0: tensor<640x640xf32>, + %arg1: tensor<640x640xf32>, %arg2: tensor<640x640xf32>, + %arg3: tensor<640x640xf32>) + -> (tensor<640x640xf32>, tensor<640x640xf32>, tensor<640x640xf32>) { + %cst = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : tensor<640x640xf32> + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<640x640xf32>) -> tensor<640x640xf32> + %2 = linalg.matmul ins(%arg0, %arg1 : tensor<640x640xf32>, tensor<640x640xf32>) + outs(%1 : tensor<640x640xf32>) -> tensor<640x640xf32> + %3 = linalg.matmul ins(%arg0, %arg2 : tensor<640x640xf32>, tensor<640x640xf32>) + outs(%1 : tensor<640x640xf32>) -> tensor<640x640xf32> + %4 = linalg.matmul ins(%arg0, %3 : tensor<640x640xf32>, tensor<640x640xf32>) + outs(%1 : tensor<640x640xf32>) -> tensor<640x640xf32> + util.return %2, %3, %4 : tensor<640x640xf32>, tensor<640x640xf32>, tensor<640x640xf32> +} +// CHECK-LABEL: func public @check_horizontal_independence +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<640x640xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<640x640xf32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<640x640xf32> +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: tensor<640x640xf32> +// CHECK: %[[FUSED_OP:.+]]:2 = linalg.generic +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[ARG2]] : +// CHECK: %[[OP:.+]] = linalg.matmul +// CHECK: ins(%[[ARG0]], %[[FUSED_OP]]#1 : +// CHECK: util.return %[[FUSED_OP]]#0, %[[FUSED_OP]]#1, %[[OP]]