From 40d5cafa37b8458cac9d461a266fb60d4f637014 Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Fri, 20 Sep 2024 18:31:38 -0700 Subject: [PATCH 1/2] [BACKEND] Switch back to use llvm.load for shared memory load When we don't have predicates we can use llvm.load. Using inline asm for i8 types can cause inneficient code generation in llvm due to the interaction with DAG legalizer. --- test/Conversion/tritongpu_to_llvm.mlir | 62 ++++--------------- .../lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp | 43 ++++++++++--- 2 files changed, 46 insertions(+), 59 deletions(-) diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 3ce43e71d2a7..b60a73f80c8c 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -709,39 +709,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // CHECK-LABEL: convert_layout_blocked_blocked tt.func @convert_layout_blocked_blocked(%arg0: tensor<16x16xf32, #blocked0>) { // CHECK: llvm.mlir.addressof @global_smem - // CHECK: llvm.inline_asm - // CHECK: st.shared - // CHECK: llvm.inline_asm - // CHECK: st.shared - // CHECK: llvm.inline_asm - // CHECK: st.shared - // CHECK: llvm.inline_asm - // CHECK: st.shared - // CHECK: llvm.inline_asm - // CHECK: st.shared - // CHECK: llvm.inline_asm - // CHECK: st.shared - // CHECK: llvm.inline_asm - // CHECK: st.shared - // CHECK: llvm.inline_asm - // CHECK: st.shared - // CHECK: nvvm.barrier0 - // CHECK: llvm.inline_asm - // CHECK: ld.shared - // CHECK: llvm.inline_asm - // CHECK: ld.shared - // CHECK: llvm.inline_asm - // CHECK: ld.shared - // CHECK: llvm.inline_asm - // CHECK: ld.shared - // CHECK: llvm.inline_asm - // CHECK: ld.shared - // CHECK: llvm.inline_asm - // CHECK: ld.shared - // CHECK: llvm.inline_asm - // CHECK: ld.shared - // CHECK: llvm.inline_asm - // CHECK: ld.shared + // CHECK-COUNT-8: llvm.inline_asm {{.*}} st.shared + // CHECK-: nvvm.barrier0 + // CHECK-COUNT-8: llvm.load %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf32, #blocked0> -> tensor<16x16xf32, #blocked1> tt.return } @@ -761,10 +731,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // CHECK: llvm.inline_asm // CHECK: st.shared // CHECK: nvvm.barrier0 - // CHECK: llvm.inline_asm - // CHECK: ld.shared - // CHECK: llvm.inline_asm - // CHECK: ld.shared + // CHECK: llvm.load + // CHECK: llvm.load %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf32, #blocked0> -> tensor<16x16xf32, #blocked1> tt.return } @@ -782,18 +750,14 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // CHECK: llvm.inline_asm // CHECK: st.shared // CHECK: nvvm.barrier0 - // CHECK: llvm.inline_asm - // CHECK: ld.shared - // CHECK: llvm.inline_asm - // CHECK: ld.shared + // CHECK: llvm.load + // CHECK: llvm.load // CHECK: nvvm.barrier0 // CHECK: llvm.inline_asm // CHECK: st.shared // CHECK: nvvm.barrier0 - // CHECK: llvm.inline_asm - // CHECK: ld.shared - // CHECK: llvm.inline_asm - // CHECK: ld.shared + // CHECK: llvm.load + // CHECK: llvm.load %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf32, #blocked0> -> tensor<16x16xf32, #blocked1> tt.return } @@ -851,7 +815,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: llvm.inline_asm // CHECK-SAME: st.shared // CHECK: nvvm.barrier0 - // CHECK: ld.shared + // CHECK: llvm.load %0 = triton_gpu.convert_layout %arg0 : tensor<32x16xf32, #mma> -> tensor<32x16xf32, #blocked0> tt.return } @@ -891,7 +855,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : tt.func @convert_layout_mmav3_transpose(%arg0: tensor<128x256xf8E5M2, #mma>) { // CHECK-COUNT-128: st.shared.b8 // CHECK: nvvm.barrier0 - // CHECK-COUNT-8: ld.shared.v4.b32 + // CHECK-COUNT-8: llvm.load {{.*}} -> vector<4xi32> %0 = triton_gpu.convert_layout %arg0 : tensor<128x256xf8E5M2, #mma> -> tensor<128x256xf8E5M2, #blocked> tt.return } @@ -920,7 +884,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { // CHECK-LABEL: convert_blocked1d_to_slice0 tt.func @convert_blocked1d_to_slice0(%src:tensor<32xi32, #blocked0>) { - // CHECK: inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@${{.*}} ld.shared.v4.b32 + // CHECK: llvm.load {{.*}} -> vector<4xi32> %cvt = triton_gpu.convert_layout %src : tensor<32xi32, #blocked0> -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> tt.return } @@ -933,7 +897,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { // CHECK-LABEL: convert_blocked1d_to_slice1 tt.func @convert_blocked1d_to_slice1(%src:tensor<32xi32, #blocked0>) { - // CHECK-COUNT-8: inline_asm{{.*}}ld.shared.b32 + // CHECK-COUNT-8: llvm.load {{.*}} -> i32 %cvt = triton_gpu.convert_layout %src : tensor<32xi32, #blocked0> -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> tt.return } diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp index 30d4639cbf41..1f96ae8cd005 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp @@ -290,6 +290,13 @@ static std::string getConstraintForBitwidth(unsigned bitwidth) { } } +static bool isTruePred(Value pred) { + if (auto constOp = pred.getDefiningOp()) { + return cast(constOp.getValue()).getInt() != 0; + } + return false; +} + void TargetInfo::storeDShared(RewriterBase &rewriter, Location loc, Value ptr, std::optional ctaId, Value val, Value pred) const { @@ -501,16 +508,32 @@ Value TargetInfo::loadDShared(RewriterBase &rewriter, Location loc, Value ptr, .v(vec, /*predicate=*/vec > 1) .b(elemBitwidth); - std::string elemConstraint = "=" + getConstraintForBitwidth(elemBitwidth); - auto *outOpr = vec == 1 ? builder.newOperand(elemConstraint) - : builder.newListOperand(vec, elemConstraint); - ld(outOpr, builder.newAddrOperand(ptr, "r")).predicate(pred, "b"); - - Type resultTy = - vec == 1 ? Type(int_ty(elemBitwidth)) - : Type(struct_ty(SmallVector(vec, int_ty(elemBitwidth)))); - Value load = builder.launch(rewriter, loc, resultTy, /*hasSideEffects=*/true); - + Value load; + if (isTruePred(pred)) { + Type resultTy = vec == 1 ? Type(int_ty(elemBitwidth)) + : Type(vec_ty(int_ty(elemBitwidth), vec)); + load = load(resultTy, ptr); + if (vec > 1) { + Type structTy = struct_ty(SmallVector(vec, int_ty(elemBitwidth))); + Value structValue = undef(structTy); + for (int i = 0; i < vec; i++) { + structValue = insert_val(structTy, structValue, + extract_element(load, i32_val(i)), i); + } + load = structValue; + } + } else { + std::string elemConstraint = "=" + getConstraintForBitwidth(elemBitwidth); + auto *outOpr = vec == 1 ? builder.newOperand(elemConstraint) + : builder.newListOperand(vec, elemConstraint); + ld(outOpr, builder.newAddrOperand(ptr, "r")).predicate(pred, "b"); + + Type resultTy = + vec == 1 + ? Type(int_ty(elemBitwidth)) + : Type(struct_ty(SmallVector(vec, int_ty(elemBitwidth)))); + load = builder.launch(rewriter, loc, resultTy, /*hasSideEffects=*/true); + } SmallVector resultVals = unpackLLElements(loc, load, rewriter); return packLLVector(loc, resultVals, rewriter); } From b1b222565ddff3fc645cb4aa5532fb824c978dd7 Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Fri, 20 Sep 2024 18:42:09 -0700 Subject: [PATCH 2/2] address review comment --- third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp index 1f96ae8cd005..5813b9679ef0 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp @@ -290,7 +290,7 @@ static std::string getConstraintForBitwidth(unsigned bitwidth) { } } -static bool isTruePred(Value pred) { +static bool isConstantTruePred(Value pred) { if (auto constOp = pred.getDefiningOp()) { return cast(constOp.getValue()).getInt() != 0; } @@ -509,7 +509,7 @@ Value TargetInfo::loadDShared(RewriterBase &rewriter, Location loc, Value ptr, .b(elemBitwidth); Value load; - if (isTruePred(pred)) { + if (isConstantTruePred(pred)) { Type resultTy = vec == 1 ? Type(int_ty(elemBitwidth)) : Type(vec_ty(int_ty(elemBitwidth), vec)); load = load(resultTy, ptr);