From 9c7f9f8e3f6c769d0c32d0c51265c04ba8b164ca Mon Sep 17 00:00:00 2001 From: Zahi Moudallal Date: Wed, 24 Jul 2024 15:33:12 -0700 Subject: [PATCH] [BACKEND] Adds back a printf function that takes StringRef msg. (#4386) --- .../Conversion/TritonGPUToLLVM/TargetInfoBase.h | 10 ++++++++++ .../amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp | 12 ++++++++++++ third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h | 4 ++++ .../nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp | 12 ++++++++++++ .../nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h | 4 ++++ 5 files changed, 42 insertions(+) diff --git a/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h b/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h index d2f463ea7499..ee43674c21b4 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h +++ b/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h @@ -68,6 +68,16 @@ class TargetInfoBase { // placeholders in the format string. virtual void printf(RewriterBase &rewriter, Value formatStrStart, int formatStrByteCount, ValueRange args) const = 0; + + // Emits LLVM code with |rewriter| to print a message, particularly useful for + // backend debug. |msg| is the message to print, |args| are the arguments to + // fill placeholders in the |msg|. + // NOTE: This function is used for backend debug. DO NOT DELETE. + // Example use: targetInfo.printf(rewriter,"index: %d, value: %f", {index, + // value}); + virtual void printf(RewriterBase &rewriter, StringRef msg, + ValueRange args) const = 0; + // Emits LLVM code with |rewriter| to perform assertion failure with the given // |message| from the given |func| in |file|. virtual void assertFail(RewriterBase &rewriter, Location loc, diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp index c2552bd3fff6..08aef1bbe5ec 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp @@ -218,6 +218,18 @@ void TargetInfo::printf(RewriterBase &rewriter, Value formatStrStart, /*useStdError=*/false); } +void TargetInfo::printf(RewriterBase &rewriter, StringRef msg, + ValueRange args) const { + assert(!msg.empty() && "printf with empty string not supported"); + llvm::SmallString<64> msgNewline(msg); + msgNewline.push_back('\n'); + msgNewline.push_back('\0'); + Value msgValue = + LLVM::addStringToModule(UnknownLoc::get(rewriter.getContext()), rewriter, + "printfFormat_", msgNewline); + printf(rewriter, msgValue, msgNewline.size_in_bytes(), args); +} + void TargetInfo::assertFail(RewriterBase &rewriter, Location loc, StringRef message, StringRef file, StringRef func, int line) const { diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h index 9d42571c84ec..d5ad966a4e57 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h @@ -57,6 +57,10 @@ class TargetInfo : public mlir::triton::TargetInfoBase { void printf(RewriterBase &rewriter, Value formatStrStart, int formatStrByteCount, ValueRange args) const override; + + void printf(RewriterBase &rewriter, StringRef msg, + ValueRange args) const override; + void assertFail(RewriterBase &rewriter, Location loc, StringRef message, StringRef file, StringRef func, int line) const override; diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp index 3d00cc112897..66e1b7e7ad20 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp @@ -643,6 +643,18 @@ void TargetInfo::printf(RewriterBase &rewriter, Value formatStrStart, call(funcOp, operands); } +void TargetInfo::printf(RewriterBase &rewriter, StringRef msg, + ValueRange args) const { + assert(!msg.empty() && "printf with empty string not supported"); + llvm::SmallString<64> msgNewline(msg); + msgNewline.push_back('\n'); + msgNewline.push_back('\0'); + Value msgValue = + LLVM::addStringToModule(UnknownLoc::get(rewriter.getContext()), rewriter, + "printfFormat_", msgNewline); + printf(rewriter, msgValue, msgNewline.size_in_bytes(), args); +} + void TargetInfo::assertFail(RewriterBase &rewriter, Location loc, StringRef message, StringRef file, StringRef func, int line) const { diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h index 7a7cd72c716a..c7c4ef3b4cee 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h @@ -52,6 +52,10 @@ class TargetInfo : public mlir::triton::TargetInfoBase { void printf(RewriterBase &rewriter, Value formatStrStart, int formatStrByteCount, ValueRange args) const override; + + void printf(RewriterBase &rewriter, StringRef msg, + ValueRange args) const override; + void assertFail(RewriterBase &rewriter, Location loc, StringRef message, StringRef file, StringRef func, int line) const override;