Skip to content

Commit

Permalink
Fix an issue in DeadArgElim where we fail to mark condition op of scf…
Browse files Browse the repository at this point in the history
….if as live (#4404)

Summary: When scf.if is marked as live in ForOpDeadArgElim, we should
mark its condition as live too. Without this fix, with the test case
added in this patch, the scf.if will be removed.
  • Loading branch information
manman-ren committed Jul 28, 2024
1 parent 92f75d3 commit 7b617bc
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 0 deletions.
2 changes: 2 additions & 0 deletions lib/Dialect/TritonGPU/Transforms/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -896,6 +896,8 @@ struct ForOpDeadArgElimination : public OpRewritePattern<scf::ForOp> {
}
if (auto nestedIf = value.getDefiningOp<scf::IfOp>()) {
auto result = mlir::cast<OpResult>(value);
// mark condition as live.
markLive(nestedIf.getCondition());
for (scf::YieldOp nestedYieldOp :
{nestedIf.thenYield(), nestedIf.elseYield()}) {
Value nestedYieldOperand =
Expand Down
31 changes: 31 additions & 0 deletions test/TritonGPU/combine.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2542,3 +2542,34 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 :
tt.return %9 : tensor<1x256xi32, #blocked>
}
}

// -----
#blocked = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
// CHECK-LABEL: @if_condition_not_dead_inside_loop
// CHECK: scf.if
// CHECK-NOT: convert_layout
tt.func public @if_condition_not_dead_inside_loop(%arg0: i32) -> (tensor<32xf32, #blocked>, tensor<32xf32, #blocked>) {
%true = arith.constant true
%cst = arith.constant dense<1.000000e+00> : tensor<32xf32, #blocked1>
%cst_0 = arith.constant dense<2.000000e+00> : tensor<32xf32, #blocked>
%c0_i32 = arith.constant 0 : i32
%c32_i32 = arith.constant 32 : i32
%c4096_i32 = arith.constant 4096 : i32
%1:3 = scf.for %arg10 = %c0_i32 to %c4096_i32 step %c32_i32 iter_args(%arg1 = %cst, %arg3 = %cst_0, %arg4 = %true) -> (tensor<32xf32, #blocked1>, tensor<32xf32, #blocked>, i1) : i32 {
%3:2 = scf.if %arg4 -> (tensor<32xf32, #blocked1>, tensor<32xf32, #blocked>) {
scf.yield %cst, %cst_0 : tensor<32xf32, #blocked1>, tensor<32xf32, #blocked>
} else {
%4 = arith.addf %arg1, %cst : tensor<32xf32, #blocked1>
%5 = triton_gpu.convert_layout %4 : tensor<32xf32, #blocked1> -> tensor<32xf32, #blocked>
%6 = arith.mulf %arg3, %5 : tensor<32xf32, #blocked>
scf.yield %4, %6 : tensor<32xf32, #blocked1>, tensor<32xf32, #blocked>
}
%119 = arith.cmpi eq, %arg10, %arg0 : i32
scf.yield %3#0, %3#1, %119 : tensor<32xf32, #blocked1>, tensor<32xf32, #blocked>, i1
}
%7 = triton_gpu.convert_layout %1#0 : tensor<32xf32, #blocked1> -> tensor<32xf32, #blocked>
tt.return %7, %1#1 : tensor<32xf32, #blocked>, tensor<32xf32, #blocked>
}
}

0 comments on commit 7b617bc

Please sign in to comment.