Skip to content

Commit

Permalink
Fix partial horizontal fusion dependence violation.
Browse files Browse the repository at this point in the history
Signed-off-by: MaheshRavishankar <[email protected]>
  • Loading branch information
MaheshRavishankar committed Feb 14, 2025
1 parent a27128a commit dcfa537
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -208,9 +208,6 @@ static std::optional<SmallVector<Operation *>> getHorizontalFusionGroupMembers(
if (!dominanceInfo.properlyDominates(seedOp, linalgOp)) {
return false;
}
if (!isHorizontalToGroup(linalgOp, allOps, dominanceInfo, seedOp)) {
return false;
}
return true;
};

Expand Down Expand Up @@ -240,6 +237,10 @@ static std::optional<SmallVector<Operation *>> getHorizontalFusionGroupMembers(
continue;
}

if (!isHorizontalToGroup(linalgUser, allOps, dominanceInfo, seedOp)) {
continue;
}

contractionOps.push_back(linalgUser);
allOps.insert(linalgUser);
if (contractionOps.size() >= fusionLimit) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -365,3 +365,31 @@ util.func @dont_fuse_contractions_with_different_n(%lhs : tensor<10x20xf32>,
// CHECK: %[[MATMUL1:.+]] = linalg.matmul
// CHECK-SAME: ins(%[[LHS]], %[[RHS1]] :
// CHECK: util.return %[[MATMUL0]], %[[MATMUL1]]

// -----

util.func public @check_horizontal_independence(%arg0: tensor<640x640xf32>,
%arg1: tensor<640x640xf32>, %arg2: tensor<640x640xf32>,
%arg3: tensor<640x640xf32>)
-> (tensor<640x640xf32>, tensor<640x640xf32>, tensor<640x640xf32>) {
%cst = arith.constant 0.000000e+00 : f32
%0 = tensor.empty() : tensor<640x640xf32>
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<640x640xf32>) -> tensor<640x640xf32>
%2 = linalg.matmul ins(%arg0, %arg1 : tensor<640x640xf32>, tensor<640x640xf32>)
outs(%1 : tensor<640x640xf32>) -> tensor<640x640xf32>
%3 = linalg.matmul ins(%arg0, %arg2 : tensor<640x640xf32>, tensor<640x640xf32>)
outs(%1 : tensor<640x640xf32>) -> tensor<640x640xf32>
%4 = linalg.matmul ins(%arg0, %3 : tensor<640x640xf32>, tensor<640x640xf32>)
outs(%1 : tensor<640x640xf32>) -> tensor<640x640xf32>
util.return %2, %3, %4 : tensor<640x640xf32>, tensor<640x640xf32>, tensor<640x640xf32>
}
// CHECK-LABEL: func public @check_horizontal_independence
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<640x640xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<640x640xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<640x640xf32>
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: tensor<640x640xf32>
// CHECK: %[[FUSED_OP:.+]]:2 = linalg.generic
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[ARG2]] :
// CHECK: %[[OP:.+]] = linalg.matmul
// CHECK: ins(%[[ARG0]], %[[FUSED_OP]]#1 :
// CHECK: util.return %[[FUSED_OP]]#0, %[[FUSED_OP]]#1, %[[OP]]

0 comments on commit dcfa537

Please sign in to comment.