Skip to content

Commit

Permalink
Use optimal number of VGPRs (#281)
Browse files Browse the repository at this point in the history
* Use optimal number of VGPRs

* Fix tritongpu_to_hsaco test
  • Loading branch information
oplavsic authored Aug 4, 2023
1 parent e1de24c commit 1388445
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
12 changes: 9 additions & 3 deletions lib/Target/LLVMIR/LLVMIRTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
#include "mlir/Transforms/Passes.h"
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Tools/Sys/GetEnv.hpp"
#include "llvm/IR/CallingConv.h"
#include "llvm/ADT/APInt.h"
Expand Down Expand Up @@ -51,7 +52,7 @@ struct NVVMMetadata {

// Add the nvvm related metadata to LLVM IR.
static void amendLLVMFunc(llvm::Function *func, const NVVMMetadata &metadata,
bool isROCM) {
bool isROCM, const int threadsPerCTA) {
auto *module = func->getParent();
auto &ctx = func->getContext();

Expand Down Expand Up @@ -83,7 +84,8 @@ static void amendLLVMFunc(llvm::Function *func, const NVVMMetadata &metadata,
if (metadata.isKernel) {
if (isROCM) {
func->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL);
func->addFnAttr("amdgpu-flat-work-group-size", "1, 1024");
func->addFnAttr("amdgpu-flat-work-group-size",
"1, " + std::to_string(threadsPerCTA));
func->addFnAttr("denormal-fp-math-f32", "preserve-sign");
func->addFnAttr("amdgpu-unsafe-fp-atomics", "true");
} else {
Expand Down Expand Up @@ -312,10 +314,14 @@ translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module,
return nullptr;
}

const int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(module);
const int warpSize = triton::gpu::TritonGPUDialect::getThreadsPerWarp(module);
const int threadsPerCTA = numWarps * warpSize;

for (auto &func : llvmModule->functions()) {
auto it = nvvmMetadata.find(func.getName());
if (it != nvvmMetadata.end())
amendLLVMFunc(&func, it->second, isROCM);
amendLLVMFunc(&func, it->second, isROCM, threadsPerCTA);
}

return llvmModule;
Expand Down
2 changes: 1 addition & 1 deletion test/Target/tritongpu_to_hsaco.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
// CHECK: .group_segment_fixed_size: 0
// CHECK-NEXT: .kernarg_segment_align: 8
// CHECK-NEXT: .kernarg_segment_size: 16
// CHECK-NEXT: .max_flat_workgroup_size: 1024
// CHECK-NEXT: .max_flat_workgroup_size: 256
// CHECK-NEXT: .name: test_empty_kernel
// CHECK-NEXT: .private_segment_fixed_size: 0

Expand Down

0 comments on commit 1388445

Please sign in to comment.