diff --git a/bin/triton-translate.cpp b/bin/triton-translate.cpp index 7d2cb19e6360..3ac0f23fcb5c 100644 --- a/bin/triton-translate.cpp +++ b/bin/triton-translate.cpp @@ -126,11 +126,13 @@ LogicalResult tritonTranslateMain(int argc, char **argv, llvm::LLVMContext llvmContext; mlir::triton::gpu::TMAMetadataTy tmaInfos; #ifdef USE_ROCM - auto llvmir = translateTritonGPUToLLVMIR(&llvmContext, *module, - SMArch.getValue(), tmaInfos, Target::ROCDL); + auto llvmir = + translateTritonGPUToLLVMIR(&llvmContext, *module, SMArch.getValue(), + tmaInfos, Target::ROCDL, 0 /*wavesPerEU*/); #else - auto llvmir = translateTritonGPUToLLVMIR(&llvmContext, *module, - SMArch.getValue(), tmaInfos, Target::Default); + auto llvmir = + translateTritonGPUToLLVMIR(&llvmContext, *module, SMArch.getValue(), + tmaInfos, Target::Default, 0 /*wavesPerEU*/); #endif if (!llvmir) { llvm::errs() << "Translate to LLVM IR failed"; diff --git a/include/triton/Target/LLVMIR/LLVMIRTranslation.h b/include/triton/Target/LLVMIR/LLVMIRTranslation.h index fa810121060c..84adda2e46fd 100644 --- a/include/triton/Target/LLVMIR/LLVMIRTranslation.h +++ b/include/triton/Target/LLVMIR/LLVMIRTranslation.h @@ -29,7 +29,7 @@ std::unique_ptr translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module, int computeCapability, mlir::triton::gpu::TMAMetadataTy &tmaInfos, - Target target); + Target target, int wavesPerEU); // Translate mlir LLVM dialect to LLVMIR, return null if failed. std::unique_ptr diff --git a/lib/Target/LLVMIR/LLVMIRTranslation.cpp b/lib/Target/LLVMIR/LLVMIRTranslation.cpp index 265c411057c7..45f0fedaf898 100644 --- a/lib/Target/LLVMIR/LLVMIRTranslation.cpp +++ b/lib/Target/LLVMIR/LLVMIRTranslation.cpp @@ -59,7 +59,8 @@ struct NVVMMetadata { // Add the nvvm related metadata to LLVM IR. static void amendLLVMFunc(llvm::Function *func, const NVVMMetadata &metadata, - Target target, const int threadsPerCTA) { + Target target, const int threadsPerCTA, + const int wavesPerEU) { auto *module = func->getParent(); auto &ctx = func->getContext(); @@ -102,6 +103,8 @@ static void amendLLVMFunc(llvm::Function *func, const NVVMMetadata &metadata, func->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL); func->addFnAttr("amdgpu-flat-work-group-size", "1, " + std::to_string(threadsPerCTA)); + if (wavesPerEU > 0) + func->addFnAttr("amdgpu-waves-per-eu", std::to_string(wavesPerEU)); func->addFnAttr("denormal-fp-math-f32", "preserve-sign"); func->addFnAttr("amdgpu-unsafe-fp-atomics", "true"); } break; @@ -283,7 +286,7 @@ bool linkExternLib(llvm::Module &module, llvm::StringRef name, std::unique_ptr translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module, - Target target) { + Target target, int wavesPerEU) { DialectRegistry registry; mlir::registerBuiltinDialectTranslation(registry); mlir::registerLLVMDialectTranslation(registry); @@ -331,7 +334,7 @@ translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module, for (auto &func : llvmModule->functions()) { auto it = nvvmMetadata.find(func.getName()); if (it != nvvmMetadata.end()) - amendLLVMFunc(&func, it->second, target, threadsPerCTA); + amendLLVMFunc(&func, it->second, target, threadsPerCTA, wavesPerEU); } return llvmModule; @@ -341,7 +344,7 @@ std::unique_ptr translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module, int computeCapability, mlir::triton::gpu::TMAMetadataTy &tmaInfos, - Target target) { + Target target, int wavesPerEU) { mlir::PassManager pm(module->getContext()); mlir::registerPassManagerCLOptions(); if (failed(applyPassManagerCLOptions(pm))) { @@ -385,7 +388,7 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext, return nullptr; } - auto llvmIR = translateLLVMToLLVMIR(llvmContext, module, target); + auto llvmIR = translateLLVMToLLVMIR(llvmContext, module, target, wavesPerEU); if (!llvmIR) { llvm::errs() << "Translate to LLVM IR failed"; return nullptr; diff --git a/python/src/triton.cc b/python/src/triton.cc index 710ec3f2e09b..29276493e8cd 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -1970,11 +1970,11 @@ void init_triton_translation(py::module &m) { "translate_triton_gpu_to_llvmir", [](mlir::ModuleOp op, int computeCapability, mlir::triton::gpu::TMAMetadataTy &tmaInfos, - mlir::triton::Target target) { + mlir::triton::Target target, int wavesPerEU) { py::gil_scoped_release allow_threads; llvm::LLVMContext llvmContext; auto llvmModule = ::mlir::triton::translateTritonGPUToLLVMIR( - &llvmContext, op, computeCapability, tmaInfos, target); + &llvmContext, op, computeCapability, tmaInfos, target, wavesPerEU); if (!llvmModule) llvm::report_fatal_error("Failed to translate TritonGPU to LLVM IR."); diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index e5f8f6662817..ef106702b702 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -162,14 +162,14 @@ def _add_external_libs(mod, libs): add_external_libs(mod, list(libs.keys()), list(libs.values())) -def ttgir_to_llir(mod, extern_libs, arch, tma_infos): +def ttgir_to_llir(mod, extern_libs, arch, tma_infos, waves_per_eu=0): if extern_libs: _add_external_libs(mod, extern_libs) # TODO: separate tritongpu_to_llvmir for different backends if _is_cuda(arch): - return translate_triton_gpu_to_llvmir(mod, arch, tma_infos, runtime.TARGET.NVVM) + return translate_triton_gpu_to_llvmir(mod, arch, tma_infos, runtime.TARGET.NVVM, waves_per_eu) else: - return translate_triton_gpu_to_llvmir(mod, 0, TMAInfos(), runtime.TARGET.ROCDL) + return translate_triton_gpu_to_llvmir(mod, 0, TMAInfos(), runtime.TARGET.ROCDL, waves_per_eu) # PTX translation @@ -308,6 +308,7 @@ def make_hash(fn, arch, env_vars, **kwargs): num_warps = kwargs.get("num_warps", 4) num_ctas = kwargs.get("num_ctas", 1) num_stages = kwargs.get("num_stages", 3) + waves_per_eu = kwargs.get("waves_per_eu", 0) enable_warp_specialization = kwargs.get("enable_warp_specialization", False) enable_persistent = kwargs.get("enable_persistent", False) debug = kwargs.get("debug", False) @@ -315,7 +316,7 @@ def make_hash(fn, arch, env_vars, **kwargs): get_conf_key = lambda conf: (sorted(conf.divisible_by_16), sorted(conf.equal_to_1), sorted(conf.ids_of_folded_args), sorted(conf.divisible_by_8)) configs_key = [get_conf_key(conf) for conf in configs] env_vars_list = [f"{env_vars[k]}" for k in sorted(env_vars.keys())] - key = f"{fn.cache_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{num_ctas}-{num_stages}-{enable_warp_specialization}-{enable_persistent}-{debug}-{arch}-{env_vars_list}" + key = f"{fn.cache_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{waves_per_eu}-{num_ctas}-{num_stages}-{enable_warp_specialization}-{enable_persistent}-{debug}-{arch}-{env_vars_list}" return hashlib.md5(key.encode("utf-8")).hexdigest() assert isinstance(fn, str) return hashlib.md5((Path(fn).read_text() + version_key()).encode("utf-8")).hexdigest() @@ -472,6 +473,7 @@ def compile(fn, **kwargs): assert num_warps > 0 and (num_warps & (num_warps - 1)) == 0, "num_warps must be a power of 2" num_ctas = kwargs.get("num_ctas", 1) num_stages = kwargs.get("num_stages", get_arch_default_num_stages(device_type, capability=capability)) + waves_per_eu = kwargs.get("waves_per_eu", 0) # TODO[shuhaoj]: Default should be to enable warp specialization once possible enable_warp_specialization = kwargs.get("enable_warp_specialization", False) # TODO[shuhaoj]: persistent can be decoupled with warp specialization @@ -499,7 +501,7 @@ def compile(fn, **kwargs): stages["ttgir"] = (lambda path: parse_mlir_module(path, context), lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps, warp_size, num_ctas, arch), num_stages, num_warps, num_ctas, arch, cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue)) stages["llir"] = (lambda path: Path(path).read_text(), - lambda src: ttgir_to_llir(src, extern_libs, arch, tma_infos)) + lambda src: ttgir_to_llir(src, extern_libs, arch, tma_infos, waves_per_eu)) if is_cuda: add_cuda_stages(arch, extern_libs, stages) elif is_hip: @@ -571,6 +573,7 @@ def compile(fn, **kwargs): "warp_size": warp_size, "num_ctas": num_ctas, "num_stages": num_stages, + "waves_per_eu": waves_per_eu, "enable_warp_specialization": enable_warp_specialization, "enable_persistent": enable_persistent, "constants": _get_jsonable_constants(constants), @@ -689,6 +692,7 @@ def __init__(self, fn, so_path, metadata, asm): self.warp_size = metadata["warp_size"] self.num_ctas = metadata["num_ctas"] self.num_stages = metadata["num_stages"] + self.waves_per_eu = metadata["waves_per_eu"] self.clusterDims = metadata["clusterDims"] if "tensormaps_info" in metadata: self.tensormaps_info = metadata["tensormaps_info"] diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index aad0f57a66f0..6dc9de5a7c60 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -276,13 +276,13 @@ def _make_constants(self, constexpr_key): constants = dict(zip(self.constexprs, constexpr_key)) return constants - def _call_hook(self, key, signature, device, constants, num_warps, num_ctas, num_stages, enable_warp_specialization, extern_libs, configs): + def _call_hook(self, key, signature, device, constants, num_warps, num_ctas, num_stages, waves_per_eu, enable_warp_specialization, extern_libs, configs): if JITFunction.cache_hook is None: return False name = self.fn.__name__ module = self.fn.__module__ arg_reprs = ', '.join([f'{name}: {ty}' for name, ty in zip(self.arg_names, key[1])]) - repr = f"{name}[num_warps={num_warps}, num_ctas={num_ctas}, num_stages={num_stages}, enable_warp_specialization={enable_warp_specialization}]({arg_reprs})" + repr = f"{name}[num_warps={num_warps}, num_ctas={num_ctas}, num_stages={num_stages}, waves_per_eu={waves_per_eu}, enable_warp_specialization={enable_warp_specialization}]({arg_reprs})" key = str(key) class LegacyCompiler: @@ -292,7 +292,7 @@ def __init__(self, module, name): pass kwargs = dict(signature=signature, device=device, constants=constants, - num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, enable_warp_specialization=enable_warp_specialization, extern_libs=extern_libs, + num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, waves_per_eu=waves_per_eu, enable_warp_specialization=enable_warp_specialization, extern_libs=extern_libs, configs=configs) return JITFunction.cache_hook(key=key, repr=repr, fn=LegacyCompiler(module, name), compile={ @@ -364,7 +364,7 @@ def _make_launcher(self): src = f""" import triton -def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, num_stages=None, enable_warp_specialization=False, extern_libs=None, stream=None, warmup=False, device=None, device_type=None): +def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, num_stages=None, waves_per_eu=0, enable_warp_specialization=False, extern_libs=None, stream=None, warmup=False, device=None, device_type=None): from ..compiler import compile, CompiledKernel, get_arch_default_num_warps, get_arch_default_num_stages sig_key = {f'{sig_keys},' if len(sig_keys) > 0 else ()} constexpr_key = {f'{constexpr_keys},' if len(constexpr_keys) > 0 else ()} @@ -406,7 +406,7 @@ def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, nu if num_stages is None: num_stages = get_arch_default_num_stages(device_type) - key = (version_key, sig_key, constexpr_key, spec_key, num_warps, num_ctas, num_stages, enable_warp_specialization, self.debug) + key = (version_key, sig_key, constexpr_key, spec_key, num_warps, num_ctas, num_stages, waves_per_eu, enable_warp_specialization, self.debug) if not extern_libs is None: key = (key, tuple(extern_libs.items())) @@ -434,8 +434,8 @@ def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, nu for i, arg in constants.items(): if callable(arg): raise TypeError(f"Callable constexpr at index {{i}} is not supported") - if not self._call_hook(key, signature, device, constants, num_warps, num_ctas, num_stages, enable_warp_specialization, extern_libs, configs): - bin = compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, enable_warp_specialization=enable_warp_specialization, extern_libs=extern_libs, configs=configs, debug=self.debug, device_type=device_type) + if not self._call_hook(key, signature, device, constants, num_warps, num_ctas, num_stages, waves_per_eu, enable_warp_specialization, extern_libs, configs): + bin = compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, waves_per_eu=waves_per_eu, enable_warp_specialization=enable_warp_specialization, extern_libs=extern_libs, configs=configs, debug=self.debug, device_type=device_type) # Create tensormaps and append to args args = bin.assemble_tensormap_to_arg(args) if not warmup: diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index 0de71c697f4d..027cdc31e289 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -464,6 +464,7 @@ def forward(ctx, q, k, v, causal, sm_scale, split_kernel=False): BLOCK_N = 64 num_warps = 4 num_stages = 1 + waves_per_eu = 2 if causal else 3 grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1) L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) @@ -481,7 +482,7 @@ def forward(ctx, q, k, v, causal, sm_scale, split_kernel=False): BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk, IS_CAUSAL=causal, num_warps=num_warps, - num_stages=num_stages) + num_stages=num_stages, waves_per_eu=waves_per_eu) ctx.save_for_backward(q, k, v, o, L) ctx.grid = grid @@ -560,7 +561,7 @@ def backward(ctx, do): v.stride(0), v.stride(1), v.stride(2), v.stride(3), q.shape[0], q.shape[1], q.shape[2], BLOCK_M=2*BLOCK, BLOCK_N=BLOCK, - BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=4, + BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=4, waves_per_eu=1, num_stages=1, ) # print(h.asm["ttgir"])