From b90b3a077be48e2e1bbe63e4dc5951e4c0241536 Mon Sep 17 00:00:00 2001 From: Andrew James Date: Fri, 7 Jun 2024 13:05:17 -0500 Subject: [PATCH] [BACKEND] Prevent for/yield argument number drift (#4097) In the current implementation, when backward rematerialization encounters a loop argument that has already been rematerialized, the process short-circuits the collection of yield operations, leaving the value in the slice. However, if another loop argument is present in the same slice, the loop is collected again, resulting in the duplication of the first argument without generating the corresponding yield. To address this issue, this fix removes values from the slice that are skipped during collection, ensuring they are not reduplicated. This adjustment ensures that the number of yield operands and loop iter_args remain synchronized. --- .../Transforms/RemoveLayoutConversions.cpp | 6 ++ test/TritonGPU/combine.mlir | 72 +++++++++++++++++++ 2 files changed, 78 insertions(+) diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp index f7841b9cfa73..585d6670f162 100644 --- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp @@ -841,12 +841,17 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, SetVector opsToRewrite; // Keep track of yield operands that need to be duplicated. DenseMap> yieldOperandsMap; + // Keep these around to remove them from the slice after our collection pass + // This ensures we don't duplicate them during an for rewrite or causing the + // for/yield to fall out of sync + SetVector valuesWithExistingRemat; for (Value v : slice) { auto layoutIt = layout.find(v); assert(layoutIt != layout.end()); // If we already have a remat value for this value, use it. if (hasRematValue(v, layoutIt->second)) { mapping.map(v, getRematValue(v, layoutIt->second)); + valuesWithExistingRemat.insert(v); continue; } if (v.getDefiningOp()) { @@ -870,6 +875,7 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, } } } + slice.set_subtract(valuesWithExistingRemat); opsToRewrite = multiRootTopologicalSort(opsToRewrite); // replaceAllUsesWith calls delayed until after initial rewrite. diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir index e48b56decff4..9118bc4f2fc6 100644 --- a/test/TritonGPU/combine.mlir +++ b/test/TritonGPU/combine.mlir @@ -2396,3 +2396,75 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : } } + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + // Regression test: + // Rematerialization of multiple loop-carried variables, where one is + // rematerialized to the same layout by multiple users. + // Previously this didn't interact correctly with the de-duplication mechanism. + // CHECK-LABEL: @multi_rematerialize_loop_arg + tt.func public @multi_rematerialize_loop_arg(%arg0: !tt.ptr, %arg1: !tt.ptr) -> (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) { + %c0_i32 = arith.constant 0 : i32 + %c64_i32 = arith.constant 64 : i32 + %c2048_i32 = arith.constant 2048 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %cst_0 = arith.constant dense<0xFF800000> : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %cst_1 = arith.constant dense<0xFF800000> : tensor<128x64xf32, #mma> + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> + %0 = tt.splat %arg0 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked1> + %1 = tt.load %0 : tensor<128x64x!tt.ptr, #blocked1> + %2 = tt.splat %arg0 : !tt.ptr -> tensor<64x64x!tt.ptr, #blocked2> + %3 = tt.splat %arg1 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked> + %4 = tt.splat %arg1 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked> + // CHECK: %[[F:.+]]:3 = scf.for {{.*}} -> (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) + // CHECK: scf.yield {{.*}} : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + // CHECK: } + // CHECK: tt.return %[[F]]#0, %[[F]]#1 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %5:3 = scf.for %arg2 = %c0_i32 to %c2048_i32 step %c64_i32 iter_args(%arg3 = %cst_2, %arg4 = %cst, %arg5 = %cst_0) -> (tensor<128x64xf32, #mma>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) : i32 { + %6 = tt.load %2 : tensor<64x64x!tt.ptr, #blocked2> + %7 = triton_gpu.convert_layout %1 : tensor<128x64xf16, #blocked1> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %8 = triton_gpu.convert_layout %6 : tensor<64x64xf16, #blocked2> -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %9 = tt.dot %7, %8, %cst_2, inputPrecision = tf32 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma> + %10 = tt.load %3 : tensor<128x64x!tt.ptr, #blocked> + %11 = tt.load %4 : tensor<128x64x!tt.ptr, #blocked> + %12 = arith.cmpi eq, %10, %11 : tensor<128x64xi8, #blocked> + %13 = triton_gpu.convert_layout %12 : tensor<128x64xi1, #blocked> -> tensor<128x64xi1, #mma> + %14 = arith.select %13, %9, %cst_1 : tensor<128x64xi1, #mma>, tensor<128x64xf32, #mma> + %15 = triton_gpu.convert_layout %14 : tensor<128x64xf32, #mma> -> tensor<128x64xf32, #blocked> + %16 = "tt.reduce"(%15) <{axis = 1 : i32}> ({ + ^bb0(%arg6: f32, %arg7: f32): + %34 = arith.maxnumf %arg6, %arg7 : f32 + tt.reduce.return %34 : f32 + }) : (tensor<128x64xf32, #blocked>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %17 = arith.maxnumf %arg5, %16 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %18 = arith.cmpf oeq, %17, %cst_0 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %19 = triton_gpu.convert_layout %18 : tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %20 = arith.select %18, %cst, %17 : tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %21 = tt.expand_dims %19 {axis = 1 : i32} : tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xi1, #mma> + %22 = tt.broadcast %21 : tensor<128x1xi1, #mma> -> tensor<128x64xi1, #mma> + %23 = arith.select %22, %cst_2, %14 : tensor<128x64xi1, #mma>, tensor<128x64xf32, #mma> + %24 = triton_gpu.convert_layout %23 : tensor<128x64xf32, #mma> -> tensor<128x64xf32, #blocked> + %25 = arith.mulf %arg4, %cst : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %26 = triton_gpu.convert_layout %25 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %27 = tt.expand_dims %26 {axis = 1 : i32} : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xf32, #mma> + %28 = tt.broadcast %27 : tensor<128x1xf32, #mma> -> tensor<128x64xf32, #mma> + %29 = arith.mulf %arg3, %28 : tensor<128x64xf32, #mma> + %30 = triton_gpu.convert_layout %23 : tensor<128x64xf32, #mma> -> tensor<128x64xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %31 = arith.mulf %arg4, %20 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %32 = "tt.reduce"(%24) <{axis = 1 : i32}> ({ + ^bb0(%arg6: f32, %arg7: f32): + %34 = arith.addf %arg6, %arg7 : f32 + tt.reduce.return %34 : f32 + }) : (tensor<128x64xf32, #blocked>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %33 = arith.addf %31, %32 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + scf.yield %29, %33, %17 : tensor<128x64xf32, #mma>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + } + tt.return %5#1, %5#2 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + } +}