Skip to content

Commit

Permalink
[BACKEND] Adds back a printf function that takes StringRef msg. (#4386)
Browse files Browse the repository at this point in the history
  • Loading branch information
zahimoud committed Jul 24, 2024
1 parent 301fc18 commit 9c7f9f8
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 0 deletions.
10 changes: 10 additions & 0 deletions include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,16 @@ class TargetInfoBase {
// placeholders in the format string.
virtual void printf(RewriterBase &rewriter, Value formatStrStart,
int formatStrByteCount, ValueRange args) const = 0;

// Emits LLVM code with |rewriter| to print a message, particularly useful for
// backend debug. |msg| is the message to print, |args| are the arguments to
// fill placeholders in the |msg|.
// NOTE: This function is used for backend debug. DO NOT DELETE.
// Example use: targetInfo.printf(rewriter,"index: %d, value: %f", {index,
// value});
virtual void printf(RewriterBase &rewriter, StringRef msg,
ValueRange args) const = 0;

// Emits LLVM code with |rewriter| to perform assertion failure with the given
// |message| from the given |func| in |file|.
virtual void assertFail(RewriterBase &rewriter, Location loc,
Expand Down
12 changes: 12 additions & 0 deletions third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,18 @@ void TargetInfo::printf(RewriterBase &rewriter, Value formatStrStart,
/*useStdError=*/false);
}

void TargetInfo::printf(RewriterBase &rewriter, StringRef msg,
ValueRange args) const {
assert(!msg.empty() && "printf with empty string not supported");
llvm::SmallString<64> msgNewline(msg);
msgNewline.push_back('\n');
msgNewline.push_back('\0');
Value msgValue =
LLVM::addStringToModule(UnknownLoc::get(rewriter.getContext()), rewriter,
"printfFormat_", msgNewline);
printf(rewriter, msgValue, msgNewline.size_in_bytes(), args);
}

void TargetInfo::assertFail(RewriterBase &rewriter, Location loc,
StringRef message, StringRef file, StringRef func,
int line) const {
Expand Down
4 changes: 4 additions & 0 deletions third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ class TargetInfo : public mlir::triton::TargetInfoBase {

void printf(RewriterBase &rewriter, Value formatStrStart,
int formatStrByteCount, ValueRange args) const override;

void printf(RewriterBase &rewriter, StringRef msg,
ValueRange args) const override;

void assertFail(RewriterBase &rewriter, Location loc, StringRef message,
StringRef file, StringRef func, int line) const override;

Expand Down
12 changes: 12 additions & 0 deletions third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,18 @@ void TargetInfo::printf(RewriterBase &rewriter, Value formatStrStart,
call(funcOp, operands);
}

void TargetInfo::printf(RewriterBase &rewriter, StringRef msg,
ValueRange args) const {
assert(!msg.empty() && "printf with empty string not supported");
llvm::SmallString<64> msgNewline(msg);
msgNewline.push_back('\n');
msgNewline.push_back('\0');
Value msgValue =
LLVM::addStringToModule(UnknownLoc::get(rewriter.getContext()), rewriter,
"printfFormat_", msgNewline);
printf(rewriter, msgValue, msgNewline.size_in_bytes(), args);
}

void TargetInfo::assertFail(RewriterBase &rewriter, Location loc,
StringRef message, StringRef file, StringRef func,
int line) const {
Expand Down
4 changes: 4 additions & 0 deletions third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ class TargetInfo : public mlir::triton::TargetInfoBase {

void printf(RewriterBase &rewriter, Value formatStrStart,
int formatStrByteCount, ValueRange args) const override;

void printf(RewriterBase &rewriter, StringRef msg,
ValueRange args) const override;

void assertFail(RewriterBase &rewriter, Location loc, StringRef message,
StringRef file, StringRef func, int line) const override;

Expand Down

0 comments on commit 9c7f9f8

Please sign in to comment.