Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BACKEND] Switch back to use llvm.load for shared memory load #4776

Merged
merged 2 commits into from
Sep 21, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 13 additions & 49 deletions test/Conversion/tritongpu_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -709,39 +709,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
// CHECK-LABEL: convert_layout_blocked_blocked
tt.func @convert_layout_blocked_blocked(%arg0: tensor<16x16xf32, #blocked0>) {
// CHECK: llvm.mlir.addressof @global_smem
// CHECK: llvm.inline_asm
// CHECK: st.shared
// CHECK: llvm.inline_asm
// CHECK: st.shared
// CHECK: llvm.inline_asm
// CHECK: st.shared
// CHECK: llvm.inline_asm
// CHECK: st.shared
// CHECK: llvm.inline_asm
// CHECK: st.shared
// CHECK: llvm.inline_asm
// CHECK: st.shared
// CHECK: llvm.inline_asm
// CHECK: st.shared
// CHECK: llvm.inline_asm
// CHECK: st.shared
// CHECK: nvvm.barrier0
// CHECK: llvm.inline_asm
// CHECK: ld.shared
// CHECK: llvm.inline_asm
// CHECK: ld.shared
// CHECK: llvm.inline_asm
// CHECK: ld.shared
// CHECK: llvm.inline_asm
// CHECK: ld.shared
// CHECK: llvm.inline_asm
// CHECK: ld.shared
// CHECK: llvm.inline_asm
// CHECK: ld.shared
// CHECK: llvm.inline_asm
// CHECK: ld.shared
// CHECK: llvm.inline_asm
// CHECK: ld.shared
// CHECK-COUNT-8: llvm.inline_asm {{.*}} st.shared
// CHECK-: nvvm.barrier0
// CHECK-COUNT-8: llvm.load
%0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf32, #blocked0> -> tensor<16x16xf32, #blocked1>
tt.return
}
Expand All @@ -761,10 +731,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
// CHECK: llvm.inline_asm
// CHECK: st.shared
// CHECK: nvvm.barrier0
// CHECK: llvm.inline_asm
// CHECK: ld.shared
// CHECK: llvm.inline_asm
// CHECK: ld.shared
// CHECK: llvm.load
// CHECK: llvm.load
%0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf32, #blocked0> -> tensor<16x16xf32, #blocked1>
tt.return
}
Expand All @@ -782,18 +750,14 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
// CHECK: llvm.inline_asm
// CHECK: st.shared
// CHECK: nvvm.barrier0
// CHECK: llvm.inline_asm
// CHECK: ld.shared
// CHECK: llvm.inline_asm
// CHECK: ld.shared
// CHECK: llvm.load
// CHECK: llvm.load
// CHECK: nvvm.barrier0
// CHECK: llvm.inline_asm
// CHECK: st.shared
// CHECK: nvvm.barrier0
// CHECK: llvm.inline_asm
// CHECK: ld.shared
// CHECK: llvm.inline_asm
// CHECK: ld.shared
// CHECK: llvm.load
// CHECK: llvm.load
%0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf32, #blocked0> -> tensor<16x16xf32, #blocked1>
tt.return
}
Expand Down Expand Up @@ -851,7 +815,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
// CHECK: llvm.inline_asm
// CHECK-SAME: st.shared
// CHECK: nvvm.barrier0
// CHECK: ld.shared
// CHECK: llvm.load
%0 = triton_gpu.convert_layout %arg0 : tensor<32x16xf32, #mma> -> tensor<32x16xf32, #blocked0>
tt.return
}
Expand Down Expand Up @@ -891,7 +855,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 :
tt.func @convert_layout_mmav3_transpose(%arg0: tensor<128x256xf8E5M2, #mma>) {
// CHECK-COUNT-128: st.shared.b8
// CHECK: nvvm.barrier0
// CHECK-COUNT-8: ld.shared.v4.b32
// CHECK-COUNT-8: llvm.load {{.*}} -> vector<4xi32>
%0 = triton_gpu.convert_layout %arg0 : tensor<128x256xf8E5M2, #mma> -> tensor<128x256xf8E5M2, #blocked>
tt.return
}
Expand Down Expand Up @@ -920,7 +884,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 :
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: convert_blocked1d_to_slice0
tt.func @convert_blocked1d_to_slice0(%src:tensor<32xi32, #blocked0>) {
// CHECK: inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@${{.*}} ld.shared.v4.b32
// CHECK: llvm.load {{.*}} -> vector<4xi32>
%cvt = triton_gpu.convert_layout %src : tensor<32xi32, #blocked0> -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
tt.return
}
Expand All @@ -933,7 +897,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: convert_blocked1d_to_slice1
tt.func @convert_blocked1d_to_slice1(%src:tensor<32xi32, #blocked0>) {
// CHECK-COUNT-8: inline_asm{{.*}}ld.shared.b32
// CHECK-COUNT-8: llvm.load {{.*}} -> i32
%cvt = triton_gpu.convert_layout %src : tensor<32xi32, #blocked0> -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
tt.return
}
Expand Down
43 changes: 33 additions & 10 deletions third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,13 @@ static std::string getConstraintForBitwidth(unsigned bitwidth) {
}
}

static bool isTruePred(Value pred) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about isConstantTruePred?

if (auto constOp = pred.getDefiningOp<LLVM::ConstantOp>()) {
return cast<IntegerAttr>(constOp.getValue()).getInt() != 0;
}
return false;
}

void TargetInfo::storeDShared(RewriterBase &rewriter, Location loc, Value ptr,
std::optional<Value> ctaId, Value val,
Value pred) const {
Expand Down Expand Up @@ -501,16 +508,32 @@ Value TargetInfo::loadDShared(RewriterBase &rewriter, Location loc, Value ptr,
.v(vec, /*predicate=*/vec > 1)
.b(elemBitwidth);

std::string elemConstraint = "=" + getConstraintForBitwidth(elemBitwidth);
auto *outOpr = vec == 1 ? builder.newOperand(elemConstraint)
: builder.newListOperand(vec, elemConstraint);
ld(outOpr, builder.newAddrOperand(ptr, "r")).predicate(pred, "b");

Type resultTy =
vec == 1 ? Type(int_ty(elemBitwidth))
: Type(struct_ty(SmallVector<Type>(vec, int_ty(elemBitwidth))));
Value load = builder.launch(rewriter, loc, resultTy, /*hasSideEffects=*/true);

Value load;
if (isTruePred(pred)) {
Type resultTy = vec == 1 ? Type(int_ty(elemBitwidth))
: Type(vec_ty(int_ty(elemBitwidth), vec));
load = load(resultTy, ptr);
if (vec > 1) {
Type structTy = struct_ty(SmallVector<Type>(vec, int_ty(elemBitwidth)));
Value structValue = undef(structTy);
for (int i = 0; i < vec; i++) {
structValue = insert_val(structTy, structValue,
extract_element(load, i32_val(i)), i);
}
load = structValue;
}
} else {
std::string elemConstraint = "=" + getConstraintForBitwidth(elemBitwidth);
auto *outOpr = vec == 1 ? builder.newOperand(elemConstraint)
: builder.newListOperand(vec, elemConstraint);
ld(outOpr, builder.newAddrOperand(ptr, "r")).predicate(pred, "b");

Type resultTy =
vec == 1
? Type(int_ty(elemBitwidth))
: Type(struct_ty(SmallVector<Type>(vec, int_ty(elemBitwidth))));
load = builder.launch(rewriter, loc, resultTy, /*hasSideEffects=*/true);
}
SmallVector<Value> resultVals = unpackLLElements(loc, load, rewriter);
return packLLVector(loc, resultVals, rewriter);
}
Expand Down
Loading