Skip to content

Commit

Permalink
[BACKEND] Switch back to use llvm.load for shared memory load
Browse files Browse the repository at this point in the history
When we don't have predicates we can use llvm.load. Using
inline asm for i8 types can cause inneficient code generation in llvm
due to the interaction with DAG legalizer.
  • Loading branch information
ThomasRaoux committed Sep 21, 2024
1 parent 15734f6 commit 40d5caf
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 59 deletions.
62 changes: 13 additions & 49 deletions test/Conversion/tritongpu_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -709,39 +709,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
// CHECK-LABEL: convert_layout_blocked_blocked
tt.func @convert_layout_blocked_blocked(%arg0: tensor<16x16xf32, #blocked0>) {
// CHECK: llvm.mlir.addressof @global_smem
// CHECK: llvm.inline_asm
// CHECK: st.shared
// CHECK: llvm.inline_asm
// CHECK: st.shared
// CHECK: llvm.inline_asm
// CHECK: st.shared
// CHECK: llvm.inline_asm
// CHECK: st.shared
// CHECK: llvm.inline_asm
// CHECK: st.shared
// CHECK: llvm.inline_asm
// CHECK: st.shared
// CHECK: llvm.inline_asm
// CHECK: st.shared
// CHECK: llvm.inline_asm
// CHECK: st.shared
// CHECK: nvvm.barrier0
// CHECK: llvm.inline_asm
// CHECK: ld.shared
// CHECK: llvm.inline_asm
// CHECK: ld.shared
// CHECK: llvm.inline_asm
// CHECK: ld.shared
// CHECK: llvm.inline_asm
// CHECK: ld.shared
// CHECK: llvm.inline_asm
// CHECK: ld.shared
// CHECK: llvm.inline_asm
// CHECK: ld.shared
// CHECK: llvm.inline_asm
// CHECK: ld.shared
// CHECK: llvm.inline_asm
// CHECK: ld.shared
// CHECK-COUNT-8: llvm.inline_asm {{.*}} st.shared
// CHECK-: nvvm.barrier0
// CHECK-COUNT-8: llvm.load
%0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf32, #blocked0> -> tensor<16x16xf32, #blocked1>
tt.return
}
Expand All @@ -761,10 +731,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
// CHECK: llvm.inline_asm
// CHECK: st.shared
// CHECK: nvvm.barrier0
// CHECK: llvm.inline_asm
// CHECK: ld.shared
// CHECK: llvm.inline_asm
// CHECK: ld.shared
// CHECK: llvm.load
// CHECK: llvm.load
%0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf32, #blocked0> -> tensor<16x16xf32, #blocked1>
tt.return
}
Expand All @@ -782,18 +750,14 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
// CHECK: llvm.inline_asm
// CHECK: st.shared
// CHECK: nvvm.barrier0
// CHECK: llvm.inline_asm
// CHECK: ld.shared
// CHECK: llvm.inline_asm
// CHECK: ld.shared
// CHECK: llvm.load
// CHECK: llvm.load
// CHECK: nvvm.barrier0
// CHECK: llvm.inline_asm
// CHECK: st.shared
// CHECK: nvvm.barrier0
// CHECK: llvm.inline_asm
// CHECK: ld.shared
// CHECK: llvm.inline_asm
// CHECK: ld.shared
// CHECK: llvm.load
// CHECK: llvm.load
%0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf32, #blocked0> -> tensor<16x16xf32, #blocked1>
tt.return
}
Expand Down Expand Up @@ -851,7 +815,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
// CHECK: llvm.inline_asm
// CHECK-SAME: st.shared
// CHECK: nvvm.barrier0
// CHECK: ld.shared
// CHECK: llvm.load
%0 = triton_gpu.convert_layout %arg0 : tensor<32x16xf32, #mma> -> tensor<32x16xf32, #blocked0>
tt.return
}
Expand Down Expand Up @@ -891,7 +855,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 :
tt.func @convert_layout_mmav3_transpose(%arg0: tensor<128x256xf8E5M2, #mma>) {
// CHECK-COUNT-128: st.shared.b8
// CHECK: nvvm.barrier0
// CHECK-COUNT-8: ld.shared.v4.b32
// CHECK-COUNT-8: llvm.load {{.*}} -> vector<4xi32>
%0 = triton_gpu.convert_layout %arg0 : tensor<128x256xf8E5M2, #mma> -> tensor<128x256xf8E5M2, #blocked>
tt.return
}
Expand Down Expand Up @@ -920,7 +884,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 :
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: convert_blocked1d_to_slice0
tt.func @convert_blocked1d_to_slice0(%src:tensor<32xi32, #blocked0>) {
// CHECK: inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@${{.*}} ld.shared.v4.b32
// CHECK: llvm.load {{.*}} -> vector<4xi32>
%cvt = triton_gpu.convert_layout %src : tensor<32xi32, #blocked0> -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
tt.return
}
Expand All @@ -933,7 +897,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: convert_blocked1d_to_slice1
tt.func @convert_blocked1d_to_slice1(%src:tensor<32xi32, #blocked0>) {
// CHECK-COUNT-8: inline_asm{{.*}}ld.shared.b32
// CHECK-COUNT-8: llvm.load {{.*}} -> i32
%cvt = triton_gpu.convert_layout %src : tensor<32xi32, #blocked0> -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
tt.return
}
Expand Down
43 changes: 33 additions & 10 deletions third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,13 @@ static std::string getConstraintForBitwidth(unsigned bitwidth) {
}
}

static bool isTruePred(Value pred) {
if (auto constOp = pred.getDefiningOp<LLVM::ConstantOp>()) {
return cast<IntegerAttr>(constOp.getValue()).getInt() != 0;
}
return false;
}

void TargetInfo::storeDShared(RewriterBase &rewriter, Location loc, Value ptr,
std::optional<Value> ctaId, Value val,
Value pred) const {
Expand Down Expand Up @@ -501,16 +508,32 @@ Value TargetInfo::loadDShared(RewriterBase &rewriter, Location loc, Value ptr,
.v(vec, /*predicate=*/vec > 1)
.b(elemBitwidth);

std::string elemConstraint = "=" + getConstraintForBitwidth(elemBitwidth);
auto *outOpr = vec == 1 ? builder.newOperand(elemConstraint)
: builder.newListOperand(vec, elemConstraint);
ld(outOpr, builder.newAddrOperand(ptr, "r")).predicate(pred, "b");

Type resultTy =
vec == 1 ? Type(int_ty(elemBitwidth))
: Type(struct_ty(SmallVector<Type>(vec, int_ty(elemBitwidth))));
Value load = builder.launch(rewriter, loc, resultTy, /*hasSideEffects=*/true);

Value load;
if (isTruePred(pred)) {
Type resultTy = vec == 1 ? Type(int_ty(elemBitwidth))
: Type(vec_ty(int_ty(elemBitwidth), vec));
load = load(resultTy, ptr);
if (vec > 1) {
Type structTy = struct_ty(SmallVector<Type>(vec, int_ty(elemBitwidth)));
Value structValue = undef(structTy);
for (int i = 0; i < vec; i++) {
structValue = insert_val(structTy, structValue,
extract_element(load, i32_val(i)), i);
}
load = structValue;
}
} else {
std::string elemConstraint = "=" + getConstraintForBitwidth(elemBitwidth);
auto *outOpr = vec == 1 ? builder.newOperand(elemConstraint)
: builder.newListOperand(vec, elemConstraint);
ld(outOpr, builder.newAddrOperand(ptr, "r")).predicate(pred, "b");

Type resultTy =
vec == 1
? Type(int_ty(elemBitwidth))
: Type(struct_ty(SmallVector<Type>(vec, int_ty(elemBitwidth))));
load = builder.launch(rewriter, loc, resultTy, /*hasSideEffects=*/true);
}
SmallVector<Value> resultVals = unpackLLElements(loc, load, rewriter);
return packLLVector(loc, resultVals, rewriter);
}
Expand Down

0 comments on commit 40d5caf

Please sign in to comment.