Skip to content

Commit

Permalink
Add waves_per_eu as kernel parameter (#319)
Browse files Browse the repository at this point in the history
* Add waves_per_eu as kernel parameter

* Fix failing tests

* Add default value for waves_per_eu for ttgir_to_llir function

* Remove aot.py
  • Loading branch information
oplavsic authored Oct 6, 2023
1 parent be95edc commit e801638
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 26 deletions.
10 changes: 6 additions & 4 deletions bin/triton-translate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,13 @@ LogicalResult tritonTranslateMain(int argc, char **argv,
llvm::LLVMContext llvmContext;
mlir::triton::gpu::TMAMetadataTy tmaInfos;
#ifdef USE_ROCM
auto llvmir = translateTritonGPUToLLVMIR(&llvmContext, *module,
SMArch.getValue(), tmaInfos, Target::ROCDL);
auto llvmir =
translateTritonGPUToLLVMIR(&llvmContext, *module, SMArch.getValue(),
tmaInfos, Target::ROCDL, 0 /*wavesPerEU*/);
#else
auto llvmir = translateTritonGPUToLLVMIR(&llvmContext, *module,
SMArch.getValue(), tmaInfos, Target::Default);
auto llvmir =
translateTritonGPUToLLVMIR(&llvmContext, *module, SMArch.getValue(),
tmaInfos, Target::Default, 0 /*wavesPerEU*/);
#endif
if (!llvmir) {
llvm::errs() << "Translate to LLVM IR failed";
Expand Down
2 changes: 1 addition & 1 deletion include/triton/Target/LLVMIR/LLVMIRTranslation.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ std::unique_ptr<llvm::Module>
translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
mlir::ModuleOp module, int computeCapability,
mlir::triton::gpu::TMAMetadataTy &tmaInfos,
Target target);
Target target, int wavesPerEU);

// Translate mlir LLVM dialect to LLVMIR, return null if failed.
std::unique_ptr<llvm::Module>
Expand Down
13 changes: 8 additions & 5 deletions lib/Target/LLVMIR/LLVMIRTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ struct NVVMMetadata {

// Add the nvvm related metadata to LLVM IR.
static void amendLLVMFunc(llvm::Function *func, const NVVMMetadata &metadata,
Target target, const int threadsPerCTA) {
Target target, const int threadsPerCTA,
const int wavesPerEU) {
auto *module = func->getParent();
auto &ctx = func->getContext();

Expand Down Expand Up @@ -102,6 +103,8 @@ static void amendLLVMFunc(llvm::Function *func, const NVVMMetadata &metadata,
func->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL);
func->addFnAttr("amdgpu-flat-work-group-size",
"1, " + std::to_string(threadsPerCTA));
if (wavesPerEU > 0)
func->addFnAttr("amdgpu-waves-per-eu", std::to_string(wavesPerEU));
func->addFnAttr("denormal-fp-math-f32", "preserve-sign");
func->addFnAttr("amdgpu-unsafe-fp-atomics", "true");
} break;
Expand Down Expand Up @@ -283,7 +286,7 @@ bool linkExternLib(llvm::Module &module, llvm::StringRef name,

std::unique_ptr<llvm::Module>
translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module,
Target target) {
Target target, int wavesPerEU) {
DialectRegistry registry;
mlir::registerBuiltinDialectTranslation(registry);
mlir::registerLLVMDialectTranslation(registry);
Expand Down Expand Up @@ -331,7 +334,7 @@ translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module,
for (auto &func : llvmModule->functions()) {
auto it = nvvmMetadata.find(func.getName());
if (it != nvvmMetadata.end())
amendLLVMFunc(&func, it->second, target, threadsPerCTA);
amendLLVMFunc(&func, it->second, target, threadsPerCTA, wavesPerEU);
}

return llvmModule;
Expand All @@ -341,7 +344,7 @@ std::unique_ptr<llvm::Module>
translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
mlir::ModuleOp module, int computeCapability,
mlir::triton::gpu::TMAMetadataTy &tmaInfos,
Target target) {
Target target, int wavesPerEU) {
mlir::PassManager pm(module->getContext());
mlir::registerPassManagerCLOptions();
if (failed(applyPassManagerCLOptions(pm))) {
Expand Down Expand Up @@ -385,7 +388,7 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
return nullptr;
}

auto llvmIR = translateLLVMToLLVMIR(llvmContext, module, target);
auto llvmIR = translateLLVMToLLVMIR(llvmContext, module, target, wavesPerEU);
if (!llvmIR) {
llvm::errs() << "Translate to LLVM IR failed";
return nullptr;
Expand Down
4 changes: 2 additions & 2 deletions python/src/triton.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1970,11 +1970,11 @@ void init_triton_translation(py::module &m) {
"translate_triton_gpu_to_llvmir",
[](mlir::ModuleOp op, int computeCapability,
mlir::triton::gpu::TMAMetadataTy &tmaInfos,
mlir::triton::Target target) {
mlir::triton::Target target, int wavesPerEU) {
py::gil_scoped_release allow_threads;
llvm::LLVMContext llvmContext;
auto llvmModule = ::mlir::triton::translateTritonGPUToLLVMIR(
&llvmContext, op, computeCapability, tmaInfos, target);
&llvmContext, op, computeCapability, tmaInfos, target, wavesPerEU);
if (!llvmModule)
llvm::report_fatal_error("Failed to translate TritonGPU to LLVM IR.");

Expand Down
14 changes: 9 additions & 5 deletions python/triton/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,14 +162,14 @@ def _add_external_libs(mod, libs):
add_external_libs(mod, list(libs.keys()), list(libs.values()))


def ttgir_to_llir(mod, extern_libs, arch, tma_infos):
def ttgir_to_llir(mod, extern_libs, arch, tma_infos, waves_per_eu=0):
if extern_libs:
_add_external_libs(mod, extern_libs)
# TODO: separate tritongpu_to_llvmir for different backends
if _is_cuda(arch):
return translate_triton_gpu_to_llvmir(mod, arch, tma_infos, runtime.TARGET.NVVM)
return translate_triton_gpu_to_llvmir(mod, arch, tma_infos, runtime.TARGET.NVVM, waves_per_eu)
else:
return translate_triton_gpu_to_llvmir(mod, 0, TMAInfos(), runtime.TARGET.ROCDL)
return translate_triton_gpu_to_llvmir(mod, 0, TMAInfos(), runtime.TARGET.ROCDL, waves_per_eu)


# PTX translation
Expand Down Expand Up @@ -308,14 +308,15 @@ def make_hash(fn, arch, env_vars, **kwargs):
num_warps = kwargs.get("num_warps", 4)
num_ctas = kwargs.get("num_ctas", 1)
num_stages = kwargs.get("num_stages", 3)
waves_per_eu = kwargs.get("waves_per_eu", 0)
enable_warp_specialization = kwargs.get("enable_warp_specialization", False)
enable_persistent = kwargs.get("enable_persistent", False)
debug = kwargs.get("debug", False)
# Get unique key for the compiled code
get_conf_key = lambda conf: (sorted(conf.divisible_by_16), sorted(conf.equal_to_1), sorted(conf.ids_of_folded_args), sorted(conf.divisible_by_8))
configs_key = [get_conf_key(conf) for conf in configs]
env_vars_list = [f"{env_vars[k]}" for k in sorted(env_vars.keys())]
key = f"{fn.cache_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{num_ctas}-{num_stages}-{enable_warp_specialization}-{enable_persistent}-{debug}-{arch}-{env_vars_list}"
key = f"{fn.cache_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{waves_per_eu}-{num_ctas}-{num_stages}-{enable_warp_specialization}-{enable_persistent}-{debug}-{arch}-{env_vars_list}"
return hashlib.md5(key.encode("utf-8")).hexdigest()
assert isinstance(fn, str)
return hashlib.md5((Path(fn).read_text() + version_key()).encode("utf-8")).hexdigest()
Expand Down Expand Up @@ -472,6 +473,7 @@ def compile(fn, **kwargs):
assert num_warps > 0 and (num_warps & (num_warps - 1)) == 0, "num_warps must be a power of 2"
num_ctas = kwargs.get("num_ctas", 1)
num_stages = kwargs.get("num_stages", get_arch_default_num_stages(device_type, capability=capability))
waves_per_eu = kwargs.get("waves_per_eu", 0)
# TODO[shuhaoj]: Default should be to enable warp specialization once possible
enable_warp_specialization = kwargs.get("enable_warp_specialization", False)
# TODO[shuhaoj]: persistent can be decoupled with warp specialization
Expand Down Expand Up @@ -499,7 +501,7 @@ def compile(fn, **kwargs):
stages["ttgir"] = (lambda path: parse_mlir_module(path, context),
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps, warp_size, num_ctas, arch), num_stages, num_warps, num_ctas, arch, cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue))
stages["llir"] = (lambda path: Path(path).read_text(),
lambda src: ttgir_to_llir(src, extern_libs, arch, tma_infos))
lambda src: ttgir_to_llir(src, extern_libs, arch, tma_infos, waves_per_eu))
if is_cuda:
add_cuda_stages(arch, extern_libs, stages)
elif is_hip:
Expand Down Expand Up @@ -571,6 +573,7 @@ def compile(fn, **kwargs):
"warp_size": warp_size,
"num_ctas": num_ctas,
"num_stages": num_stages,
"waves_per_eu": waves_per_eu,
"enable_warp_specialization": enable_warp_specialization,
"enable_persistent": enable_persistent,
"constants": _get_jsonable_constants(constants),
Expand Down Expand Up @@ -689,6 +692,7 @@ def __init__(self, fn, so_path, metadata, asm):
self.warp_size = metadata["warp_size"]
self.num_ctas = metadata["num_ctas"]
self.num_stages = metadata["num_stages"]
self.waves_per_eu = metadata["waves_per_eu"]
self.clusterDims = metadata["clusterDims"]
if "tensormaps_info" in metadata:
self.tensormaps_info = metadata["tensormaps_info"]
Expand Down
14 changes: 7 additions & 7 deletions python/triton/runtime/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,13 +276,13 @@ def _make_constants(self, constexpr_key):
constants = dict(zip(self.constexprs, constexpr_key))
return constants

def _call_hook(self, key, signature, device, constants, num_warps, num_ctas, num_stages, enable_warp_specialization, extern_libs, configs):
def _call_hook(self, key, signature, device, constants, num_warps, num_ctas, num_stages, waves_per_eu, enable_warp_specialization, extern_libs, configs):
if JITFunction.cache_hook is None:
return False
name = self.fn.__name__
module = self.fn.__module__
arg_reprs = ', '.join([f'{name}: {ty}' for name, ty in zip(self.arg_names, key[1])])
repr = f"{name}[num_warps={num_warps}, num_ctas={num_ctas}, num_stages={num_stages}, enable_warp_specialization={enable_warp_specialization}]({arg_reprs})"
repr = f"{name}[num_warps={num_warps}, num_ctas={num_ctas}, num_stages={num_stages}, waves_per_eu={waves_per_eu}, enable_warp_specialization={enable_warp_specialization}]({arg_reprs})"
key = str(key)

class LegacyCompiler:
Expand All @@ -292,7 +292,7 @@ def __init__(self, module, name):
pass

kwargs = dict(signature=signature, device=device, constants=constants,
num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, enable_warp_specialization=enable_warp_specialization, extern_libs=extern_libs,
num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, waves_per_eu=waves_per_eu, enable_warp_specialization=enable_warp_specialization, extern_libs=extern_libs,
configs=configs)

return JITFunction.cache_hook(key=key, repr=repr, fn=LegacyCompiler(module, name), compile={
Expand Down Expand Up @@ -364,7 +364,7 @@ def _make_launcher(self):

src = f"""
import triton
def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, num_stages=None, enable_warp_specialization=False, extern_libs=None, stream=None, warmup=False, device=None, device_type=None):
def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, num_stages=None, waves_per_eu=0, enable_warp_specialization=False, extern_libs=None, stream=None, warmup=False, device=None, device_type=None):
from ..compiler import compile, CompiledKernel, get_arch_default_num_warps, get_arch_default_num_stages
sig_key = {f'{sig_keys},' if len(sig_keys) > 0 else ()}
constexpr_key = {f'{constexpr_keys},' if len(constexpr_keys) > 0 else ()}
Expand Down Expand Up @@ -406,7 +406,7 @@ def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, nu
if num_stages is None:
num_stages = get_arch_default_num_stages(device_type)
key = (version_key, sig_key, constexpr_key, spec_key, num_warps, num_ctas, num_stages, enable_warp_specialization, self.debug)
key = (version_key, sig_key, constexpr_key, spec_key, num_warps, num_ctas, num_stages, waves_per_eu, enable_warp_specialization, self.debug)
if not extern_libs is None:
key = (key, tuple(extern_libs.items()))
Expand Down Expand Up @@ -434,8 +434,8 @@ def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, nu
for i, arg in constants.items():
if callable(arg):
raise TypeError(f"Callable constexpr at index {{i}} is not supported")
if not self._call_hook(key, signature, device, constants, num_warps, num_ctas, num_stages, enable_warp_specialization, extern_libs, configs):
bin = compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, enable_warp_specialization=enable_warp_specialization, extern_libs=extern_libs, configs=configs, debug=self.debug, device_type=device_type)
if not self._call_hook(key, signature, device, constants, num_warps, num_ctas, num_stages, waves_per_eu, enable_warp_specialization, extern_libs, configs):
bin = compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, waves_per_eu=waves_per_eu, enable_warp_specialization=enable_warp_specialization, extern_libs=extern_libs, configs=configs, debug=self.debug, device_type=device_type)
# Create tensormaps and append to args
args = bin.assemble_tensormap_to_arg(args)
if not warmup:
Expand Down
5 changes: 3 additions & 2 deletions python/tutorials/06-fused-attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,7 @@ def forward(ctx, q, k, v, causal, sm_scale, split_kernel=False):
BLOCK_N = 64
num_warps = 4
num_stages = 1
waves_per_eu = 2 if causal else 3

grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1)
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
Expand All @@ -481,7 +482,7 @@ def forward(ctx, q, k, v, causal, sm_scale, split_kernel=False):
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk,
IS_CAUSAL=causal,
num_warps=num_warps,
num_stages=num_stages)
num_stages=num_stages, waves_per_eu=waves_per_eu)

ctx.save_for_backward(q, k, v, o, L)
ctx.grid = grid
Expand Down Expand Up @@ -560,7 +561,7 @@ def backward(ctx, do):
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
q.shape[0], q.shape[1], q.shape[2],
BLOCK_M=2*BLOCK, BLOCK_N=BLOCK,
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=4,
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=4, waves_per_eu=1,
num_stages=1,
)
# print(h.asm["ttgir"])
Expand Down

0 comments on commit e801638

Please sign in to comment.