diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 3ce43e71d2a7..b60a73f80c8c 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -709,39 +709,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // CHECK-LABEL: convert_layout_blocked_blocked tt.func @convert_layout_blocked_blocked(%arg0: tensor<16x16xf32, #blocked0>) { // CHECK: llvm.mlir.addressof @global_smem - // CHECK: llvm.inline_asm - // CHECK: st.shared - // CHECK: llvm.inline_asm - // CHECK: st.shared - // CHECK: llvm.inline_asm - // CHECK: st.shared - // CHECK: llvm.inline_asm - // CHECK: st.shared - // CHECK: llvm.inline_asm - // CHECK: st.shared - // CHECK: llvm.inline_asm - // CHECK: st.shared - // CHECK: llvm.inline_asm - // CHECK: st.shared - // CHECK: llvm.inline_asm - // CHECK: st.shared - // CHECK: nvvm.barrier0 - // CHECK: llvm.inline_asm - // CHECK: ld.shared - // CHECK: llvm.inline_asm - // CHECK: ld.shared - // CHECK: llvm.inline_asm - // CHECK: ld.shared - // CHECK: llvm.inline_asm - // CHECK: ld.shared - // CHECK: llvm.inline_asm - // CHECK: ld.shared - // CHECK: llvm.inline_asm - // CHECK: ld.shared - // CHECK: llvm.inline_asm - // CHECK: ld.shared - // CHECK: llvm.inline_asm - // CHECK: ld.shared + // CHECK-COUNT-8: llvm.inline_asm {{.*}} st.shared + // CHECK-: nvvm.barrier0 + // CHECK-COUNT-8: llvm.load %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf32, #blocked0> -> tensor<16x16xf32, #blocked1> tt.return } @@ -761,10 +731,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // CHECK: llvm.inline_asm // CHECK: st.shared // CHECK: nvvm.barrier0 - // CHECK: llvm.inline_asm - // CHECK: ld.shared - // CHECK: llvm.inline_asm - // CHECK: ld.shared + // CHECK: llvm.load + // CHECK: llvm.load %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf32, #blocked0> -> tensor<16x16xf32, #blocked1> tt.return } @@ -782,18 +750,14 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // CHECK: llvm.inline_asm // CHECK: st.shared // CHECK: nvvm.barrier0 - // CHECK: llvm.inline_asm - // CHECK: ld.shared - // CHECK: llvm.inline_asm - // CHECK: ld.shared + // CHECK: llvm.load + // CHECK: llvm.load // CHECK: nvvm.barrier0 // CHECK: llvm.inline_asm // CHECK: st.shared // CHECK: nvvm.barrier0 - // CHECK: llvm.inline_asm - // CHECK: ld.shared - // CHECK: llvm.inline_asm - // CHECK: ld.shared + // CHECK: llvm.load + // CHECK: llvm.load %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf32, #blocked0> -> tensor<16x16xf32, #blocked1> tt.return } @@ -851,7 +815,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: llvm.inline_asm // CHECK-SAME: st.shared // CHECK: nvvm.barrier0 - // CHECK: ld.shared + // CHECK: llvm.load %0 = triton_gpu.convert_layout %arg0 : tensor<32x16xf32, #mma> -> tensor<32x16xf32, #blocked0> tt.return } @@ -891,7 +855,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : tt.func @convert_layout_mmav3_transpose(%arg0: tensor<128x256xf8E5M2, #mma>) { // CHECK-COUNT-128: st.shared.b8 // CHECK: nvvm.barrier0 - // CHECK-COUNT-8: ld.shared.v4.b32 + // CHECK-COUNT-8: llvm.load {{.*}} -> vector<4xi32> %0 = triton_gpu.convert_layout %arg0 : tensor<128x256xf8E5M2, #mma> -> tensor<128x256xf8E5M2, #blocked> tt.return } @@ -920,7 +884,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { // CHECK-LABEL: convert_blocked1d_to_slice0 tt.func @convert_blocked1d_to_slice0(%src:tensor<32xi32, #blocked0>) { - // CHECK: inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@${{.*}} ld.shared.v4.b32 + // CHECK: llvm.load {{.*}} -> vector<4xi32> %cvt = triton_gpu.convert_layout %src : tensor<32xi32, #blocked0> -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> tt.return } @@ -933,7 +897,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { // CHECK-LABEL: convert_blocked1d_to_slice1 tt.func @convert_blocked1d_to_slice1(%src:tensor<32xi32, #blocked0>) { - // CHECK-COUNT-8: inline_asm{{.*}}ld.shared.b32 + // CHECK-COUNT-8: llvm.load {{.*}} -> i32 %cvt = triton_gpu.convert_layout %src : tensor<32xi32, #blocked0> -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> tt.return } diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp index 30d4639cbf41..1f96ae8cd005 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp @@ -290,6 +290,13 @@ static std::string getConstraintForBitwidth(unsigned bitwidth) { } } +static bool isTruePred(Value pred) { + if (auto constOp = pred.getDefiningOp()) { + return cast(constOp.getValue()).getInt() != 0; + } + return false; +} + void TargetInfo::storeDShared(RewriterBase &rewriter, Location loc, Value ptr, std::optional ctaId, Value val, Value pred) const { @@ -501,16 +508,32 @@ Value TargetInfo::loadDShared(RewriterBase &rewriter, Location loc, Value ptr, .v(vec, /*predicate=*/vec > 1) .b(elemBitwidth); - std::string elemConstraint = "=" + getConstraintForBitwidth(elemBitwidth); - auto *outOpr = vec == 1 ? builder.newOperand(elemConstraint) - : builder.newListOperand(vec, elemConstraint); - ld(outOpr, builder.newAddrOperand(ptr, "r")).predicate(pred, "b"); - - Type resultTy = - vec == 1 ? Type(int_ty(elemBitwidth)) - : Type(struct_ty(SmallVector(vec, int_ty(elemBitwidth)))); - Value load = builder.launch(rewriter, loc, resultTy, /*hasSideEffects=*/true); - + Value load; + if (isTruePred(pred)) { + Type resultTy = vec == 1 ? Type(int_ty(elemBitwidth)) + : Type(vec_ty(int_ty(elemBitwidth), vec)); + load = load(resultTy, ptr); + if (vec > 1) { + Type structTy = struct_ty(SmallVector(vec, int_ty(elemBitwidth))); + Value structValue = undef(structTy); + for (int i = 0; i < vec; i++) { + structValue = insert_val(structTy, structValue, + extract_element(load, i32_val(i)), i); + } + load = structValue; + } + } else { + std::string elemConstraint = "=" + getConstraintForBitwidth(elemBitwidth); + auto *outOpr = vec == 1 ? builder.newOperand(elemConstraint) + : builder.newListOperand(vec, elemConstraint); + ld(outOpr, builder.newAddrOperand(ptr, "r")).predicate(pred, "b"); + + Type resultTy = + vec == 1 + ? Type(int_ty(elemBitwidth)) + : Type(struct_ty(SmallVector(vec, int_ty(elemBitwidth)))); + load = builder.launch(rewriter, loc, resultTy, /*hasSideEffects=*/true); + } SmallVector resultVals = unpackLLElements(loc, load, rewriter); return packLLVector(loc, resultVals, rewriter); }