Skip to content

Commit

Permalink
[BACKEND] Prevent for/yield argument number drift (#4097)
Browse files Browse the repository at this point in the history
In the current implementation, when backward rematerialization
encounters a loop argument that has already been rematerialized, the
process short-circuits the collection of yield operations, leaving the
value in the slice. However, if another loop argument is present in the
same slice, the loop is collected again, resulting in the duplication of
the first argument without generating the corresponding yield.

To address this issue, this fix removes values from the slice that are
skipped during collection, ensuring they are not reduplicated. This
adjustment ensures that the number of yield operands and loop iter_args
remain synchronized.
  • Loading branch information
amjames committed Jun 7, 2024
1 parent 1466ad9 commit b90b3a0
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 0 deletions.
6 changes: 6 additions & 0 deletions lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -841,12 +841,17 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
SetVector<Operation *> opsToRewrite;
// Keep track of yield operands that need to be duplicated.
DenseMap<Operation *, SmallVector<int>> yieldOperandsMap;
// Keep these around to remove them from the slice after our collection pass
// This ensures we don't duplicate them during an for rewrite or causing the
// for/yield to fall out of sync
SetVector<Value> valuesWithExistingRemat;
for (Value v : slice) {
auto layoutIt = layout.find(v);
assert(layoutIt != layout.end());
// If we already have a remat value for this value, use it.
if (hasRematValue(v, layoutIt->second)) {
mapping.map(v, getRematValue(v, layoutIt->second));
valuesWithExistingRemat.insert(v);
continue;
}
if (v.getDefiningOp()) {
Expand All @@ -870,6 +875,7 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
}
}
}
slice.set_subtract(valuesWithExistingRemat);
opsToRewrite = multiRootTopologicalSort(opsToRewrite);

// replaceAllUsesWith calls delayed until after initial rewrite.
Expand Down
72 changes: 72 additions & 0 deletions test/TritonGPU/combine.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2396,3 +2396,75 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :

}
}

// -----

#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked2 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}>
#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
// Regression test:
// Rematerialization of multiple loop-carried variables, where one is
// rematerialized to the same layout by multiple users.
// Previously this didn't interact correctly with the de-duplication mechanism.
// CHECK-LABEL: @multi_rematerialize_loop_arg
tt.func public @multi_rematerialize_loop_arg(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<i8>) -> (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) {
%c0_i32 = arith.constant 0 : i32
%c64_i32 = arith.constant 64 : i32
%c2048_i32 = arith.constant 2048 : i32
%cst = arith.constant dense<0.000000e+00> : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%cst_0 = arith.constant dense<0xFF800000> : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%cst_1 = arith.constant dense<0xFF800000> : tensor<128x64xf32, #mma>
%cst_2 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma>
%0 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<128x64x!tt.ptr<f16>, #blocked1>
%1 = tt.load %0 : tensor<128x64x!tt.ptr<f16>, #blocked1>
%2 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<64x64x!tt.ptr<f16>, #blocked2>
%3 = tt.splat %arg1 : !tt.ptr<i8> -> tensor<128x64x!tt.ptr<i8>, #blocked>
%4 = tt.splat %arg1 : !tt.ptr<i8> -> tensor<128x64x!tt.ptr<i8>, #blocked>
// CHECK: %[[F:.+]]:3 = scf.for {{.*}} -> (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>)
// CHECK: scf.yield {{.*}} : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
// CHECK: }
// CHECK: tt.return %[[F]]#0, %[[F]]#1 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%5:3 = scf.for %arg2 = %c0_i32 to %c2048_i32 step %c64_i32 iter_args(%arg3 = %cst_2, %arg4 = %cst, %arg5 = %cst_0) -> (tensor<128x64xf32, #mma>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) : i32 {
%6 = tt.load %2 : tensor<64x64x!tt.ptr<f16>, #blocked2>
%7 = triton_gpu.convert_layout %1 : tensor<128x64xf16, #blocked1> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
%8 = triton_gpu.convert_layout %6 : tensor<64x64xf16, #blocked2> -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
%9 = tt.dot %7, %8, %cst_2, inputPrecision = tf32 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma>
%10 = tt.load %3 : tensor<128x64x!tt.ptr<i8>, #blocked>
%11 = tt.load %4 : tensor<128x64x!tt.ptr<i8>, #blocked>
%12 = arith.cmpi eq, %10, %11 : tensor<128x64xi8, #blocked>
%13 = triton_gpu.convert_layout %12 : tensor<128x64xi1, #blocked> -> tensor<128x64xi1, #mma>
%14 = arith.select %13, %9, %cst_1 : tensor<128x64xi1, #mma>, tensor<128x64xf32, #mma>
%15 = triton_gpu.convert_layout %14 : tensor<128x64xf32, #mma> -> tensor<128x64xf32, #blocked>
%16 = "tt.reduce"(%15) <{axis = 1 : i32}> ({
^bb0(%arg6: f32, %arg7: f32):
%34 = arith.maxnumf %arg6, %arg7 : f32
tt.reduce.return %34 : f32
}) : (tensor<128x64xf32, #blocked>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%17 = arith.maxnumf %arg5, %16 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%18 = arith.cmpf oeq, %17, %cst_0 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%19 = triton_gpu.convert_layout %18 : tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #mma}>>
%20 = arith.select %18, %cst, %17 : tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%21 = tt.expand_dims %19 {axis = 1 : i32} : tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xi1, #mma>
%22 = tt.broadcast %21 : tensor<128x1xi1, #mma> -> tensor<128x64xi1, #mma>
%23 = arith.select %22, %cst_2, %14 : tensor<128x64xi1, #mma>, tensor<128x64xf32, #mma>
%24 = triton_gpu.convert_layout %23 : tensor<128x64xf32, #mma> -> tensor<128x64xf32, #blocked>
%25 = arith.mulf %arg4, %cst : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%26 = triton_gpu.convert_layout %25 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
%27 = tt.expand_dims %26 {axis = 1 : i32} : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xf32, #mma>
%28 = tt.broadcast %27 : tensor<128x1xf32, #mma> -> tensor<128x64xf32, #mma>
%29 = arith.mulf %arg3, %28 : tensor<128x64xf32, #mma>
%30 = triton_gpu.convert_layout %23 : tensor<128x64xf32, #mma> -> tensor<128x64xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
%31 = arith.mulf %arg4, %20 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%32 = "tt.reduce"(%24) <{axis = 1 : i32}> ({
^bb0(%arg6: f32, %arg7: f32):
%34 = arith.addf %arg6, %arg7 : f32
tt.reduce.return %34 : f32
}) : (tensor<128x64xf32, #blocked>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%33 = arith.addf %31, %32 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
scf.yield %29, %33, %17 : tensor<128x64xf32, #mma>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
}
tt.return %5#1, %5#2 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
}
}

0 comments on commit b90b3a0

Please sign in to comment.