From c6d33dcebf5b8c04d7a4b567c51327c6157a12cd Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Thu, 31 Aug 2023 17:02:00 -0400 Subject: [PATCH 001/122] [ROCM] Core Functionality for AMD (#1983) * this pr adds a third party backend for triton that works on AMD * this expose a lot of the work that has been done in our [fork](https://github.com/ROCmSoftwarePlatform/triton) * most unit tests on `test_core.py` pass * it skips some unit tests for various reasons * we plan to follow up with more prs improving Functionality and Performance in the future --------- Co-authored-by: Philippe Tillet --- .github/workflows/integration-tests.yml | 6 +- .gitmodules | 4 + CMakeLists.txt | 1 - bin/CMakeLists.txt | 1 - bin/triton-translate.cpp | 10 +- .../triton/Target/AMDGCN/AMDGCNTranslation.h | 19 -- .../triton/Target/HSACO/HSACOTranslation.h | 21 -- include/triton/Tools/Sys/GetEnv.hpp | 2 +- lib/Target/CMakeLists.txt | 1 - lib/Target/HSACO/CMakeLists.txt | 9 - lib/Target/HSACO/HSACOTranslation.cpp | 182 -------------- python/src/triton.cc | 20 +- python/test/unit/language/test_core.py | 232 ++++++++++++++---- python/triton/common/backend.py | 3 +- python/triton/compiler/compiler.py | 162 ++++-------- python/triton/compiler/make_launcher.py | 155 +----------- python/triton/language/math.py | 9 +- python/triton/language/semantic.py | 40 +++ python/triton/runtime/jit.py | 6 +- third_party/amd_hip_backend | 1 + 20 files changed, 290 insertions(+), 594 deletions(-) delete mode 100644 include/triton/Target/AMDGCN/AMDGCNTranslation.h delete mode 100644 include/triton/Target/HSACO/HSACOTranslation.h delete mode 100644 lib/Target/HSACO/CMakeLists.txt delete mode 100644 lib/Target/HSACO/HSACOTranslation.cpp create mode 160000 third_party/amd_hip_backend diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index ca785a1a5d56..1b83b566359f 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -27,7 +27,7 @@ jobs: run: | if [ x"${{ github.repository }}" == x"openai/triton" ]; then echo '::set-output name=matrix-required::[["self-hosted", "A100"], ["self-hosted", "H100"]]' - echo '::set-output name=matrix-optional::[]' + echo '::set-output name=matrix-optional::[["self-hosted", "gfx908"], ["self-hosted", "arc770"]]' else echo '::set-output name=matrix-required::["ubuntu-latest"]' echo '::set-output name=matrix-optional::["ubuntu-latest"]' @@ -209,10 +209,12 @@ jobs: - name: Install Triton on ROCM if: ${{ env.BACKEND == 'ROCM'}} run: | + git submodule update --init --recursive cd python python3 -m pip install --upgrade pip python3 -m pip install cmake==3.24 python3 -m pip install torch==1.13.1 --index-url https://download.pytorch.org/whl/rocm5.2 + export TRITON_CODEGEN_AMD_HIP_BACKEND=1 python3 -m pip install --no-build-isolation -vvv '.[tests]' - name: Install Triton on XPU @@ -234,7 +236,7 @@ jobs: if: ${{ env.BACKEND == 'ROCM'}} run: | cd python/test/unit/language - python3 -m pytest --capture=tee-sys -rfs --verbose "test_core.py::test_empty_kernel" + python3 -m pytest --capture=tee-sys -rfs --verbose "test_core.py" - name: Run python tests on XPU if: ${{ env.BACKEND == 'XPU'}} diff --git a/.gitmodules b/.gitmodules index 1638f552cb27..30ba4342537e 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,7 @@ [submodule "third_party/intel_xpu_backend"] path = third_party/intel_xpu_backend url = http://github.com/intel/intel-xpu-backend-for-triton +[submodule "third_party/amd_hip_backend"] + path = third_party/amd_hip_backend + url = https://github.com/ROCmSoftwarePlatform/triton + branch = third_party_backend_2 diff --git a/CMakeLists.txt b/CMakeLists.txt index 7dd62d1a3dea..bb709df6d8d1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -212,7 +212,6 @@ if(TRITON_BUILD_PYTHON_MODULE) TritonNvidiaGPUTransforms TritonLLVMIR TritonPTX - TritonHSACO ${dialect_libs} ${conversion_libs} diff --git a/bin/CMakeLists.txt b/bin/CMakeLists.txt index a8966c5e77f4..2b2f6afeb5ce 100644 --- a/bin/CMakeLists.txt +++ b/bin/CMakeLists.txt @@ -53,7 +53,6 @@ llvm_update_compile_flags(triton-translate) TritonNvidiaGPUTransforms TritonLLVMIR TritonPTX - TritonHSACO ${dialect_libs} ${conversion_libs} # tests diff --git a/bin/triton-translate.cpp b/bin/triton-translate.cpp index b7f02484f288..49ca1322397c 100644 --- a/bin/triton-translate.cpp +++ b/bin/triton-translate.cpp @@ -15,7 +15,6 @@ #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" -#include "triton/Target/HSACO/HSACOTranslation.h" #include "triton/Target/LLVMIR/LLVMIRTranslation.h" #include "triton/Target/PTX/PTXTranslation.h" #include "llvm/IR/LLVMContext.h" @@ -131,16 +130,11 @@ LogicalResult tritonTranslateMain(int argc, char **argv, llvm::errs() << "Translate to LLVM IR failed"; } - if (targetKind == "llvmir") + if (targetKind == "llvmir") { llvm::outs() << *llvmir << '\n'; - else if (targetKind == "ptx") + } else if (targetKind == "ptx") { llvm::outs() << ::triton::translateLLVMIRToPTX(*llvmir, SMArch.getValue(), ptxVersion.getValue()); - else if (targetKind == "hsaco") { - auto [module, hsaco] = ::triton::translateLLVMIRToHSACO( - *llvmir, GCNArch.getValue(), GCNTriple.getValue(), - GCNFeatures.getValue()); - llvm::outs() << hsaco; } else { llvm::errs() << "Error: Unknown target specified: " << targetKind << "\n"; return failure(); diff --git a/include/triton/Target/AMDGCN/AMDGCNTranslation.h b/include/triton/Target/AMDGCN/AMDGCNTranslation.h deleted file mode 100644 index c20f1924db5c..000000000000 --- a/include/triton/Target/AMDGCN/AMDGCNTranslation.h +++ /dev/null @@ -1,19 +0,0 @@ -#ifndef TRITON_TARGET_AMDGCNTRANSLATION_H -#define TRITON_TARGET_AMDGCNTRANSLATION_H - -#include -#include - -namespace llvm { -class Module; -} // namespace llvm - -namespace triton { - -// Translate LLVM IR to AMDGCN code. -std::tuple -translateLLVMIRToAMDGCN(llvm::Module &module, std::string cc); - -} // namespace triton - -#endif diff --git a/include/triton/Target/HSACO/HSACOTranslation.h b/include/triton/Target/HSACO/HSACOTranslation.h deleted file mode 100644 index 21ab10d6d095..000000000000 --- a/include/triton/Target/HSACO/HSACOTranslation.h +++ /dev/null @@ -1,21 +0,0 @@ -#ifndef TRITON_TARGET_HSACOTRANSLATION_H -#define TRITON_TARGET_HSACOTRANSLATION_H - -#include -#include -#include - -namespace llvm { -class Module; -} // namespace llvm - -namespace triton { - -// Translate TritonGPU IR to HSACO code. -std::tuple -translateLLVMIRToHSACO(llvm::Module &module, std::string gfx_arch, - std::string gfx_triple, std::string gfx_features); - -} // namespace triton - -#endif diff --git a/include/triton/Tools/Sys/GetEnv.hpp b/include/triton/Tools/Sys/GetEnv.hpp index 53e421ef218b..bf682af946b1 100644 --- a/include/triton/Tools/Sys/GetEnv.hpp +++ b/include/triton/Tools/Sys/GetEnv.hpp @@ -46,7 +46,7 @@ inline std::string getenv(const char *name) { inline bool getBoolEnv(const std::string &env) { std::string msg = "Environment variable " + env + " is not recognized"; - assert(triton::ENV_VARS.find(env.c_str()) != triton::ENV_VARS.end() && + assert(::triton::ENV_VARS.find(env.c_str()) != ::triton::ENV_VARS.end() && msg.c_str()); const char *s = std::getenv(env.c_str()); std::string str(s ? s : ""); diff --git a/lib/Target/CMakeLists.txt b/lib/Target/CMakeLists.txt index 99cf364fab4d..9b24f0ff225b 100644 --- a/lib/Target/CMakeLists.txt +++ b/lib/Target/CMakeLists.txt @@ -1,3 +1,2 @@ add_subdirectory(LLVMIR) add_subdirectory(PTX) -add_subdirectory(HSACO) diff --git a/lib/Target/HSACO/CMakeLists.txt b/lib/Target/HSACO/CMakeLists.txt deleted file mode 100644 index ea59a0619d53..000000000000 --- a/lib/Target/HSACO/CMakeLists.txt +++ /dev/null @@ -1,9 +0,0 @@ -add_mlir_translation_library(TritonHSACO - HSACOTranslation.cpp - - LINK_COMPONENTS - Core - - LINK_LIBS PUBLIC - TritonLLVMIR - ) diff --git a/lib/Target/HSACO/HSACOTranslation.cpp b/lib/Target/HSACO/HSACOTranslation.cpp deleted file mode 100644 index ff2c2ea3aa0b..000000000000 --- a/lib/Target/HSACO/HSACOTranslation.cpp +++ /dev/null @@ -1,182 +0,0 @@ -#include "triton/Target/HSACO/HSACOTranslation.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/ExecutionEngine/ExecutionEngine.h" -#include "mlir/ExecutionEngine/OptUtils.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/Dialect.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassManager.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" -#include "mlir/Target/LLVMIR/Export.h" -#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h" -#include "triton/Target/LLVMIR/LLVMIRTranslation.h" -#include "triton/Tools/Sys/GetEnv.hpp" - -#include "llvm/ExecutionEngine/ExecutionEngine.h" -#include "llvm/ExecutionEngine/SectionMemoryManager.h" -#include "llvm/IR/IRBuilder.h" -#include "llvm/IR/IRPrintingPasses.h" -#include "llvm/IR/LegacyPassManager.h" -#include "llvm/IR/Module.h" -#include "llvm/IR/Verifier.h" -#include "llvm/MC/TargetRegistry.h" -#include "llvm/Support/CodeGen.h" -#include "llvm/Support/CommandLine.h" -#include "llvm/Support/SourceMgr.h" -#include "llvm/Support/TargetSelect.h" -#include "llvm/Support/raw_ostream.h" -#include "llvm/Target/TargetMachine.h" -#include "llvm/Target/TargetOptions.h" -#include "llvm/Transforms/Scalar.h" -#include "llvm/Transforms/Utils/Cloning.h" -#include -#include -#include -#include - -namespace { - -void init_llvm() { - LLVMInitializeAMDGPUTarget(); - LLVMInitializeAMDGPUTargetInfo(); - LLVMInitializeAMDGPUTargetMC(); - LLVMInitializeAMDGPUAsmParser(); - LLVMInitializeAMDGPUAsmPrinter(); -} - -std::unique_ptr -initialize_module(llvm::Module *module, const std::string &triple, - const std::string &proc, const std::string &features) { - // verify and store llvm - llvm::legacy::PassManager pm; - pm.add(llvm::createVerifierPass()); - pm.run(*module); - - module->setTargetTriple(triple); - - std::string error; - auto target = - llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error); - llvm::TargetOptions opt; - opt.AllowFPOpFusion = llvm::FPOpFusion::Fast; - opt.UnsafeFPMath = false; - opt.NoInfsFPMath = false; - opt.NoNaNsFPMath = true; - llvm::TargetMachine *machine = target->createTargetMachine( - module->getTargetTriple(), proc, features, opt, llvm::Reloc::PIC_, - std::nullopt, llvm::CodeGenOpt::Aggressive); - - module->setDataLayout(machine->createDataLayout()); - - for (llvm::Function &f : module->functions()) - f.addFnAttr(llvm::Attribute::AlwaysInline); - - return std::unique_ptr(machine); -} - -std::string generate_amdgcn_assembly(llvm::Module *module, - const std::string &triple, - const std::string &proc, - const std::string &features) { - auto machine = initialize_module(module, triple, proc, features); - llvm::SmallVector buffer; - llvm::legacy::PassManager pass; - llvm::raw_svector_ostream stream(buffer); - - // emit - machine->addPassesToEmitFile(pass, stream, nullptr, - llvm::CodeGenFileType::CGFT_AssemblyFile); - pass.run(*module); - - std::string amdgcn(buffer.begin(), buffer.end()); - if (::triton::tools::getBoolEnv("AMDGCN_ENABLE_DUMP")) { - std::cout << "// -----// AMDGCN Dump //----- //\n" << amdgcn << std::endl; - } - - return amdgcn; -} - -std::string generate_hsaco(llvm::Module *module, const std::string &triple, - const std::string &proc, - const std::string &features) { - auto machine = initialize_module(module, triple, proc, features); - - // create unique dir for kernel's binary and hsaco - std::error_code ec; - std::string kernel_name_base = "amd_triton_kernel"; - std::filesystem::path tmp = std::filesystem::temp_directory_path(); - std::filesystem::path kernel_dir_base(kernel_name_base); - llvm::SmallString<256> unique_dir; - ec = llvm::sys::fs::createUniqueDirectory((tmp / kernel_dir_base).string(), - unique_dir); - if (ec) { - std::cerr << "Directory for " << kernel_name_base - << " was not created. error code: " << ec << std::endl; - } - std::filesystem::path kernel_dir(unique_dir.data()); - std::string kernel_name = kernel_dir.stem(); - - // Save GCN ISA binary. - std::filesystem::path isa_binary(kernel_name + ".o"); - std::string isabin_path = (kernel_dir / isa_binary).string(); - std::unique_ptr isabin_fs( - new llvm::raw_fd_ostream(isabin_path, ec, llvm::sys::fs::OF_Text)); - if (ec) { - std::cerr << isabin_path << " was not created. error code: " << ec - << std::endl; - } - - // emit - llvm::legacy::PassManager pass; - machine->addPassesToEmitFile(pass, *isabin_fs, nullptr, - llvm::CGFT_ObjectFile); - pass.run(*module); - - // generate HASCO file - std::filesystem::path hsaco(kernel_name + ".hsaco"); - std::string hsaco_path = (kernel_dir / hsaco).string(); - std::string error_message; - std::string lld_path = "/opt/rocm/llvm/bin/ld.lld"; - int lld_result = llvm::sys::ExecuteAndWait( - lld_path, - {lld_path, "-flavor", "gnu", "-shared", "-o", hsaco_path, isabin_path}, - std::nullopt, {}, 0, 0, &error_message); - if (lld_result) { - std::cout << "ld.lld execute fail: " << std::endl; - std::cout << error_message << std::endl; - std::cout << lld_result << std::endl; - } - - return hsaco_path; -} - -std::tuple -llir_to_amdgcn_and_hsaco(llvm::Module *module, std::string gfx_arch, - std::string gfx_triple, std::string gfx_features) { - - init_llvm(); - - // verify and store llvm - auto module_obj = llvm::CloneModule(*module); - auto amdgcn = - generate_amdgcn_assembly(module, gfx_triple, gfx_arch, gfx_features); - auto hsaco_path = - generate_hsaco(module_obj.get(), gfx_triple, gfx_arch, gfx_features); - - return std::make_tuple(amdgcn, hsaco_path); -} - -} // namespace - -namespace triton { - -std::tuple -translateLLVMIRToHSACO(llvm::Module &module, std::string gfx_arch, - std::string gfx_triple, std::string gfx_features) { - auto hsacoCode = - llir_to_amdgcn_and_hsaco(&module, gfx_arch, gfx_triple, gfx_features); - return hsacoCode; -} - -} // namespace triton diff --git a/python/src/triton.cc b/python/src/triton.cc index 1a947a8a7df4..6ac87d6c34fe 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -33,7 +33,6 @@ #include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" #include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" -#include "triton/Target/HSACO/HSACOTranslation.h" #include "triton/Target/LLVMIR/LLVMIRTranslation.h" #include "triton/Target/PTX/PTXTranslation.h" #include "triton/Target/PTX/TmaMetadata.h" @@ -257,6 +256,7 @@ void init_triton_ir(py::module &&m) { // we load LLVM because the frontend uses LLVM.undef for // some placeholders self.getOrLoadDialect(); + self.getOrLoadDialect(); }); // .def(py::init([](){ // mlir::MLIRContext context; @@ -1958,24 +1958,6 @@ void init_triton_translation(py::module &m) { const std::vector &paths) { ::mlir::triton::addExternalLibs(op, names, paths); }); - - m.def( - "translate_llvmir_to_hsaco", - [](const std::string llvmIR, std::string gfx_arch, std::string gfx_triple, - std::string gfx_features) -> std::tuple { - // create LLVM module from C++ - llvm::LLVMContext context; - std::unique_ptr buffer = - llvm::MemoryBuffer::getMemBuffer(llvmIR.c_str()); - llvm::SMDiagnostic error; - std::unique_ptr module = - llvm::parseIR(buffer->getMemBufferRef(), error, context); - // translate module to HSACO - auto hsacoCode = triton::translateLLVMIRToHSACO( - *module, gfx_arch, gfx_triple, gfx_features); - return hsacoCode; - }, - ret::take_ownership); } void init_triton(py::module &m) { diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index da329be6267b..9961dc51cff5 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -12,6 +12,7 @@ import triton import triton._C.libtriton.triton as _triton import triton.language as tl +from triton.common.build import is_hip from triton.runtime.jit import JITFunction, TensorWrapper, reinterpret int_dtypes = ['int8', 'int16', 'int32', 'int64'] @@ -25,6 +26,13 @@ # num_ctas_list = [1, 4] if torch.cuda.get_device_capability()[0] == 9 else [1] num_ctas_list = [1] +if is_hip(): + GPU_DIALECT = "triton_gpu_rocm" + THREADS_PER_WARP = 64 +else: + GPU_DIALECT = "triton_gpu" + THREADS_PER_WARP = 32 + def _bitwidth(dtype: str) -> int: # ex.: "int64" -> 64 @@ -137,7 +145,7 @@ def __init__(self, version, warps_per_cta, ctas_per_cga, cta_split_num, cta_orde self.instr_shape = str(instr_shape) def __str__(self): - return f"#triton_gpu.mma<{{versionMajor={self.version[0]}, versionMinor={self.version[1]}, warpsPerCTA={self.warps_per_cta}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}, instrShape={self.instr_shape}}}>" + return f"#{GPU_DIALECT}.mma<{{versionMajor={self.version[0]}, versionMinor={self.version[1]}, warpsPerCTA={self.warps_per_cta}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}, instrShape={self.instr_shape}}}>" class BlockedLayout: @@ -151,7 +159,7 @@ def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order, ctas self.cta_order = str(cta_order) def __str__(self): - return f"#triton_gpu.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>" + return f"#{GPU_DIALECT}.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>" class SharedLayout: @@ -165,7 +173,7 @@ def __init__(self, vec, per_phase, max_phase, order, ctas_per_cga, cta_split_num self.cta_order = str(cta_order) def __str__(self): - return f"#triton_gpu.shared<{{vec={self.vec}, perPhase={self.per_phase}, maxPhase={self.max_phase}, order={self.order}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>" + return f"#{GPU_DIALECT}.shared<{{vec={self.vec}, perPhase={self.per_phase}, maxPhase={self.max_phase}, order={self.order}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>" @pytest.mark.parametrize("dtype_x", list(dtypes) + ["bfloat16"]) @@ -851,6 +859,8 @@ def test_abs(dtype_x, device): @pytest.mark.parametrize("in_dtype", [tl.float8e4b15, tl.float8e4nv, tl.float8e5]) def test_abs_fp8(in_dtype, device): + if is_hip(): + pytest.skip('test_abs_fp8 not supported on HIP.') @triton.jit def abs_kernel(X, Z, SIZE: tl.constexpr): @@ -1056,6 +1066,9 @@ def noinline_multi_values_fn(x, y, Z): @pytest.mark.parametrize("mode", ["simple", "call_graph", "shared", "dynamic", "multi_values"]) def test_noinline(mode, device): + if is_hip() and mode == "shared": + pytest.skip('test_noinline["shared"] not supported on HIP.') + @triton.jit def kernel(X, Y, Z): x = tl.load(X) @@ -1141,6 +1154,9 @@ def kernel(X, Z): else: np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) sem_str = "acq_rel" if sem is None else sem + if is_hip(): + return + assert f"atom.global.gpu.{sem_str}" in h.asm["ptx"] @@ -1232,6 +1248,8 @@ def serialized_add(data, Lock, SEM: tl.constexpr): h = serialized_add[(64,)](data, Lock, SEM=sem, num_ctas=num_ctas) sem_str = "acq_rel" if sem is None else sem np.testing.assert_allclose(to_numpy(data), to_numpy(ref)) + if is_hip(): + return assert f"atom.global.{sem_str}" in h.asm["ptx"] @@ -1261,6 +1279,9 @@ def test_cast(dtype_x, dtype_z, bitcast, num_ctas, device): check_type_supported(dtype_x, device) check_type_supported(dtype_z, device) + if is_hip() and (dtype_z == "bfloat16"): + pytest.skip(f'test_cast{(dtype_x, dtype_z)} cast to bfloat16 not supported on HIP.') + size = 1024 # This is tricky because numpy doesn't have bfloat, and torch doesn't have uints. if dtype_x.startswith('bfloat'): @@ -1358,7 +1379,10 @@ def kernel(in_out_ptr): for _ in range(1000): x = torch.ones((65536,), device=device, dtype=torch.float32) - kernel[(65536,)](x, num_warps=32) + if is_hip(): + kernel[(65536,)](x, num_warps=16) # threads per Warp for ROCM is 64 + else: + kernel[(65536,)](x, num_warps=32) assert torch.all(x == 2) @@ -1452,6 +1476,8 @@ def test_fp8_fpN_roundtrip(in_dtype, out_dtype, device): """ check_type_supported(in_dtype, device) check_type_supported(out_dtype, device) + if is_hip(): + pytest.skip('test_abs_fp8 not supported on HIP.') @triton.jit def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): @@ -1507,6 +1533,9 @@ def get_reduced_dtype(dtype_str, op): def test_reduce1d(op, dtype_str, shape, num_ctas, device): check_type_supported(dtype_str, device) # bfloat16 on cc < 80 will not be tested + if is_hip(): + pytest.skip(f"test_reduce1d not supported on HIP") + # triton kernel @triton.jit def kernel(X, Z, BLOCK: tl.constexpr): @@ -1597,7 +1626,10 @@ def kernel(X, Z, BLOCK: tl.constexpr): def test_reduce2d(op, dtype_str, shape, axis, num_ctas, device): check_type_supported(dtype_str, device) # bfloat16 on cc < 80 will not be tested + if is_hip(): + pytest.skip(f"test_reduce2d not supported on HIP") # triton kernel + @triton.jit def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr): range_m = tl.arange(0, BLOCK_M) @@ -1667,6 +1699,8 @@ def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexp @pytest.mark.parametrize("op, dtype_str, shape, axis, num_warps", scan_configs) def test_scan2d(op, dtype_str, shape, axis, num_warps, device): + if is_hip(): + pytest.skip("test_scan2d is not supported in HIP") check_type_supported(dtype_str, device) # triton kernel @@ -1720,6 +1754,9 @@ def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexp @pytest.mark.parametrize("src_layout", scan_layouts) @pytest.mark.parametrize("axis", [0, 1]) def test_scan_layouts(M, N, src_layout, axis, device): + if is_hip(): + pytest.skip("test_scan_layouts is not supported in HIP") + ir = f""" #blocked = {src_layout} module attributes {{"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32}} {{ @@ -1783,6 +1820,9 @@ def test_scan_layouts(M, N, src_layout, axis, device): @pytest.mark.parametrize("src_layout", layouts) @pytest.mark.parametrize("axis", [0, 1]) def test_reduce_layouts(M, N, src_layout, axis, device): + if is_hip(): + pytest.skip("test_reduce_layouts is not supported in HIP") + rdims_2d = f"1x{N}" if axis == 0 else f"{M}x1" rdims_1d = f"{N}" if axis == 0 else f"{M}" store_range = "%7" if axis == 0 else "%1" @@ -1792,28 +1832,28 @@ def test_reduce_layouts(M, N, src_layout, axis, device): #src = {src_layout} module attributes {{"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32}} {{ tt.func public @kernel_0d1d2c3d4c(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: i32 {{tt.divisibility = 16 : i32}}, %arg2: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ - %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>> - %1 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>>) -> tensor<{M}x1xi32, #blocked> + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #blocked}}>> + %1 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #blocked}}>>) -> tensor<{M}x1xi32, #blocked> %2 = tt.splat %arg1 : (i32) -> tensor<{M}x1xi32, #blocked> %3 = arith.muli %1, %2 : tensor<{M}x1xi32, #blocked> %4 = tt.splat %arg0 : (!tt.ptr) -> tensor<{M}x1x!tt.ptr, #blocked> %5 = tt.addptr %4, %3 : tensor<{M}x1x!tt.ptr, #blocked>, tensor<{M}x1xi32, #blocked> - %6 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #blocked}}>> - %7 = tt.expand_dims %6 {{axis = 0 : i32}} : (tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #blocked}}>>) -> tensor<1x{N}xi32, #blocked> + %6 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #{GPU_DIALECT}.slice<{{dim = 0, parent = #blocked}}>> + %7 = tt.expand_dims %6 {{axis = 0 : i32}} : (tensor<{N}xi32, #{GPU_DIALECT}.slice<{{dim = 0, parent = #blocked}}>>) -> tensor<1x{N}xi32, #blocked> %8 = tt.broadcast %5 : (tensor<{M}x1x!tt.ptr, #blocked>) -> tensor<{M}x{N}x!tt.ptr, #blocked> %9 = tt.broadcast %7 : (tensor<1x{N}xi32, #blocked>) -> tensor<{M}x{N}xi32, #blocked> %10 = tt.addptr %8, %9 : tensor<{M}x{N}x!tt.ptr, #blocked>, tensor<{M}x{N}xi32, #blocked> %11 = tt.splat %arg2 : (!tt.ptr) -> tensor<{rdims_2d}x!tt.ptr, #blocked> %12 = tt.addptr %11, {store_range} : tensor<{rdims_2d}x!tt.ptr, #blocked>, tensor<{rdims_2d}xi32, #blocked> %13 = tt.load %10 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}xi32, #blocked> - %14 = triton_gpu.convert_layout %13 : (tensor<{M}x{N}xi32, #blocked>) -> tensor<{M}x{N}xi32, #src> + %14 = {GPU_DIALECT}.convert_layout %13 : (tensor<{M}x{N}xi32, #blocked>) -> tensor<{M}x{N}xi32, #src> %15 = "tt.reduce"(%14) ({{ ^bb0(%arg3: i32, %arg4: i32): %17 = arith.addi %arg3, %arg4 : i32 tt.reduce.return %17 : i32 - }}) {{axis = {axis} : i32}} : (tensor<{M}x{N}xi32, #src>) -> tensor<{rdims_1d}xi32, #triton_gpu.slice<{{dim = {axis}, parent = #src}}>> - %18 = triton_gpu.convert_layout %15 : (tensor<{rdims_1d}xi32, #triton_gpu.slice<{{dim = {axis}, parent = #src}}>>) -> tensor<{rdims_1d}xi32, #triton_gpu.slice<{{dim = {axis}, parent = #blocked}}>> - %19 = tt.expand_dims %18 {{axis = {axis} : i32}} : (tensor<{rdims_1d}xi32, #triton_gpu.slice<{{dim = {axis}, parent = #blocked}}>>) -> tensor<{rdims_2d}xi32, #blocked> + }}) {{axis = {axis} : i32}} : (tensor<{M}x{N}xi32, #src>) -> tensor<{rdims_1d}xi32, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #src}}>> + %18 = {GPU_DIALECT}.convert_layout %15 : (tensor<{rdims_1d}xi32, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #src}}>>) -> tensor<{rdims_1d}xi32, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #blocked}}>> + %19 = tt.expand_dims %18 {{axis = {axis} : i32}} : (tensor<{rdims_1d}xi32, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #blocked}}>>) -> tensor<{rdims_2d}xi32, #blocked> tt.store %12, %19 {{cache = 1 : i32, evict = 1 : i32}} : tensor<{rdims_2d}xi32, #blocked> tt.return }} @@ -1854,17 +1894,20 @@ def test_reduce_layouts(M, N, src_layout, axis, device): @pytest.mark.parametrize("M", [32, 64, 128, 256]) @pytest.mark.parametrize("src_layout", layouts) def test_store_op(M, src_layout, device): + if is_hip(): + pytest.skip("test_convert1d is not supported yet in HIP") + ir = f""" #src = {src_layout} - module attributes {{"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32}} {{ + module attributes {{"{GPU_DIALECT}.num-warps" = 4 : i32, "{GPU_DIALECT}.num-ctas" = 1 : i32, "{GPU_DIALECT}.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ tt.func public @kernel(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ - %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>> - %1 = tt.splat %arg0 : (!tt.ptr) -> tensor<{M}x!tt.ptr, #triton_gpu.slice<{{dim = 1, parent = #src}}>> - %2 = tt.addptr %1, %0 : tensor<{M}x!tt.ptr, #triton_gpu.slice<{{dim = 1, parent = #src}}>>, tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>> - %3 = tt.load %2 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}xf32, #triton_gpu.slice<{{dim = 1, parent = #src}}>> - %4 = tt.expand_dims %3 {{axis = 1 : i32}} : (tensor<{M}xf32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xf32, #src> - %5 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>> - %6 = tt.expand_dims %5 {{axis = 1 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xi32, #src> + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> + %1 = tt.splat %arg0 : (!tt.ptr) -> tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> + %2 = tt.addptr %1, %0 : tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>, tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> + %3 = tt.load %2 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}xf32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> + %4 = tt.expand_dims %3 {{axis = 1 : i32}} : (tensor<{M}xf32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xf32, #src> + %5 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> + %6 = tt.expand_dims %5 {{axis = 1 : i32}} : (tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xi32, #src> %7 = tt.splat %arg1 : (!tt.ptr) -> tensor<{M}x1x!tt.ptr, #src> %8 = tt.addptr %7, %6 : tensor<{M}x1x!tt.ptr, #src>, tensor<{M}x1xi32, #src> tt.store %8, %4 : tensor<{M}x1xf32, #src> @@ -1903,20 +1946,23 @@ def test_store_op(M, src_layout, device): @pytest.mark.parametrize("src_dim", [0, 1]) @pytest.mark.parametrize("dst_dim", [0, 1]) def test_convert1d(M, src_layout, dst_layout, src_dim, dst_dim, device): + if is_hip(): + pytest.skip("test_convert1d is not supported in HIP") + ir = f""" #dst = {dst_layout} #src = {src_layout} - module attributes {{"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32}} {{ + module attributes {{"{GPU_DIALECT}.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ tt.func public @kernel(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ - %0 = tt.splat %arg0 : (!tt.ptr) -> tensor<{M}x!tt.ptr, #triton_gpu.slice<{{dim = {src_dim}, parent = #src}}>> - %1 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = {src_dim}, parent = #src}}>> - %2 = tt.addptr %0, %1 : tensor<{M}x!tt.ptr, #triton_gpu.slice<{{dim = {src_dim}, parent = #src}}>>, tensor<{M}xi32, #triton_gpu.slice<{{dim = {src_dim}, parent = #src}}>> - %3 = tt.load %2 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = {src_dim}, parent = #src}}>> - %4 = tt.splat %arg1 : (!tt.ptr) -> tensor<{M}x!tt.ptr, #triton_gpu.slice<{{dim = {dst_dim}, parent = #dst}}>> - %5 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = {dst_dim}, parent = #dst}}>> - %6 = tt.addptr %4, %5 : tensor<{M}x!tt.ptr, #triton_gpu.slice<{{dim = {dst_dim}, parent = #dst}}>>, tensor<{M}xi32, #triton_gpu.slice<{{dim = {dst_dim}, parent = #dst}}>> - %7 = triton_gpu.convert_layout %3 : (tensor<{M}xi32, #triton_gpu.slice<{{dim = {src_dim}, parent = #src}}>>) -> tensor<{M}xi32, #triton_gpu.slice<{{dim = {dst_dim}, parent = #dst}}>> - tt.store %6, %7 : tensor<{M}xi32, #triton_gpu.slice<{{dim = {dst_dim}, parent = #dst}}>> + %0 = tt.splat %arg0 : (!tt.ptr) -> tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>> + %1 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>> + %2 = tt.addptr %0, %1 : tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>>, tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>> + %3 = tt.load %2 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>> + %4 = tt.splat %arg1 : (!tt.ptr) -> tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>> + %5 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>> + %6 = tt.addptr %4, %5 : tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>>, tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>> + %7 = {GPU_DIALECT}.convert_layout %3 : (tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>>) -> tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>> + tt.store %6, %7 : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>> tt.return }} }} @@ -1962,6 +2008,9 @@ def _welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2): @pytest.mark.parametrize("op", ["sum", "max"]) @pytest.mark.parametrize("first_axis", [0, 1]) def test_chain_reduce(M, N, src_layout, op, device, first_axis): + if is_hip(): + pytest.skip("test_chain_reduce is not supported in HIP") + op_str = "" if op == "sum": op_str = f""" @@ -1969,19 +2018,19 @@ def test_chain_reduce(M, N, src_layout, op, device, first_axis): tt.reduce.return %13 : i32""" elif op == "max": op_str = f""" - %13 = "triton_gpu.cmpi"(%arg2, %arg3) <{{predicate = 4 : i64}}> : (i32, i32) -> i1 + %13 = "{GPU_DIALECT}.cmpi"(%arg2, %arg3) <{{predicate = 4 : i64}}> : (i32, i32) -> i1 %14 = arith.select %13, %arg2, %arg3 : i32 tt.reduce.return %14 : i32""" ir = f""" #src = {src_layout} - module attributes {{"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32}} {{ + module attributes {{"{GPU_DIALECT}.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ tt.func public @sum_kernel_0d1d(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #src> - %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>> - %1 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xi32, #src> + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> + %1 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xi32, #src> %2 = arith.muli %1, %cst : tensor<{M}x1xi32, #src> - %3 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>> - %4 = tt.expand_dims %3 {{axis = 0 : i32}} : (tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>) -> tensor<1x{N}xi32, #src> + %3 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #{GPU_DIALECT}.slice<{{dim = 0, parent = #src}}>> + %4 = tt.expand_dims %3 {{axis = 0 : i32}} : (tensor<{N}xi32, #{GPU_DIALECT}.slice<{{dim = 0, parent = #src}}>>) -> tensor<1x{N}xi32, #src> %5 = tt.broadcast %2 : (tensor<{M}x1xi32, #src>) -> tensor<{M}x{N}xi32, #src> %6 = tt.broadcast %4 : (tensor<1x{N}xi32, #src>) -> tensor<{M}x{N}xi32, #src> %7 = arith.addi %5, %6 : tensor<{M}x{N}xi32, #src> @@ -1991,11 +2040,11 @@ def test_chain_reduce(M, N, src_layout, op, device, first_axis): %11 = "tt.reduce"(%10) ({{ ^bb0(%arg2: i32, %arg3: i32): {op_str} - }}) {{axis = {first_axis} : i32}} : (tensor<{M}x{N}xi32, #src>) -> tensor<{M if first_axis == 1 else N}xi32, #triton_gpu.slice<{{dim = {first_axis}, parent = #src}}>> + }}) {{axis = {first_axis} : i32}} : (tensor<{M}x{N}xi32, #src>) -> tensor<{M if first_axis == 1 else N}xi32, #{GPU_DIALECT}.slice<{{dim = {first_axis}, parent = #src}}>> %12 = "tt.reduce"(%11) ({{ ^bb0(%arg2: i32, %arg3: i32): {op_str} - }}) {{axis = 0 : i32}} : (tensor<{M if first_axis == 1 else N}xi32, #triton_gpu.slice<{{dim = {first_axis}, parent = #src}}>>) -> i32 + }}) {{axis = 0 : i32}} : (tensor<{M if first_axis == 1 else N}xi32, #{GPU_DIALECT}.slice<{{dim = {first_axis}, parent = #src}}>>) -> i32 tt.store %arg1, %12 {{cache = 1 : i32, evict = 1 : i32}} : i32 tt.return }} @@ -2063,6 +2112,8 @@ def var_mean_kernel(X, out_mean, out_var, BLOCK: tl.constexpr): @pytest.mark.parametrize("num_ctas", num_ctas_list) def test_permute(dtype_str, shape, perm, num_ctas, device): check_type_supported(dtype_str, device) # bfloat16 on cc < 80 will not be tested + if is_hip(): + pytest.skip(f"test_permute is not supported in HIP") # triton kernel @triton.jit @@ -2099,6 +2150,10 @@ def kernel(X, stride_xm, stride_xn, # compare np.testing.assert_allclose(to_numpy(z_tri), z_ref) np.testing.assert_allclose(to_numpy(z_tri_contiguous), z_ref) + + if is_hip(): + return + # parse ptx to make sure ld/st are vectorized ptx = pgm.asm['ptx'] assert 'ld.global.v4' in ptx @@ -2115,7 +2170,7 @@ def kernel(X, stride_xm, stride_xn, @pytest.mark.parametrize("M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, out_dtype", [(*shape, 4, False, False, epilogue, allow_tf32, in_dtype, out_dtype) - for shape in [(64, 64, 64), (16, 16, 16)] + for shape in [(64, 64, 64), (32, 32, 32), (16, 16, 16)] for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols', 'softmax', 'chain-dot'] for allow_tf32 in [True, False] for in_dtype, out_dtype in [('float16', 'float16'), @@ -2146,6 +2201,17 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o check_cuda_only(device) capability = torch.cuda.get_device_capability() + + if is_hip(): + # set capability to large number to jump over check below + # check are not relevant to amd gpu, left them for smaller diff between test_core.py and test_core_amd.py tests + capability = (100, 100) + if out_dtype is None: + if in_dtype in float_dtypes: + out_dtype = "float32" + else: + out_dtype = "int32" + if capability[0] < 7: pytest.skip("Only test tl.dot() on devices with sm >= 70") if capability[0] < 8: @@ -2160,6 +2226,16 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o # TODO: support out_dtype=float16 for tl.dot on V100 pytest.skip("Only test out_dtype=float16 on devices with sm >=80") + if is_hip(): + if (M, N, K) in [(64, 128, 128)]: + pytest.skip(f"test_dot{(M, N, K)} not supported on HIP: memory out of resource.") + if (M, N, K, num_warps) in [(128, 256, 32, 8), (128, 128, 64, 4)]: + pytest.skip(f"test_dot{(M, N, K)} not supported on HIP. Reduce Warp to work") + if M == 16 or N == 16 or K == 16: + pytest.skip(f"test_dot{(M, N, K)} segfaults on HIP") + if epilogue == "softmax": + pytest.skip(f"test_dot{epilogue} segfaults on HIP") + torch.backends.cuda.matmul.allow_tf32 = allow_tf32 if num_ctas > 1 and in_dtype == 'int8': @@ -2247,6 +2323,7 @@ def kernel(X, stride_xm, stride_xk, out_dtype = tl.float16 else: out_dtype = tl.float32 + pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1), y_tri, y_tri.stride(0), y_tri.stride(1), w_tri, w_tri.stride(0), w_tri.stride(1), @@ -2261,20 +2338,24 @@ def kernel(X, stride_xm, stride_xk, ALLOW_TF32=allow_tf32, num_warps=num_warps, num_ctas=num_ctas, out_dtype=out_dtype) + if epilogue == 'softmax' and (in_dtype != 'float32' or allow_tf32): - ptx = pgm.asm["ptx"] - start = ptx.find("shfl.sync") - end = ptx.find("cvt.rn.f16.f32") - red_code = ptx[start:end] - assert len(red_code) > 0 - import os - enable_mmav3 = os.environ.get('ENABLE_MMA_V3', 'not found').lower() - enable_tma = os.environ.get('ENABLE_TMA', 'not found').lower() - # skip this check on hopper because there are some functions whose name contain "shared" in ptx. - # TODO: we should eliminate these unused functions in ptx code. - if not (enable_mmav3 in ["on", "true", "1"] and enable_tma in ["on", "true", "1"]): - assert "shared" not in red_code - assert "bar.sync" not in red_code + if is_hip(): + pass + else: + ptx = pgm.asm["ptx"] + start = ptx.find("shfl.sync") + end = ptx.find("cvt.rn.f16.f32") + red_code = ptx[start:end] + assert len(red_code) > 0 + import os + enable_mmav3 = os.environ.get('ENABLE_MMA_V3', 'not found').lower() + enable_tma = os.environ.get('ENABLE_TMA', 'not found').lower() + # skip this check on hopper because there are some functions whose name contain "shared" in ptx. + # TODO: we should eliminate these unused functions in ptx code. + if not (enable_mmav3 in ["on", "true", "1"] and enable_tma in ["on", "true", "1"]): + assert "shared" not in red_code + assert "bar.sync" not in red_code # torch result if in_dtype == 'int8': z_ref = np.matmul(x.astype(np.float32), @@ -2300,9 +2381,12 @@ def kernel(X, stride_xm, stride_xk, # XXX: Somehow there's a larger difference when we use float32 np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3) elif out_dtype == tl.float16: - np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3) + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-2) else: - np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) + # added atol, to loose precision for float16xfloat16->float32 case + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3) + if is_hip(): + return # make sure ld/st are vectorized ptx = pgm.asm['ptx'] if (K > 16 or N > 16 or M > 16) and (M * N // (num_warps * 32) >= 4): @@ -2366,6 +2450,9 @@ def kernel(Z, X, Y, h = kernel[grid](z_tri, x_tri, y_tri, M, N, K, BM, BN, BK) z_ref = np.matmul(x, y) np.testing.assert_allclose(z_ref, to_numpy(z_tri), atol=0.01) + + if is_hip(): + return assert "tt.dot" in h.asm['ttir'] # with option ENABLE_MMA_V3 on, we will not pipeline the load op for Y # as the loaded value is in rowmajor. But MMAv3 requires it's second @@ -2432,6 +2519,9 @@ def test_dot_without_load(dtype_str, device): capability = torch.cuda.get_device_capability() allow_tf32 = capability[0] > 7 + if is_hip() and dtype_str == "float16": + pytest.skip("test_dot_without_load[float16] not supported in HIP") + @triton.jit def _kernel(out, ALLOW_TF32: tl.constexpr): a = GENERATE_TEST_HERE @@ -2512,6 +2602,9 @@ def _kernel(in_ptr, out_ptr, in_size: tl.constexpr, out_size: tl.constexpr): # FIXME: Shape too small for ldmatrix when num_ctas=4 @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) def test_masked_load_shared_memory(dtype, device): + if is_hip(): + pytest.skip("test_masked_load_shared_memory is not supported in HIP") + check_type_supported(dtype, device) # bfloat16 on cc < 80 will not be tested M = 32 @@ -2571,6 +2664,9 @@ def _kernel(dst, src, CACHE: tl.constexpr): tl.store(dst + offsets, x) pgm = _kernel[(1,)](dst, src, CACHE=cache) + if is_hip(): + return + ptx = pgm.asm['ptx'] if cache == '': assert 'ld.global.ca' not in ptx @@ -2597,6 +2693,10 @@ def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr): tl.store(dst + offsets, x, mask=offsets < N) pgm = _kernel[(1,)]( dst, src, N=N, BLOCK_SIZE=block_size) + + if is_hip(): + return + ptx = pgm.asm["ptx"] if N % 16 == 0: assert "ld.global.v4.b32" in ptx @@ -2620,6 +2720,9 @@ def _kernel(dst, src, off, N, BLOCK_SIZE: tl.constexpr, HINT: tl.constexpr): x = tl.load(src + offsets, mask=offsets < N) tl.store(dst + offsets, x, mask=offsets < N) pgm = _kernel[(1,)](dst, src, off, N=1024, BLOCK_SIZE=src.shape[0], HINT=has_hints) + if is_hip(): + return + ptx = pgm.asm["ptx"] if has_hints: assert "ld.global.v4.b32" in ptx @@ -2642,6 +2745,8 @@ def _kernel(dst, src, CACHE: tl.constexpr): x = tl.load(src + offsets) tl.store(dst + offsets, x, cache_modifier=CACHE) + if is_hip(): + return pgm = _kernel[(1,)](dst, src, CACHE=cache) ptx = pgm.asm['ptx'] if cache == '': @@ -2793,6 +2898,9 @@ def kernel(VALUE, X): @pytest.mark.parametrize("is_lhs_constexpr", [False, True]) @pytest.mark.parametrize("is_rhs_constexpr", [True, False]) def test_bin_op_constexpr(op, is_lhs_constexpr, is_rhs_constexpr, device): + if is_hip(): + if (is_rhs_constexpr, is_lhs_constexpr, op) in [(False, False, "<<"), (False, False, ">>"), (False, True, "<<")]: + pytest.skip(f"test_bin_op_constexpr[{is_lhs_constexpr}-{is_rhs_constexpr}-{op}] is not supported in HIP") @triton.jit def kernel(Z, X, Y): @@ -2968,6 +3076,9 @@ def _kernel(dst): @pytest.mark.parametrize("num_ctas", num_ctas_list) def test_math_tensor(dtype_str, expr, lib_path, num_ctas, device): + if is_hip() and expr == "math.scalbn": + pytest.skip("test_math_tensor[math.scalbn] is not supported in HIP") + @triton.jit def kernel(X, Y, BLOCK: tl.constexpr): x = tl.load(X + tl.arange(0, BLOCK)) @@ -3063,6 +3174,9 @@ def kernel(X, Y, BLOCK: tl.constexpr): def test_inline_asm(num_ctas, device): check_cuda_only(device) + if is_hip(): + pytest.skip("test_inline_asm is not supported in HIP") + @triton.jit def kernel(X, Y, Z, n: tl.constexpr, BLOCK: tl.constexpr): x = tl.load(X + tl.arange(0, BLOCK)) @@ -3089,6 +3203,9 @@ def kernel(X, Y, Z, n: tl.constexpr, BLOCK: tl.constexpr): def test_inline_asm_packed(num_ctas, device): check_cuda_only(device) + if is_hip(): + pytest.skip("test_inline_asm is not supported in HIP") + @triton.jit def kernel(X, Y, BLOCK: tl.constexpr): x = tl.load(X + tl.arange(0, BLOCK)) @@ -3392,6 +3509,8 @@ def nested_while(data, countPtr): def test_globaltimer(device): + if is_hip(): + pytest.skip("test_globaltimer is not supported in HIP") check_cuda_only(device) @triton.jit @@ -3411,6 +3530,8 @@ def kernel(Out1, Out2): def test_smid(device): + if is_hip(): + pytest.skip("test_smid is not supported in HIP") check_cuda_only(device) @triton.jit @@ -3456,6 +3577,9 @@ def kernel(Out): @pytest.mark.parametrize("interm_layout", intermediate_layouts) @pytest.mark.parametrize("dst_layout", layouts) def test_convert2d(dtype, shape, src_layout, interm_layout, dst_layout, device): + if is_hip(): + pytest.skip("test_convert2d is not supported in HIP") + if str(src_layout) == str(dst_layout): pytest.skip() if 'mma' in str(src_layout) and 'mma' in str(dst_layout): diff --git a/python/triton/common/backend.py b/python/triton/common/backend.py index eaa15671fccf..5b60c1377caa 100644 --- a/python/triton/common/backend.py +++ b/python/triton/common/backend.py @@ -5,6 +5,7 @@ import os import re import subprocess +import traceback from typing import Dict from ..runtime.driver import DriverBase @@ -94,7 +95,7 @@ def get_backend(device_type: str): try: importlib.import_module(device_backend_package_name, package=__spec__.name) except Exception: - return None + traceback.print_exc() else: return None return _backends[device_type] if device_type in _backends else None diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index d252f3f04133..51c00add462a 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -5,18 +5,18 @@ import json import os import re -import subprocess import tempfile from collections import namedtuple from pathlib import Path -from typing import Any, Tuple +from typing import Any from .._C.libtriton.triton import (ClusterInfo, TMAInfos, add_external_libs, compile_ptx_to_cubin, get_env_vars, get_num_warps, get_shared_memory_size, ir, runtime, - translate_llvmir_to_hsaco, translate_llvmir_to_ptx, + translate_llvmir_to_ptx, translate_triton_gpu_to_llvmir) from ..common.backend import get_backend, path_to_ptxas +from ..common.build import is_hip # from ..runtime import driver, jit, JITFunction # TODO: runtime.errors from ..runtime.autotuner import OutOfResources @@ -188,72 +188,6 @@ def ptx_to_cubin(ptx: str, arch: int): return compile_ptx_to_cubin(ptx, ptxas, arch) -# AMDGCN translation - -def get_amdgcn_bitcode_paths(arch): - gpu_arch_agnostic_bitcode_libraries = ["opencl.bc", - "ocml.bc", - "ockl.bc", - "oclc_finite_only_off.bc", - "oclc_daz_opt_off.bc", - "oclc_correctly_rounded_sqrt_on.bc", - "oclc_unsafe_math_off.bc", - "oclc_wavefrontsize64_on.bc"] - - gfx_arch = arch[1] - gfx_arch_id = re.search('gfx(\\w+)', gfx_arch).group(1).strip() - - gpu_arch_specific_bitcode_library = 'oclc_isa_version_' + gfx_arch_id + ".bc" - bitcode_path_dir = os.path.join(Path(__file__).parent.resolve(), "third_party/rocm/lib/bitcode/") - - amdgcn_bitcode_paths = {} - i = 1 - for bc_lib in gpu_arch_agnostic_bitcode_libraries: - bc_path = bitcode_path_dir + bc_lib - if os.path.exists(bc_path): - amdgcn_bitcode_paths['library_' + str(i)] = bc_path - i += 1 - bc_gfx_path = bitcode_path_dir + gpu_arch_specific_bitcode_library - if os.path.exists(bc_gfx_path): - amdgcn_bitcode_paths['library_' + str(i)] = bc_gfx_path - - return amdgcn_bitcode_paths - - -def get_amdgpu_arch_fulldetails(): - """ - get the amdgpu fulll ISA details for compiling: - i.e., arch_triple: amdgcn-amd-amdhsa; arch_name: gfx906; arch_features: sramecc+:xnack- - """ - try: - # TODO: package rocm.cc with Triton - rocm_path_dir = os.getenv("ROCM_PATH", default="/opt/rocm") - rocminfo = subprocess.check_output(rocm_path_dir + '/bin/rocminfo').decode() - gfx_arch_details = re.search('amd.*', rocminfo).group(0).strip().split('--') - arch_triple = gfx_arch_details[0] - arch_name_features = gfx_arch_details[1].split(':') - arch_name = arch_name_features[0] - arch_features = "" - - if (len(arch_name_features) == 3): - arch_features = "+" + re.search('\\w+', arch_name_features[1]).group(0) + ","\ - "-" + re.search('\\w+', arch_name_features[2]).group(0) - return [arch_triple, arch_name, arch_features] - except BaseException: - return None - - -def llir_to_amdgcn_and_hsaco(mod: Any, gfx_arch: str, gfx_triple: str, gfx_features: str) -> Tuple[str, str]: - ''' - Translate TritonGPU module to HSACO code based on full details of gpu architecture. - :param mod: a TritonGPU dialect module - :return: - - AMDGCN code - - Path to HSACO object - ''' - return translate_llvmir_to_hsaco(mod, gfx_arch, gfx_triple, gfx_features) - - # ------------------------------------------------------------------------------ # compiler # ------------------------------------------------------------------------------ @@ -320,8 +254,10 @@ def make_hash(fn, arch, env_vars, **kwargs): "ttgir": mlir_arg_type_pattern, "ptx": ptx_arg_type_pattern, } - -ttgir_num_warps_pattern = r'"triton_gpu.num-warps"\s?=\s?(\d+)\s?:' +if is_hip(): + ttgir_num_warps_pattern = r'"triton_gpu_rocm.num-warps"\s?=\s?(\d+)\s?:' +else: + ttgir_num_warps_pattern = r'"triton_gpu.num-warps"\s?=\s?(\d+)\s?:' def _get_jsonable_constants(constants): @@ -354,17 +290,10 @@ def _is_cuda(arch): def get_architecture_descriptor(capability): - try: - import torch - except ImportError: - raise ImportError("Triton requires PyTorch to be installed") if capability is None: - if torch.version.hip is None: - device = get_current_device() - capability = get_device_capability(device) - capability = capability[0] * 10 + capability[1] - else: - capability = get_amdgpu_arch_fulldetails() + device = get_current_device() + capability = get_device_capability(device) + capability = capability[0] * 10 + capability[1] return capability @@ -394,23 +323,6 @@ def get_arch_default_num_stages(device_type, capability=None): return num_stages -def add_rocm_stages(arch, extern_libs, stages): - extern_libs.update(get_amdgcn_bitcode_paths(arch)) - - for key in list(extern_libs): - if extern_libs[key] == '' or extern_libs[key] is None: - extern_libs.pop(key) - - gfx_arch_full_details = arch - gfx_arch = os.environ.get('MI_GPU_ARCH', gfx_arch_full_details[1]) - if gfx_arch is None: - raise RuntimeError('gfx_arch is None (not specified)') - stages["amdgcn"] = (lambda path: Path(path).read_text(), - lambda src: llir_to_amdgcn_and_hsaco(src, gfx_arch, - gfx_arch_full_details[0], - gfx_arch_full_details[2])) - - def add_cuda_stages(arch, extern_libs, stages): stages["ptx"] = (lambda path: Path(path).read_text(), @@ -422,10 +334,13 @@ def add_cuda_stages(arch, extern_libs, stages): def compile(fn, **kwargs): # Get device type to decide which backend should be used device_type = kwargs.get("device_type", "cuda") - _device_backend = get_backend(device_type) capability = kwargs.get("cc", None) - if device_type in ["cuda", "hip"]: + if is_hip(): + device_type = "hip" + + if device_type == "cuda": + _device_backend = get_backend(device_type) arch = get_architecture_descriptor(capability) else: _device_backend = get_backend(device_type) @@ -433,7 +348,8 @@ def compile(fn, **kwargs): arch = _device_backend.get_architecture_descriptor(**kwargs) is_cuda = device_type == "cuda" and _is_cuda(arch) - is_hip = device_type in ["cuda", "hip"] and not is_cuda + if is_hip(): + is_cuda = False context = ir.context() constants = kwargs.get("constants", dict()) num_warps = kwargs.get("num_warps", get_arch_default_num_warps(device_type)) @@ -464,14 +380,20 @@ def compile(fn, **kwargs): stages["ast"] = (lambda path: fn, None) stages["ttir"] = (lambda path: parse_mlir_module(path, context), lambda src: optimize_ttir(ast_to_ttir(src, signature, configs[0], constants, debug=debug, arch=arch), arch)) - stages["ttgir"] = (lambda path: parse_mlir_module(path, context), - lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps, num_ctas, arch), num_stages, num_warps, num_ctas, arch, cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue)) - stages["llir"] = (lambda path: Path(path).read_text(), - lambda src: ttgir_to_llir(src, extern_libs, arch, tma_infos)) if is_cuda: + stages["ttgir"] = (lambda path: parse_mlir_module(path, context), + lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps, num_ctas, arch), num_stages, num_warps, num_ctas, arch, cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue)) + stages["llir"] = (lambda path: Path(path).read_text(), + lambda src: ttgir_to_llir(src, extern_libs, arch, tma_infos)) add_cuda_stages(arch, extern_libs, stages) - elif is_hip: - add_rocm_stages(arch, extern_libs, stages) + elif device_type == "hip": + _device_backend.add_stages(arch, extern_libs, stages, num_warps=num_warps, num_stages=num_stages) + elif device_type == "xpu": + stages["ttgir"] = (lambda path: parse_mlir_module(path, context), + lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps, num_ctas, arch), num_stages, num_warps, num_ctas, arch, cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue)) + stages["llir"] = (lambda path: Path(path).read_text(), + lambda src: ttgir_to_llir(src, extern_libs, arch, tma_infos)) + _device_backend.add_stages(arch, extern_libs, stages) else: # pass the user's configuration to the backend device. arch["num_warps"] = num_warps @@ -564,7 +486,7 @@ def compile(fn, **kwargs): path = metadata_group.get(ir_filename) if path is None: next_module = compile_kernel(module) - if ir == "amdgcn": + if ir_name == "amdgcn": extra_file_name = f"{name}.hsaco_path" metadata_group[ir_filename] = fn_cache_manager.put(next_module[0], ir_filename) metadata_group[extra_file_name] = fn_cache_manager.put(next_module[1], extra_file_name) @@ -587,17 +509,23 @@ def compile(fn, **kwargs): else: asm[ir_name] = str(next_module) if ir_name == "llir" and "shared" not in metadata: - metadata["shared"] = get_shared_memory_size(module) + if is_hip(): + metadata["shared"] = _device_backend.get_shared_memory_size(module) + else: + metadata["shared"] = get_shared_memory_size(module) if ir_name == "ttgir": - metadata["enable_warp_specialization"] = ir.is_ws_supported(next_module) - if metadata["enable_warp_specialization"]: - metadata["num_warps"] = get_num_warps(next_module) + if is_hip(): + metadata["num_warps"] = _device_backend.get_num_warps(next_module) + else: + metadata["enable_warp_specialization"] = ir.is_ws_supported(next_module) + if metadata["enable_warp_specialization"]: + metadata["num_warps"] = get_num_warps(next_module) if ir_name == "ptx": metadata["name"] = get_kernel_name(next_module, pattern='// .globl') if ir_name == "amdgcn": metadata["name"] = get_kernel_name(next_module[0], pattern='.globl') asm["hsaco_path"] = next_module[1] - if not is_cuda and not is_hip: + if not is_cuda and not is_hip(): _device_backend.add_meta_info(ir_name, module, next_module, metadata, asm) module = next_module @@ -622,7 +550,7 @@ def compile(fn, **kwargs): ids_of_const_exprs = tuple(fn.constexprs) if isinstance(fn, JITFunction) else () ids = {"ids_of_tensormaps": ids_of_tensormaps, "ids_of_folded_args": ids_of_folded_args, "ids_of_const_exprs": ids_of_const_exprs} # cache manager - if is_cuda or is_hip: + if is_cuda: so_path = make_stub(name, signature, constants, ids, enable_warp_specialization=enable_warp_specialization) else: so_path = _device_backend.make_launcher_stub(name, signature, constants, ids) @@ -660,7 +588,7 @@ def __init__(self, fn, so_path, metadata, asm): self.tensormaps_info = metadata["tensormaps_info"] self.constants = metadata["constants"] self.device_type = metadata["device_type"] - self.device_backend = get_backend(self.device_type) if self.device_type not in ["cuda", "hip"] else None + self.device_backend = get_backend(self.device_type) if self.device_type not in ["cuda"] else None # initialize asm dict self.asm = asm # binaries are lazily initialized @@ -674,7 +602,7 @@ def _init_handles(self): if self.cu_module is not None: return - if self.device_type in ["cuda", "hip"]: + if self.device_type in ["cuda"]: device = get_current_device() bin_path = { driver.HIP: "hsaco_path", @@ -720,7 +648,7 @@ def __getitem__(self, grid): def runner(*args, stream=None): args_expand = self.assemble_tensormap_to_arg(args) if stream is None: - if self.device_type in ["cuda", "hip"]: + if self.device_type in ["cuda"]: stream = get_cuda_stream() else: stream = get_backend(self.device_type).get_stream(None) diff --git a/python/triton/compiler/make_launcher.py b/python/triton/compiler/make_launcher.py index c61ad9095eca..a97d7d11b57b 100644 --- a/python/triton/compiler/make_launcher.py +++ b/python/triton/compiler/make_launcher.py @@ -3,16 +3,11 @@ import tempfile from ..common import _build +from ..common.build import is_hip from ..runtime.cache import get_cache_manager from ..runtime.jit import version_key from .utils import generate_cu_signature - -def is_hip(): - import torch - return torch.version.hip is not None - - # ----- stub -------- @@ -103,151 +98,9 @@ def format_of(ty): format = "iiiiiiiiiKKOOO" + ''.join([format_of(_extracted_type(ty)) for ty in signature.values()]) # generate glue code - if is_hip(): - src = f""" - #define __HIP_PLATFORM_AMD__ - #include - #include - #include - - static inline void gpuAssert(hipError_t code, const char *file, int line) - {{ - if (code != HIP_SUCCESS) - {{ - const char* prefix = "Triton Error [HIP]: "; - const char* str = hipGetErrorString(code); - char err[1024] = {{0}}; - snprintf(err, 1024, "%s Code: %d, Messsage: %s", prefix, code, str ); - PyErr_SetString(PyExc_RuntimeError, err); - }} - }} - - #define HIP_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }} - - static void _launch(int gridX, int gridY, int gridZ, int num_warps, int shared_memory, hipStream_t stream, hipFunction_t function, {arg_decls}) {{ - void *params[] = {{ {', '.join(f"&arg{i}" for i in signature.keys() if i not in constants)} }}; - if (gridX*gridY*gridZ > 0) {{ - HIP_CHECK(hipModuleLaunchKernel(function, gridX, gridY, gridZ, 64*num_warps, 1, 1, shared_memory, stream, params, 0)); - }} - }} - - typedef struct _DevicePtrInfo {{ - hipDeviceptr_t dev_ptr; - bool valid; - }} DevicePtrInfo; - - static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{ - DevicePtrInfo ptr_info; - ptr_info.dev_ptr = 0; - ptr_info.valid = true; - - if (PyLong_Check(obj)) {{ - ptr_info.dev_ptr = (hipDeviceptr_t)PyLong_AsUnsignedLongLong(obj); - return ptr_info; - }} - - if (obj == Py_None) {{ - // valid nullptr - return ptr_info; - }} - - PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr"); - - if (ptr) {{ - PyObject *empty_tuple = PyTuple_New(0); - PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL); - Py_DECREF(empty_tuple); - Py_DECREF(ptr); - - if (!PyLong_Check(ret)) {{ - PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int"); - ptr_info.valid = false; - return ptr_info; - }} - - ptr_info.dev_ptr = (hipDeviceptr_t)PyLong_AsUnsignedLongLong(ret); - - if (!ptr_info.dev_ptr) - return ptr_info; - - uint64_t dev_ptr; - hipError_t status = hipPointerGetAttribute(&dev_ptr, HIP_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr); - if (status == hipErrorInvalidValue) {{ - PyErr_Format(PyExc_ValueError, - "Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx); - ptr_info.valid = false; - }} - - ptr_info.dev_ptr = (hipDeviceptr_t)dev_ptr; - return ptr_info; - }} - - PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method"); - ptr_info.valid = false; - return ptr_info; - }} - - static PyObject* launch(PyObject* self, PyObject* args) {{ - - int gridX, gridY, gridZ; - uint64_t _stream; - uint64_t _function; - int num_warps; - int shared_memory; - PyObject *launch_enter_hook = NULL; - PyObject *launch_exit_hook = NULL; - PyObject *compiled_kernel = NULL; - - {' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])} - if (!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &num_warps, &shared_memory, &_stream, &_function, &launch_enter_hook, &launch_exit_hook, &compiled_kernel{', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''})) {{ - return NULL; - }} - - if (launch_enter_hook != Py_None && !PyObject_CallObject(launch_enter_hook, args)) {{ - return NULL; - }} - - // raise exception asap - {"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])}; - Py_BEGIN_ALLOW_THREADS; - _launch(gridX, gridY, gridZ, num_warps, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function{', ' + ', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''}); - Py_END_ALLOW_THREADS; - - if (launch_exit_hook != Py_None && !PyObject_CallObject(launch_exit_hook, args)) {{ - return NULL; - }} - - // return None - Py_INCREF(Py_None); - return Py_None; - }} - - static PyMethodDef ModuleMethods[] = {{ - {{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}}, - {{NULL, NULL, 0, NULL}} // sentinel - }}; - - static struct PyModuleDef ModuleDef = {{ - PyModuleDef_HEAD_INIT, - \"__triton_launcher\", - NULL, //documentation - -1, //size - ModuleMethods - }}; - - PyMODINIT_FUNC PyInit___triton_launcher(void) {{ - PyObject *m = PyModule_Create(&ModuleDef); - if(m == NULL) {{ - return NULL; - }} - PyModule_AddFunctions(m, ModuleMethods); - return m; - }} - """ - else: - folded_without_constexprs = [c for c in ids['ids_of_folded_args'] if c not in ids['ids_of_const_exprs']] - params = [i for i in signature.keys() if i >= start_desc or (i not in constants and i not in folded_without_constexprs)] - src = f""" + folded_without_constexprs = [c for c in ids['ids_of_folded_args'] if c not in ids['ids_of_const_exprs']] + params = [i for i in signature.keys() if i >= start_desc or (i not in constants and i not in folded_without_constexprs)] + src = f""" #include \"cuda.h\" #include #include diff --git a/python/triton/language/math.py b/python/triton/language/math.py index c84f5f01e435..6f8b0aced0e7 100644 --- a/python/triton/language/math.py +++ b/python/triton/language/math.py @@ -1,17 +1,18 @@ import functools import os +from ..common.build import is_hip from . import core @functools.lru_cache() def libdevice_path(): - import torch third_party_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "third_party") - if torch.version.hip is None: - default = os.path.join(third_party_dir, "cuda", "lib", "libdevice.10.bc") + if is_hip(): + default = os.path.join(third_party_dir, "hip", "lib", "bitcode", "cuda2gcn.bc") else: - default = '' + default = os.path.join(third_party_dir, "cuda", "lib", "libdevice.10.bc") + return os.getenv("TRITON_LIBDEVICE_PATH", default) diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 8984cb4da4f8..8cccda9bed92 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -4,6 +4,7 @@ from typing import List, Optional, Sequence, Tuple, TypeVar from .._C.libtriton.triton import ir +from ..common.build import is_hip from . import core as tl T = TypeVar('T') @@ -1239,6 +1240,19 @@ def atomic_xchg(ptr: tl.tensor, # ===----------------------------------------------------------------------===// +def gpu_has_mfma() -> bool: + if not is_hip(): + return False + return True # mfma supported in ['gfx908', 'gfx90a'] + + +def mfma_supported(M, N, K, allow_tf32, ret_scalar_ty) -> bool: + if not gpu_has_mfma(): + return False + # TODO: Add check for configurations and types. + return True + + def dot(lhs: tl.tensor, rhs: tl.tensor, allow_tf32: bool, @@ -1292,6 +1306,32 @@ def assert_dtypes_valid(lhs_dtype, rhs_dtype, arch): M = lhs.type.shape[0] N = rhs.type.shape[1] + + # Cast operands of types f16 and i8 for configurations where FMA only supported. + if is_hip() and not mfma_supported(M, N, lhs.type.shape[1], allow_tf32, ret_scalar_ty): + ret_cast_scalar_ty = tl.float32 if lhs.type.scalar.is_int() else ret_scalar_ty + lhs = cast(lhs, ret_cast_scalar_ty, builder) + rhs = cast(rhs, ret_cast_scalar_ty, builder) + if ret_cast_scalar_ty == tl.float16: + _0 = builder.create_splat(builder.get_fp16(0), [M, N]) + else: + _0 = builder.create_splat(builder.get_fp32(0), [M, N]) + ret_ty = tl.block_type(ret_cast_scalar_ty, [M, N]) + ret = tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32), + ret_ty) + return cast(ret, ret_scalar_ty, builder) + if is_hip() and mfma_supported(M, N, lhs.type.shape[1], allow_tf32, ret_scalar_ty) and ret_scalar_ty.primitive_bitwidth < 32: + if lhs.type.scalar.is_int(): + ret_dot_scalar_ty = tl.int32 + _0 = builder.create_splat(builder.get_int32(0), [M, N]) + else: + ret_dot_scalar_ty = tl.float32 + _0 = builder.create_splat(builder.get_fp32(0), [M, N]) + ret_ty = tl.block_type(ret_dot_scalar_ty, [M, N]) + ret = tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32), + ret_ty) + return cast(ret, ret_scalar_ty, builder) + _0 = builder.create_splat(_0, [M, N]) ret_ty = tl.block_type(ret_scalar_ty, [M, N]) return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32), diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index aad0f57a66f0..062df8d993d1 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -383,20 +383,20 @@ def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, nu device_type = self._conclude_device_type(device_types, {pinned_memory_flags}) device_backend = None - if device_type not in ['cuda', 'hip']: + if device_type not in ['cuda']: device_backend = get_backend(device_type) if device_backend is None: raise ValueError('Cannot find backend for ' + device_type) if device is None: - if device_type in ['cuda', 'hip']: + if device_type in ['cuda']: device = get_current_device() set_current_device(device) else: device = device_backend.get_current_device() device_backend.set_current_device(device) if stream is None and not warmup: - if device_type in ['cuda', 'hip']: + if device_type in ['cuda']: stream = get_cuda_stream(device) else: stream = device_backend.get_stream() diff --git a/third_party/amd_hip_backend b/third_party/amd_hip_backend new file mode 160000 index 000000000000..d0ad70d55df3 --- /dev/null +++ b/third_party/amd_hip_backend @@ -0,0 +1 @@ +Subproject commit d0ad70d55df3ebe11cc80bbb364a91551e6b6248 From 9b8c48f25d51c362c8d7628c30cb0dcadbd3a2b5 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Thu, 31 Aug 2023 22:02:56 -0700 Subject: [PATCH 002/122] [BACKEND] More minor backend fixups (#2223) * Fix bug in V100 convert layout * Do not push elementwise between convert and dot for V100 --- lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp | 2 +- lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp | 3 ++- test/TritonGPU/dot-operands.mlir | 6 +++--- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index d76ac0c0cc11..1aadf5093884 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -346,7 +346,7 @@ struct ConvertLayoutOpConversion SmallVector numCTAsEachRep(rank, 1); SmallVector shapePerCTATile = getShapePerCTATile(layout, shape); SmallVector shapePerCTA = getShapePerCTA(layout, shape); - auto elemTy = type.getElementType(); + auto elemTy = getTypeConverter()->convertType(type.getElementType()); int ctaId = 0; diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index 43f411bcae7d..14a050472800 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -247,7 +247,8 @@ class TritonGPUOptimizeDotOperandsPass mlir::RewritePatternSet patterns(context); patterns.add(context); - patterns.add(context); + if (triton::gpu::TritonGPUDialect::getComputeCapability(m) >= 80) + patterns.add(context); patterns.add(context); if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) signalPassFailure(); diff --git a/test/TritonGPU/dot-operands.mlir b/test/TritonGPU/dot-operands.mlir index 1fbfaa9d4d07..ded8d0613bb9 100644 --- a/test/TritonGPU/dot-operands.mlir +++ b/test/TritonGPU/dot-operands.mlir @@ -15,7 +15,7 @@ #BLR = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> #BLC = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [0, 1]}> -module attributes {"triton_gpu.num-warps" = 4 : i32} { +module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.compute-capability" = 80} { // CHECK: tt.func @push_elementwise // CHECK: %[[ALOAD:.*]] = tt.load %arg0 @@ -69,7 +69,7 @@ tt.func @succeeds_if_arg_is_not_convert_layout( #blockedA = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> #blockedB = #triton_gpu.blocked<{sizePerThread = [2, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> #mma = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4]}> -module attributes {"triton_gpu.num-warps" = 4 : i32} { +module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.compute-capability" = 80} { // CHECK: #[[BA:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [], CTASplitNum = [], CTAOrder = []}> // CHECK: #[[BB:.*]] = #triton_gpu.blocked<{sizePerThread = [2, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [], CTASplitNum = [], CTAOrder = []}> @@ -104,7 +104,7 @@ tt.func @push_convert_both_operands( #blockedA = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> #blockedB = #triton_gpu.blocked<{sizePerThread = [2, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> #mma = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4]}> -module attributes {"triton_gpu.num-warps" = 4 : i32} { +module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.compute-capability" = 80} { // CHECK: #[[BA:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [], CTASplitNum = [], CTAOrder = []}> // CHECK: #[[BB:.*]] = #triton_gpu.blocked<{sizePerThread = [2, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [], CTASplitNum = [], CTAOrder = []}> From 1367f3a6d23cc80d8489a3da46c64246bb3aa0c9 Mon Sep 17 00:00:00 2001 From: Ethan Pronovost Date: Thu, 31 Aug 2023 23:19:40 -0700 Subject: [PATCH 003/122] [FRONTEND/OPS] wap `stride_vn` and `stride_vk` in flash attention (#2208) I'm not sure if this was a typo or if I'm missing something. To me code like ``` (offs_n[:, None] * stride_vk + offs_k[None, :] * stride_vn) ``` seems off. In case this is a typo I made this PR to correct it. This PR should have no functional changes. If this is not a typo would you mind explaining the reasoning behind these variable names? --- python/triton/ops/flash_attention.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/python/triton/ops/flash_attention.py b/python/triton/ops/flash_attention.py index ce649a96cfda..53481fb21bd2 100644 --- a/python/triton/ops/flash_attention.py +++ b/python/triton/ops/flash_attention.py @@ -21,7 +21,7 @@ def _fwd_kernel( Out, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, stride_kh, stride_kn, stride_kk, - stride_vz, stride_vh, stride_vk, stride_vn, + stride_vz, stride_vh, stride_vn, stride_vk, stride_oz, stride_oh, stride_om, stride_on, Z, H, N_CTX, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, @@ -50,7 +50,7 @@ def _fwd_kernel( V_block_ptr = tl.make_block_ptr( base=V + qvk_offset, shape=(N_CTX, BLOCK_DMODEL), - strides=(stride_vk, stride_vn), + strides=(stride_vn, stride_vk), offsets=(0, 0), block_shape=(BLOCK_N, BLOCK_DMODEL), order=(1, 0) @@ -137,7 +137,7 @@ def _bwd_kernel_one_col_block( D, stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, stride_kh, stride_kn, stride_kk, - stride_vz, stride_vh, stride_vk, stride_vn, + stride_vz, stride_vh, stride_vn, stride_vk, Z, H, N_CTX, off_hz, start_n, num_block, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, @@ -159,7 +159,7 @@ def _bwd_kernel_one_col_block( # initialize pointers to value-like data q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) - v_ptrs = V + (offs_n[:, None] * stride_vk + offs_k[None, :] * stride_vn) + v_ptrs = V + (offs_n[:, None] * stride_vn + offs_k[None, :] * stride_vk) do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) # pointer to row-wise quantities in value-like data @@ -212,7 +212,7 @@ def _bwd_kernel_one_col_block( q_ptrs += BLOCK_M * stride_qm do_ptrs += BLOCK_M * stride_qm # write-back - dv_ptrs = DV + (offs_n[:, None] * stride_vk + offs_k[None, :] * stride_vn) + dv_ptrs = DV + (offs_n[:, None] * stride_vn + offs_k[None, :] * stride_vk) dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) tl.store(dv_ptrs, dv) tl.store(dk_ptrs, dk) @@ -228,7 +228,7 @@ def _bwd_kernel( D, stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, stride_kh, stride_kn, stride_kk, - stride_vz, stride_vh, stride_vk, stride_vn, + stride_vz, stride_vh, stride_vn, stride_vk, Z, H, N_CTX, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, @@ -259,7 +259,7 @@ def _bwd_kernel( D, stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, stride_kh, stride_kn, stride_kk, - stride_vz, stride_vh, stride_vk, stride_vn, + stride_vz, stride_vh, stride_vn, stride_vk, Z, H, N_CTX, off_hz, start_n, num_block_n, BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL, @@ -276,7 +276,7 @@ def _bwd_kernel( D, stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, stride_kh, stride_kn, stride_kk, - stride_vz, stride_vh, stride_vk, stride_vn, + stride_vz, stride_vh, stride_vn, stride_vk, Z, H, N_CTX, off_hz, start_n, num_block_n, BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL, From 691c96355083f7082e697b8d08d8019a744f0f13 Mon Sep 17 00:00:00 2001 From: Qingyi Liu Date: Fri, 1 Sep 2023 14:20:15 +0800 Subject: [PATCH 004/122] [BACKEND] Fix several bugs related to multiCTA (#2170) Still work in progress --- .../TritonGPUToLLVM/LoadStoreOpToLLVM.cpp | 32 +++++++++++++++---- .../TritonGPUToLLVM/ReduceOpToLLVM.cpp | 16 ++++++++++ .../TritonGPUToLLVM/TritonGPUToLLVMBase.h | 21 ++++++++---- 3 files changed, 57 insertions(+), 12 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp index acb7f39b6493..4180fee6d7d6 100644 --- a/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -914,6 +914,18 @@ struct StoreAsyncOpConversion const TensorPtrMapT *tensorPtrMap; }; +namespace { +void createBarrier(ConversionPatternRewriter &rewriter, Location loc, + int numCTAs) { + if (numCTAs == 1) { + barrier(); + } else { + rewriter.create(loc, false); + rewriter.create(loc); + } +} +} // namespace + struct AtomicCASOpConversion : public ConvertTritonGPUOpToLLVMPattern, public LoadStoreConversionBase { @@ -934,6 +946,10 @@ struct AtomicCASOpConversion auto loc = op.getLoc(); MLIRContext *ctx = rewriter.getContext(); + auto moduleOp = op->getParentOfType(); + assert(moduleOp && "Parent ModuleOp not found for AtomicCASOp"); + int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(moduleOp); + Value llPtr = adaptor.getPtr(); Value llCmp = adaptor.getCmp(); Value llVal = adaptor.getVal(); @@ -971,7 +987,7 @@ struct AtomicCASOpConversion atom.global().o(semStr).o("cas").o("b32"); atom(dstOpr, ptrOpr, cmpOpr, valOpr).predicate(mask); auto old = ptxBuilderAtomicCAS.launch(rewriter, loc, valueElemTy); - barrier(); + createBarrier(rewriter, loc, numCTAs); PTXBuilder ptxBuilderStore; auto *dstOprStore = ptxBuilderStore.newAddrOperand(atomPtr, "r"); @@ -981,9 +997,9 @@ struct AtomicCASOpConversion st(dstOprStore, valOprStore).predicate(mask); auto ASMReturnTy = void_ty(ctx); ptxBuilderStore.launch(rewriter, loc, ASMReturnTy); - barrier(); + createBarrier(rewriter, loc, numCTAs); Value ret = load(atomPtr); - barrier(); + createBarrier(rewriter, loc, numCTAs); rewriter.replaceOp(op, {ret}); return success(); } @@ -1008,7 +1024,11 @@ struct AtomicRMWOpConversion ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); MLIRContext *ctx = rewriter.getContext(); - // + + auto moduleOp = op->getParentOfType(); + assert(moduleOp && "Parent ModuleOp not found for AtomicRMWOp"); + int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(moduleOp); + auto atomicRmwAttr = op.getAtomicRmwOp(); Value val = op.getVal(); @@ -1139,9 +1159,9 @@ struct AtomicRMWOpConversion auto *valOpr = ptxBuilderStore.newOperand(old, tyId); storeShared(ptrOpr, valOpr).predicate(rmwMask); ptxBuilderStore.launch(rewriter, loc, void_ty(ctx)); - barrier(); + createBarrier(rewriter, loc, numCTAs); Value ret = load(atomPtr); - barrier(); + createBarrier(rewriter, loc, numCTAs); rewriter.replaceOp(op, {ret}); } } diff --git a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp index 35f2b7e5a221..a00e105cf5ee 100644 --- a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp @@ -11,6 +11,7 @@ using ::mlir::LLVM::delinearize; using ::mlir::LLVM::linearize; using ::mlir::LLVM::shflSync; using ::mlir::LLVM::storeShared; +using ::mlir::triton::gpu::getCTASplitNum; using ::mlir::triton::gpu::getOrder; using ::mlir::triton::gpu::getTotalElemsPerThread; @@ -28,6 +29,12 @@ struct ReduceOpConversion LogicalResult matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + // When cross-CTA reduction is implemented in the future, this assertion can + // be removed + assert(isReduceWithinCTA(op) && + "Layout optimization passes such as PlanCTAPass and " + "RemoveLayoutConversionPass should avoid cross-CTA reduction"); + if (ReduceOpHelper(op).isFastReduction()) return matchAndRewriteFast(op, adaptor, rewriter); return matchAndRewriteBasic(op, adaptor, rewriter); @@ -36,6 +43,15 @@ struct ReduceOpConversion private: int computeCapability; + bool isReduceWithinCTA(triton::ReduceOp op) const { + auto axis = op.getAxis(); + ReduceOpHelper helper(op); + auto srcLayout = helper.getSrcLayout(); + auto CTASplitNum = getCTASplitNum(srcLayout); + assert(axis < CTASplitNum.size()); + return CTASplitNum[axis] == 1; + } + void accumulate(ConversionPatternRewriter &rewriter, Region &combineOp, SmallVector &acc, ValueRange cur, bool isFirst) const { if (isFirst) { diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h index 68361055c8e1..ba0423203a0b 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h @@ -541,6 +541,7 @@ class ConvertTritonGPUOpToLLVMPatternBase { auto tensorTy = valueTy.dyn_cast(); Value mask = int_val(1, 1); auto tid = tid_val(); + auto clusterCTAId = getClusterCTAId(rewriter, loc); if (tensorTy) { auto layout = tensorTy.getEncoding(); auto shape = tensorTy.getShape(); @@ -576,7 +577,6 @@ class ConvertTritonGPUOpToLLVMPatternBase { auto CTASplitNum = triton::gpu::getCTASplitNum(layout); auto CTAOrder = triton::gpu::getCTAOrder(layout); - auto clusterCTAId = getClusterCTAId(rewriter, loc); auto multiDimClusterCTAId = delinearize(rewriter, loc, clusterCTAId, CTAsPerCGA, CTAOrder); @@ -586,14 +586,23 @@ class ConvertTritonGPUOpToLLVMPatternBase { continue; // This wrapping rule must be consistent with emitCTAOffsetForLayout unsigned splitNum = std::min(shape[dim], CTASplitNum[dim]); - multiDimClusterCTAId[dim] = - urem(multiDimClusterCTAId[dim], i32_val(splitNum)); - mask = and_(mask, icmp_eq(multiDimClusterCTAId[dim], _0)); + Value repId = udiv(multiDimClusterCTAId[dim], i32_val(splitNum)); + // Consider the example where CTAsPerCGA = [4] and CTASplitNum = [2]: + // CTA0 and CTA2 holds data of block0, + // CTA1 and CTA3 holds data of block1. + // Only CTA0 and CTA1 are expected to write while CTA2 and CTA3 should + // be masked. We add the following mask: + // multiDimClusterCTAId[dim] / splitNum == 0 + // Actually in all existing cases of multicast, splitNum is always 1. + // The mask is equivalent to: + // multiDimClusterCTAId[dim] == 0 + mask = and_(mask, icmp_eq(repId, _0)); } } } else { - // If the tensor is not ranked, then it is a scalar and only thread 0 can - // write + // If the tensor is not ranked, then it is a scalar and only thread 0 of + // CTA0 can write + mask = and_(mask, icmp_eq(clusterCTAId, i32_val(0))); mask = and_(mask, icmp_eq(tid, i32_val(0))); } return mask; From a4df60e20a4211fd987afd7e99eb8a125b9d460d Mon Sep 17 00:00:00 2001 From: Shantanu <12621235+hauntsaninja@users.noreply.github.com> Date: Fri, 1 Sep 2023 13:30:42 -0700 Subject: [PATCH 005/122] [FRONTEND] Fix GIL handling in error conditions (#2225) The use of the opaque GIL state APIs should mean that the PyErr_SetString is now safe, regardless of whether the caller has the GIL or not. --- python/triton/compiler/make_launcher.py | 3 +++ python/triton/runtime/backends/cuda.c | 3 +++ python/triton/runtime/backends/hip.c | 3 +++ 3 files changed, 9 insertions(+) diff --git a/python/triton/compiler/make_launcher.py b/python/triton/compiler/make_launcher.py index a97d7d11b57b..eb8079e84754 100644 --- a/python/triton/compiler/make_launcher.py +++ b/python/triton/compiler/make_launcher.py @@ -116,7 +116,10 @@ def format_of(ty): char err[1024] = {{0}}; strcat(err, prefix); strcat(err, str); + PyGILState_STATE gil_state; + gil_state = PyGILState_Ensure(); PyErr_SetString(PyExc_RuntimeError, err); + PyGILState_Release(gil_state); }} }} diff --git a/python/triton/runtime/backends/cuda.c b/python/triton/runtime/backends/cuda.c index 622588e8d0ff..7dd60528f28f 100644 --- a/python/triton/runtime/backends/cuda.c +++ b/python/triton/runtime/backends/cuda.c @@ -11,7 +11,10 @@ static inline void gpuAssert(CUresult code, const char *file, int line) { char err[1024] = {0}; strcat(err, prefix); strcat(err, str); + PyGILState_STATE gil_state; + gil_state = PyGILState_Ensure(); PyErr_SetString(PyExc_RuntimeError, err); + PyGILState_Release(gil_state); } } diff --git a/python/triton/runtime/backends/hip.c b/python/triton/runtime/backends/hip.c index 5ed5f19ce837..c419132fe240 100644 --- a/python/triton/runtime/backends/hip.c +++ b/python/triton/runtime/backends/hip.c @@ -13,7 +13,10 @@ static inline void gpuAssert(hipError_t code, const char *file, int line) { const char *str = hipGetErrorString(code); char err[1024] = {0}; snprintf(err, 1024, "%s Code: %d, Messsage: %s", prefix, code, str); + PyGILState_STATE gil_state; + gil_state = PyGILState_Ensure(); PyErr_SetString(PyExc_RuntimeError, err); + PyGILState_Release(gil_state); } } } From acbf716889987f62a43a09ff7f499d0cf203830e Mon Sep 17 00:00:00 2001 From: Zahi Moudallal <128723247+zahimoud@users.noreply.github.com> Date: Fri, 1 Sep 2023 16:40:31 -0700 Subject: [PATCH 006/122] [BACKEND] Refactoring NVGPUToLLVMPass (#2158) --- include/triton/Dialect/NVGPU/IR/NVGPUOps.td | 10 +- .../NVGPUToLLVM/NVGPUToLLVMPass.cpp | 1646 ++++++++--------- .../TritonGPUToLLVM/DotOpToLLVM/WGMMA.cpp | 2 +- test/NVGPU/test_cga.mlir | 8 - test/NVGPU/test_wgmma.mlir | 34 +- 5 files changed, 793 insertions(+), 907 deletions(-) diff --git a/include/triton/Dialect/NVGPU/IR/NVGPUOps.td b/include/triton/Dialect/NVGPU/IR/NVGPUOps.td index a9451984c7ce..e8e2c91e63ac 100644 --- a/include/triton/Dialect/NVGPU/IR/NVGPUOps.td +++ b/include/triton/Dialect/NVGPU/IR/NVGPUOps.td @@ -97,7 +97,15 @@ def WGMMADesc_ModeAttr : I32EnumAttr<"WGMMADescMode", } def NVGPU_WGMMADescCreateOp : NVGPU_Op<"wgmma_desc_create", []> { - let arguments = (ins LLVM_AnyPointer:$buffer, I32:$height, WGMMADesc_ModeAttr:$mode); + let arguments = (ins LLVM_AnyPointer:$buffer, I32:$height, WGMMADesc_ModeAttr:$mode, I64Attr:$swizzling); + let builders = [ + OpBuilder<(ins "Value":$buffer, + "Value":$height, + "WGMMADescMode":$mode), [{ + uint32_t mode_ = static_cast(mode); + uint64_t swizzling = (mode_ == 1 ? 128 : mode_ == 2 ? 64 : 32); + build($_builder, $_state, $_builder.getIntegerType(64), buffer, height, WGMMADescModeAttr::get($_builder.getContext(), mode), $_builder.getI64IntegerAttr(swizzling)); + }]>]; let results = (outs I64:$res); let assemblyFormat = "$buffer `,` $height attr-dict `:` functional-type(operands, results)"; } diff --git a/lib/Conversion/NVGPUToLLVM/NVGPUToLLVMPass.cpp b/lib/Conversion/NVGPUToLLVM/NVGPUToLLVMPass.cpp index f52256b3ade4..366750780673 100644 --- a/lib/Conversion/NVGPUToLLVM/NVGPUToLLVMPass.cpp +++ b/lib/Conversion/NVGPUToLLVM/NVGPUToLLVMPass.cpp @@ -21,12 +21,245 @@ using ::mlir::LLVM::getSRegValue; namespace { +using OperandsAndConstraints = std::vector>; +typedef std::vector Constraints; + +const std::string Reg_Alloc_Op = "setmaxnreg.inc.sync.aligned.u32 #regCount;"; +const std::string Wgmma_Fence_Op = "wgmma.fence.sync.aligned;"; +const std::string Cga_Barrier_Sync_op = "barrier.cluster.sync.aligned;"; +const std::string Wgmma_Commit_Group_Op = "wgmma.commit_group.sync.aligned;"; +const std::string Wgmma_Wait_Group_Op = + "wgmma.wait_group.sync.aligned #pendings;"; +const std::string Cluster_Wait_Op = "barrier.cluster.wait.aligned;"; +const std::string Fence_Mbarrier_Init_Op = + "fence.mbarrier_init.release.cluster;"; +const std::string Cga_Barrier_Arrive_Op = "barrier.cluster.arrive;"; +const std::string Cga_Barrier_Wait_Op = "barrier.cluster.wait;"; +const std::string Reg_Dealloc_Op = "setmaxnreg.dec.sync.aligned.u32 #regCount;"; +const std::string Wgmma_Desc_Create_op = + "{\n" + ".reg .u64 a<5>; \n" + "mov.u64 a0, #swizzling;\n" + "shl.b64 a1, a0, 3;\n" // stride dimension + "shr.b64 a1, a1, 4;\n" // stride dimension + "mul.lo.u64 a2, $2, #swizzling;\n" // leadingDimension + "shr.b64 a2, a2, 4;\n" // leadingDimension + "shl.b64 a3, $1, 46; \n" // startAddr + "shr.b64 a3, a3, 50; \n" // startAddr + "mov.u64 a4, #mode; \n" // mode + "shl.b64 a4, a4, 62; \n" + "shl.b64 a1, a1, 32; \n" + "or.b64 a1, a4, a1; \n" + "shl.b64 a2, a2, 16; \n" + "or.b64 a1, a1, a2; \n" + "or.b64 $0, a1, a3; \n" + "}"; + +const std::string Mbarrier_Init_Op = + "@$1 mbarrier.init.shared.b64 [$0], #count;"; +const std::string Mbarrier_Wait_Op = + "{ \n" + ".reg .pred P1; \n" + "LAB_WAIT: \n" + "mbarrier.try_wait.parity.shared.b64 P1, [$0], $1, 0x989680; \n" + "@P1 bra.uni DONE; \n" + "bra.uni LAB_WAIT; \n" + "DONE: \n" + "} \n"; +const std::string Named_Barrier_Arrive_Op = "bar.arrive $0, $1;"; +const std::string Named_Barrier_Wait_Op = "bar.sync $0, $1;"; +const std::string Sts64_Op = "st.shared.v2.b32 [$0], {$1, $2};"; +const std::string Cluster_Cta_Id_Op = "{\n" + ".reg .u32 a<5>; \n" + "mov.u32 a0, %cluster_ctaid.x;\n" // x + "mov.u32 a1, %cluster_ctaid.y;\n" // y + "mov.u32 a2, %cluster_ctaid.z;\n" // z + "mov.u32 a3, %cluster_nctaid.x;\n" // nx + "mov.u32 a4, %cluster_nctaid.y;\n" // ny + "mad.lo.u32 a1, a2, a4, a1; \n" + "mad.lo.u32 $0, a1, a3, a0; \n" + "}"; + +bool isNumber(const std::string &s) { + return !s.empty() && std::find_if(s.begin(), s.end(), [](unsigned char c) { + return !std::isdigit(c); + }) == s.end(); +} + +Type getTypeFromConstraint(char constraint, mlir::PatternRewriter &rewriter) { + Type ty; + if (constraint == 'b') + ty = IntegerType::get(rewriter.getContext(), 1); + else if (constraint == 'h') + ty = IntegerType::get(rewriter.getContext(), 16); + else if (constraint == 'r') + ty = IntegerType::get(rewriter.getContext(), 32); + else if (constraint == 'l') + ty = IntegerType::get(rewriter.getContext(), 64); + else if (constraint == 'f') + ty = FloatType::getF32(rewriter.getContext()); + else if (constraint == 'd') + ty = FloatType::getF64(rewriter.getContext()); + else { + assert(false && "Unsupported constraint"); + } + return ty; +} + template class NVGPUOpPatternBase : public mlir::RewritePattern { public: explicit NVGPUOpPatternBase(mlir::MLIRContext *context) : mlir::RewritePattern(SourceOp::getOperationName(), 1, context) {} + // Converts the given value to the type represented by the constraint + // E.g. if val is of type llvmptr and constraint is 'r', then we convert + // val to i32 using ptrtoint(i32_ty, val) + mlir::Value convertToType(mlir::Value val, std::string constraint, + Location &loc, + mlir::PatternRewriter &rewriter) const { + auto isConstraintNumber = isNumber(constraint); + if (!isConstraintNumber) { + auto ty = getTypeFromConstraint(constraint[0], rewriter); + if (val.getType().isa()) { + return ptrtoint(ty, val); + } else { + assert(val.getType().getIntOrFloatBitWidth() <= + ty.getIntOrFloatBitWidth() && + "Cannot convert to a smaller type"); + return zext(ty, val); + } + } + return val; + } + + SmallVector + getPtxOutputs(std::vector &outputConstraints, + PTXBuilder &ptxBuilder) const { + SmallVector ptxOutputs; + for (unsigned i = 0; i < outputConstraints.size(); i++) { + auto *ptxOutput = ptxBuilder.newOperand(outputConstraints[i]); + ptxOutputs.push_back(ptxOutput); + } + return ptxOutputs; + } + + OperandsAndConstraints + unpackOperands(OperandsAndConstraints &operandsAndConstraints, + PTXBuilder &ptxBuilder, Location &loc, + mlir::PatternRewriter &rewriter) const { + OperandsAndConstraints unpackedOperands; + for (auto &[operand, constraint] : operandsAndConstraints) { + auto llvmStruct = llvm::dyn_cast(operand.getType()); + // if a constraint is a number, then we are doing input/output tying + // if the operand is a struct, then we need to unpack it, and + // add the constraint to each of the unpacked operands uses the constraint + // as an offset + auto isConstraintNumber = isNumber(constraint); + if (llvmStruct) { + for (unsigned i = 0; i < llvmStruct.getBody().size(); i++) { + if (isConstraintNumber) { + auto constraintInt = std::stoi(constraint) + i; + unpackedOperands.push_back( + {extract_val(llvmStruct.getBody()[i], operand, i), + std::to_string(constraintInt)}); + } else { + unpackedOperands.push_back( + {extract_val(llvmStruct.getBody()[i], operand, i), constraint}); + } + } + } else { + unpackedOperands.push_back({operand, constraint}); + } + } + return unpackedOperands; + } + + SmallVector + getPtxOperands(OperandsAndConstraints &operandsAndConstraints, + PTXBuilder &ptxBuilder, Location &loc, + mlir::PatternRewriter &rewriter) const { + SmallVector ptxOperands; + auto unpackedOperandsAndConstraints = + unpackOperands(operandsAndConstraints, ptxBuilder, loc, rewriter); + for (auto &[operand, constraint] : unpackedOperandsAndConstraints) { + auto convertedOperand = convertToType(operand, constraint, loc, rewriter); + auto *ptxOperand = ptxBuilder.newOperand(convertedOperand, constraint); + ptxOperands.push_back(ptxOperand); + } + return ptxOperands; + } + + virtual std::vector getOutputConstraints(SourceOp op) const { + return {}; + } + + virtual OperandsAndConstraints getOperandsAndConstraints(SourceOp op) const { + return {}; + } + + Type getReturnType(std::vector outputConstraints, + mlir::PatternRewriter &rewriter) const { + auto ctx = rewriter.getContext(); + Type resTy; + if (outputConstraints.empty()) { + resTy = void_ty(ctx); + } else { + SmallVector retTys; + for (auto &outputConstraint : outputConstraints) { + assert(outputConstraint[0] == '=' && + "Constraint must be for an output"); + Type retTy = getTypeFromConstraint(outputConstraint[1], rewriter); + retTys.push_back(retTy); + } + if (retTys.size() == 1) { + resTy = retTys[0]; + } else { + resTy = struct_ty(retTys); + } + } + return resTy; + } + + std::string patchPtxAsm(mlir::Operation *op, std::string ptxAsm) const { + std::vector> patchLocations; + std::vector patchValues; + auto start = ptxAsm.find("#", 0); + while (start != std::string::npos) { + auto endIterator = + std::find_if(ptxAsm.begin() + start + 1, ptxAsm.end(), + [](unsigned char c) { return !std::isalnum(c); }); + + assert(endIterator != ptxAsm.end() && "unexpected asm format"); + + auto end = std::distance(ptxAsm.begin(), endIterator); + auto patchLocation = std::make_pair(start, end); + patchLocations.push_back(patchLocation); + auto patchValue = ptxAsm.substr(start + 1, end - start - 1); + patchValues.push_back(patchValue); + start = ptxAsm.find("#", end); + } + assert(patchLocations.size() == patchValues.size() && + "patchLocations and patchValues should have the same size"); + if (patchLocations.size() == 0) { + return ptxAsm; + } + std::string res = ""; + size_t prevStart = 0; + unsigned i = 0; + for (auto &[start, end] : patchLocations) { + res += ptxAsm.substr(prevStart, start - prevStart); + auto integerAttr = op->getAttrOfType(patchValues[i]); + auto attr = integerAttr.getInt(); + res += std::to_string(attr); + prevStart = end; + i++; + } + if (prevStart < ptxAsm.size()) + res += ptxAsm.substr(prevStart, ptxAsm.size() - prevStart); + return res; + } + LogicalResult matchAndRewrite(mlir::Operation *op, mlir::PatternRewriter &rewriter) const override { @@ -35,30 +268,62 @@ class NVGPUOpPatternBase : public mlir::RewritePattern { auto sourceOp = llvm::dyn_cast(op); if (!sourceOp) return mlir::failure(); - auto ptxAsm = static_cast(this)->getPtxAsm(sourceOp); + auto concrete = static_cast(this); + auto ptxAsm = concrete->getPtxAsm(sourceOp); + auto ptxAsmPatched = patchPtxAsm(sourceOp, ptxAsm); auto hasSideEffects = !isMemoryEffectFree(sourceOp); + auto operandsAndConstraints = concrete->getOperandsAndConstraints(sourceOp); + auto outputConstraints = concrete->getOutputConstraints(sourceOp); + PTXBuilder ptxBuilder; - auto &ptxInstr = *ptxBuilder.create(ptxAsm); - ptxInstr({}, /*onlyAttachMLIRArgs=*/true); - auto asmReturnTy = void_ty(ctx); - ptxBuilder.launch(rewriter, loc, asmReturnTy, - /*hasSideEffects*/ hasSideEffects); - rewriter.eraseOp(op); + auto ptxOutputs = getPtxOutputs(outputConstraints, ptxBuilder); + auto ptxOperands = + getPtxOperands(operandsAndConstraints, ptxBuilder, loc, rewriter); + SmallVector outputsAndOperands = ptxOutputs; + outputsAndOperands.append(ptxOperands.begin(), ptxOperands.end()); + auto &ptxInstr = *ptxBuilder.create(ptxAsmPatched); + ptxInstr(outputsAndOperands, /*onlyAttachMLIRArgs=*/true); + auto retTy = getReturnType(outputConstraints, rewriter); + auto res = ptxBuilder.launch(rewriter, loc, retTy, + /*hasSideEffects*/ hasSideEffects); + if (op->getNumResults() == 0) { + rewriter.eraseOp(op); + } else { + rewriter.replaceOp(op, res); + } + return mlir::success(); } }; -class CGABarrierSyncOpPattern - : public NVGPUOpPatternBase { +template +class NVGPUOpGenericPattern + : public NVGPUOpPatternBase> { public: - using Base = - NVGPUOpPatternBase; - using Base::Base; - - std::string getPtxAsm(ttn::CGABarrierSyncOp op) const { - return "barrier.cluster.sync.aligned;"; + explicit NVGPUOpGenericPattern(mlir::MLIRContext *context, std::string ptxAsm, + std::vector outputConstraints, + std::vector inputConstraints) + : NVGPUOpPatternBase>(context), + ptxAsm(ptxAsm), outputConstraints(outputConstraints), + inputConstraints(inputConstraints) {} + + std::vector getOutputConstraints(SourceOp op) const { + return outputConstraints; + } + OperandsAndConstraints getOperandsAndConstraints(SourceOp op) const { + OperandsAndConstraints operandsAndConstraints; + for (unsigned i = 0; i < inputConstraints.size(); i++) { + operandsAndConstraints.push_back( + {op->getOperand(i), inputConstraints[i]}); + } + return operandsAndConstraints; } + std::string getPtxAsm(SourceOp op) const { return ptxAsm; } + +private: + std::string ptxAsm; + std::vector outputConstraints; + std::vector inputConstraints; }; class FenceAsyncSharedOpPattern @@ -78,437 +343,342 @@ class FenceAsyncSharedOpPattern } }; -class WGMMAFenceOpPattern - : public NVGPUOpPatternBase { -public: - using Base = NVGPUOpPatternBase; - using Base::Base; - - std::string getPtxAsm(ttn::WGMMAFenceOp op) const { - return "wgmma.fence.sync.aligned;"; - } -}; - -class WGMMACommitGroupOpPattern - : public NVGPUOpPatternBase { +class ClusterArriveOpPattern + : public NVGPUOpPatternBase { public: - using Base = - NVGPUOpPatternBase; + using Base = NVGPUOpPatternBase; using Base::Base; - std::string getPtxAsm(ttn::WGMMACommitGroupOp op) const { - return "wgmma.commit_group.sync.aligned;"; + std::string getPtxAsm(ttn::ClusterArriveOp op) const { + auto relaxed = op.getRelaxed(); + if (relaxed) + return "barrier.cluster.arrive.relaxed.aligned;"; + else + return "barrier.cluster.arrive.aligned;"; } }; -class WGMMAWaitGroupOpPattern - : public NVGPUOpPatternBase { +class StoreMatrixOpPattern + : public NVGPUOpPatternBase { public: - using Base = - NVGPUOpPatternBase; + using Base = NVGPUOpPatternBase; using Base::Base; - std::string getPtxAsm(ttn::WGMMAWaitGroupOp op) const { - auto pendings = op.getPendings(); - return "wgmma.wait_group.sync.aligned " + std::to_string(pendings) + ";"; + OperandsAndConstraints + getOperandsAndConstraints(ttn::StoreMatrixOp op) const { + OperandsAndConstraints operandsAndTypes; + auto addr = op.getAddr(); + auto datas = op.getDatas(); + operandsAndTypes.push_back({addr, "r"}); + for (unsigned i = 0; i < datas.size(); i++) { + operandsAndTypes.push_back({datas[i], "r"}); + } + return operandsAndTypes; } -}; - -class StoreMatrixOpPattern : public mlir::RewritePattern { -public: - StoreMatrixOpPattern(mlir::MLIRContext *context) - : mlir::RewritePattern(ttn::StoreMatrixOp::getOperationName(), 1, - context) {} - mlir::LogicalResult - matchAndRewrite(mlir::Operation *op, - mlir::PatternRewriter &rewriter) const override { - auto ctx = rewriter.getContext(); - auto storeMatrixOp = llvm::dyn_cast(op); - if (!storeMatrixOp) - return mlir::failure(); - auto loc = op->getLoc(); - auto addr = storeMatrixOp.getAddr(); - auto datas = storeMatrixOp.getDatas(); - - assert(datas.size() == 1 || datas.size() == 2 || - datas.size() == 4 && "Invalid size for StoreMatrixOp"); - PTXBuilder ptxBuilder; - auto &ptxInstr = *ptxBuilder.create( - "stmatrix.sync.aligned.m8n8.x" + std::to_string(datas.size()) + - ".shared.b16"); - auto *addrOpr = ptxBuilder.newAddrOperand(ptrtoint(i32_ty, addr), "r"); - - SmallVector> args; - for (unsigned i = 0; i < datas.size(); ++i) { - args.push_back({datas[i], "r"}); + std::string getPtxAsm(ttn::StoreMatrixOp op) const { + auto datas = op.getDatas(); + std::string ptxAsm; + switch (datas.size()) { + case 1: + ptxAsm = "stmatrix.sync.aligned.m8n8.x1.shared.b16 [$0], {$1};"; + break; + case 2: + ptxAsm = "stmatrix.sync.aligned.m8n8.x2.shared.b16 [$0], {$1, $2};"; + break; + case 4: + ptxAsm = + "stmatrix.sync.aligned.m8n8.x4.shared.b16 [$0], {$1, $2, $3, $4};"; + break; + default: + assert(false && "Invalid size"); } - auto *operands = ptxBuilder.newListOperand(args); - - ptxInstr(addrOpr, operands); - - auto asmReturnTy = void_ty(ctx); - ptxBuilder.launch(rewriter, loc, asmReturnTy); - rewriter.eraseOp(op); - return mlir::success(); + return ptxAsm; } }; -class MBarrierInitOpPattern : public mlir::RewritePattern { +class MBarrierArriveOpPattern + : public NVGPUOpPatternBase { public: - MBarrierInitOpPattern(mlir::MLIRContext *context) - : mlir::RewritePattern(ttn::MBarrierInitOp::getOperationName(), 1, - context) {} - - mlir::LogicalResult - matchAndRewrite(mlir::Operation *op, - mlir::PatternRewriter &rewriter) const override { - auto ctx = rewriter.getContext(); - auto mBarrierInitOp = llvm::dyn_cast(op); - if (!mBarrierInitOp) - return mlir::failure(); - auto loc = op->getLoc(); - Value mbarrier = mBarrierInitOp.getMbarrier(); - Value pred = mBarrierInitOp.getPred(); - uint32_t count = mBarrierInitOp.getCount(); - PTXBuilder ptxBuilder; - - auto &ptxInstr = *ptxBuilder.create("mbarrier.init.shared.b64"); - auto *barOpr = - ptxBuilder.newAddrOperand(ptrtoint(i32_ty, mbarrier), "r", 0); - auto *expectedOpr = ptxBuilder.newConstantOperand(count); - - ptxInstr(barOpr, expectedOpr).predicate(pred, "b"); + using Base = + NVGPUOpPatternBase; + using Base::Base; - auto asmReturnTy = void_ty(ctx); - ptxBuilder.launch(rewriter, loc, asmReturnTy); - rewriter.eraseOp(op); - return mlir::success(); + OperandsAndConstraints + getOperandsAndConstraints(ttn::MBarrierArriveOp op) const { + OperandsAndConstraints operandsAndTypes; + Value mbarrier = op.getMbarrier(); + Value pred = op.getPred(); + Value ctaId = op.getCtaId(); + auto arriveType = op.getArriveType(); + + switch (arriveType) { + case ttn::MBarriveType::normal: + case ttn::MBarriveType::cp_async: + case ttn::MBarriveType::expect_tx: + operandsAndTypes.push_back({mbarrier, "r"}); + operandsAndTypes.push_back({pred, "b"}); + break; + case ttn::MBarriveType::remote: + operandsAndTypes.push_back({mbarrier, "r"}); + operandsAndTypes.push_back({ctaId, "r"}); + operandsAndTypes.push_back({pred, "b"}); + break; + default: + llvm::errs() << "Unsupported mbarrier arrive type " << arriveType << "\n"; + llvm_unreachable(""); + break; + } + return operandsAndTypes; } -}; - -class MBarrierArriveOpPattern : public mlir::RewritePattern { -public: - MBarrierArriveOpPattern(mlir::MLIRContext *context) - : mlir::RewritePattern(ttn::MBarrierArriveOp::getOperationName(), 1, - context) {} - - mlir::LogicalResult - matchAndRewrite(mlir::Operation *op, - mlir::PatternRewriter &rewriter) const override { - auto ctx = rewriter.getContext(); - auto mbarrierArriveOp = llvm::dyn_cast(op); - if (!mbarrierArriveOp) - return mlir::failure(); - auto loc = op->getLoc(); - Value mbarrier = mbarrierArriveOp.getMbarrier(); - Value pred = mbarrierArriveOp.getPred(); - Value ctaId = mbarrierArriveOp.getCtaId(); - auto arriveType = mbarrierArriveOp.getArriveType(); - uint32_t txCount = mbarrierArriveOp.getTxCount(); - PTXBuilder ptxBuilder; - if (arriveType == ttn::MBarriveType::normal) { - auto &ptxInstr = - *ptxBuilder.create("mbarrier.arrive.shared.b64 _,"); - auto *barOpr = - ptxBuilder.newAddrOperand(ptrtoint(i32_ty, mbarrier), "r", 0); - - ptxInstr(barOpr).predicate(pred, "b"); - } else if (arriveType == ttn::MBarriveType::cp_async) { - auto &ptxInstr = *ptxBuilder.create( - "cp.async.mbarrier.arrive.noinc.shared.b64"); - auto *barOpr = - ptxBuilder.newAddrOperand(ptrtoint(i32_ty, mbarrier), "r", 0); - - ptxInstr(barOpr).predicate(pred, "b"); - } else if (arriveType == ttn::MBarriveType::expect_tx) { + std::string getPtxAsm(ttn::MBarrierArriveOp op) const { + Value ctaId = op.getCtaId(); + auto arriveType = op.getArriveType(); + uint32_t txCount = op.getTxCount(); + std::string ptxAsm; + switch (arriveType) { + case ttn::MBarriveType::normal: + ptxAsm = "@$1 mbarrier.arrive.shared.b64 _, [$0];"; + break; + case ttn::MBarriveType::cp_async: + ptxAsm = "@$1 cp.async.mbarrier.arrive.noinc.shared.b64 [$0];"; + break; + case ttn::MBarriveType::expect_tx: assert(txCount > 0 && "txCount should be valid"); - auto &ptxInstr = *ptxBuilder.create( - "mbarrier.arrive.expect_tx.shared.b64 _,"); - auto *barOpr = - ptxBuilder.newAddrOperand(ptrtoint(i32_ty, mbarrier), "r", 0); - auto *expectedOpr = ptxBuilder.newConstantOperand(txCount); - - ptxInstr(barOpr, expectedOpr).predicate(pred, "b"); - } else if (arriveType == ttn::MBarriveType::remote) { + ptxAsm = "@$1 mbarrier.arrive.expect_tx.shared.b64 _, [$0], " + + std::to_string(txCount) + ";"; + break; + case ttn::MBarriveType::remote: assert(ctaId && "ctaId should have a valid value"); - auto ptxAsm = + ptxAsm = " { .reg .b32 remAddr32; \n" " @$2 mapa.shared::cluster.u32 remAddr32, $0, $1; \n" " @$2 mbarrier.arrive.shared::cluster.b64 _, [remAddr32]; } \n"; - auto &ptxInstr = *ptxBuilder.create(ptxAsm); - auto *barOpr = - ptxBuilder.newAddrOperand(ptrtoint(i32_ty, mbarrier), "r", 0); - auto *ctaIdOpr = ptxBuilder.newOperand(ctaId, "r"); - auto *predOpr = ptxBuilder.newOperand(pred, "b"); - - ptxInstr({barOpr, ctaIdOpr, predOpr}, /*onlyAttachMLIRArgs=*/true); - } else { - assert(false && - "Unsupported mbarrier arrive type"); // TODO: is this the right way - // to assert in LLVM pass ? + break; + default: + llvm::errs() << "Unsupported mbarrier arrive type " << arriveType << "\n"; + llvm_unreachable(""); + break; } - auto asmReturnTy = void_ty(ctx); - ptxBuilder.launch(rewriter, loc, asmReturnTy); - rewriter.eraseOp(op); - return mlir::success(); - } -}; -class MBarrierWaitOpPattern : public mlir::RewritePattern { -public: - MBarrierWaitOpPattern(mlir::MLIRContext *context) - : mlir::RewritePattern(ttn::MBarrierWaitOp::getOperationName(), 1, - context) {} - - mlir::LogicalResult - matchAndRewrite(mlir::Operation *op, - mlir::PatternRewriter &rewriter) const override { - auto ctx = rewriter.getContext(); - auto mBarrierWaitOp = llvm::dyn_cast(op); - if (!mBarrierWaitOp) - return mlir::failure(); - auto loc = op->getLoc(); - Value mbarrier = mBarrierWaitOp.getMbarrier(); - Value phase = mBarrierWaitOp.getPhase(); - PTXBuilder ptxBuilder; - - auto ptxAsm = - "{\n" - ".reg .pred P1; \n" - "LAB_WAIT: \n" - "mbarrier.try_wait.parity.shared.b64 P1, [$0], $1, 0x989680; \n" - "@P1 bra.uni DONE; \n" - "bra.uni LAB_WAIT; \n" - "DONE: \n" - "}"; - auto &ptxInstr = *ptxBuilder.create(ptxAsm); - auto *barOpr = ptxBuilder.newOperand(ptrtoint(i32_ty, mbarrier), "r"); - auto *phaseOpr = ptxBuilder.newOperand(zext(i32_ty, phase), "r"); - - ptxInstr({barOpr, phaseOpr}, - /*onlyAttachMLIRArgs=*/true); - - auto asmReturnTy = void_ty(ctx); - ptxBuilder.launch(rewriter, loc, asmReturnTy); - rewriter.eraseOp(op); - return mlir::success(); + return ptxAsm; } }; -class ClusterArriveOpPattern - : public NVGPUOpPatternBase { +class TMALoadTiledOpPattern + : public NVGPUOpPatternBase { public: - using Base = NVGPUOpPatternBase; + using Base = NVGPUOpPatternBase; using Base::Base; - std::string getPtxAsm(ttn::ClusterArriveOp op) const { - auto relaxed = op.getRelaxed(); - if (relaxed) - return "barrier.cluster.arrive.relaxed.aligned;"; - else - return "barrier.cluster.arrive.aligned;"; - } -}; + OperandsAndConstraints + getOperandsAndConstraints(ttn::TMALoadTiledOp op) const { + OperandsAndConstraints operandsAndTypes; + auto dst = op.getDst(); + auto mbarrier = op.getMbarrier(); + auto tmaDesc = op.getTmaDesc(); + auto l2Desc = op.getL2Desc(); + auto pred = op.getPred(); + auto coords = op.getCoords(); + auto mcastMask = op.getMcastMask(); -class ClusterWaitOpPattern - : public NVGPUOpPatternBase { -public: - using Base = NVGPUOpPatternBase; - using Base::Base; - std::string getPtxAsm(ttn::ClusterWaitOp op) const { - return "barrier.cluster.wait.aligned;"; - } -}; + auto dimSize = coords.size(); + assert(dimSize == 2 || (dimSize == 4 && mcastMask == nullptr) && + "Does not support TMA configuration"); -class TMALoadTiledOpPattern : public mlir::RewritePattern { -public: - TMALoadTiledOpPattern(mlir::MLIRContext *context) - : mlir::RewritePattern(ttn::TMALoadTiledOp::getOperationName(), 1, - context) {} + operandsAndTypes.push_back({dst, "r"}); + operandsAndTypes.push_back({tmaDesc, "l"}); + for (unsigned i = 0; i < coords.size(); i++) { + operandsAndTypes.push_back({coords[i], "r"}); + } + operandsAndTypes.push_back({mbarrier, "l"}); + if (mcastMask) { + operandsAndTypes.push_back({mcastMask, "h"}); + } + operandsAndTypes.push_back({l2Desc, "l"}); + operandsAndTypes.push_back({pred, "b"}); - mlir::LogicalResult - matchAndRewrite(mlir::Operation *op, - mlir::PatternRewriter &rewriter) const override { - auto ctx = rewriter.getContext(); - auto tmaLoadTiledOp = llvm::dyn_cast(op); - if (!tmaLoadTiledOp) - return mlir::failure(); - auto loc = op->getLoc(); - auto dst = tmaLoadTiledOp.getDst(); - auto mbarrier = tmaLoadTiledOp.getMbarrier(); - auto tmaDesc = tmaLoadTiledOp.getTmaDesc(); - auto l2Desc = tmaLoadTiledOp.getL2Desc(); - auto pred = tmaLoadTiledOp.getPred(); - auto coords = tmaLoadTiledOp.getCoords(); - auto mcastMask = tmaLoadTiledOp.getMcastMask(); + return operandsAndTypes; + } + std::string getPtxAsm(ttn::TMALoadTiledOp op) const { + auto coords = op.getCoords(); + auto mcastMask = op.getMcastMask(); auto dimSize = coords.size(); - - PTXBuilder ptxBuilder; + std::string ptxAsm; if (dimSize == 2) { if (mcastMask == nullptr) { - auto ptxAsm = - "@$6 cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier:" - ":complete_tx" - "::bytes.L2::cache_hint [$0], [$1, {$2, $3}], [$4], $5;"; - auto &ptxInstr = *ptxBuilder.create(ptxAsm); - auto *dstOpr = ptxBuilder.newOperand(ptrtoint(i32_ty, dst), "r"); - auto *descOpr = ptxBuilder.newOperand(ptrtoint(i64_ty, tmaDesc), "l"); - auto *c0Opr = ptxBuilder.newOperand(coords[0], "r"); - auto *c1Opr = ptxBuilder.newOperand(coords[1], "r"); - auto *barOpr = ptxBuilder.newOperand(ptrtoint(i64_ty, mbarrier), "r"); - auto *l2DescOpr = ptxBuilder.newOperand(l2Desc, "l"); - auto *predOpr = ptxBuilder.newOperand(pred, "b"); - - ptxInstr({dstOpr, descOpr, c0Opr, c1Opr, barOpr, l2DescOpr, predOpr}, - /*onlyAttachMLIRArgs=*/true); + ptxAsm = "@$6 cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier:" + ":complete_tx" + "::bytes.L2::cache_hint [$0], [$1, {$2, $3}], [$4], $5;"; } else { - auto ptxAsm = - "@$7 cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::" - "complete_tx::bytes.multicast::cluster.L2::cache_hint" - " [$0], [$1, {$2, $3}], [$4], $5, $6;"; - auto &ptxInstr = *ptxBuilder.create(ptxAsm); - auto *dstOpr = ptxBuilder.newOperand(ptrtoint(i32_ty, dst), "r"); - auto *descOpr = ptxBuilder.newOperand(ptrtoint(i64_ty, tmaDesc), "l"); - auto *c0Opr = ptxBuilder.newOperand(coords[0], "r"); - auto *c1Opr = ptxBuilder.newOperand(coords[1], "r"); - auto *barOpr = ptxBuilder.newOperand(ptrtoint(i64_ty, mbarrier), "r"); - auto *maskOpr = ptxBuilder.newOperand(mcastMask, "h"); - auto *l2DescOpr = ptxBuilder.newOperand(l2Desc, "l"); - auto *predOpr = ptxBuilder.newOperand(pred, "b"); - ptxInstr({dstOpr, descOpr, c0Opr, c1Opr, barOpr, maskOpr, l2DescOpr, - predOpr}, - /*onlyAttachMLIRArgs=*/true); + ptxAsm = "@$7 cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::" + "complete_tx::bytes.multicast::cluster.L2::cache_hint" + " [$0], [$1, {$2, $3}], [$4], $5, $6;"; } } else if (dimSize == 4) { assert(mcastMask == nullptr && "Does not support multicast"); - auto ptxAsm = "@$8 " - "cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier:" - ":complete_tx" - "::bytes.L2::cache_hint [$0], [$1, {$2, $3, $4, $5}], " - "[$6], $7;"; - auto &ptxInstr = *ptxBuilder.create(ptxAsm); - auto *dstOpr = ptxBuilder.newOperand(ptrtoint(i32_ty, dst), "r"); - auto *descOpr = ptxBuilder.newOperand(ptrtoint(i64_ty, tmaDesc), "l"); - auto *c0Opr = ptxBuilder.newOperand(coords[0], "r"); - auto *c1Opr = ptxBuilder.newOperand(coords[1], "r"); - auto *c2Opr = ptxBuilder.newOperand(coords[2], "r"); - auto *c3Opr = ptxBuilder.newOperand(coords[3], "r"); - auto *barOpr = ptxBuilder.newOperand(ptrtoint(i64_ty, mbarrier), "r"); - auto *l2DescOpr = ptxBuilder.newOperand(l2Desc, "l"); - auto *predOpr = ptxBuilder.newOperand(pred, "b"); - ptxInstr({dstOpr, descOpr, c0Opr, c1Opr, c2Opr, c3Opr, barOpr, l2DescOpr, - predOpr}, - /*onlyAttachMLIRArgs=*/true); + ptxAsm = "@$8 " + "cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier:" + ":complete_tx" + "::bytes.L2::cache_hint [$0], [$1, {$2, $3, $4, $5}], " + "[$6], $7;"; } else { - assert(false && "invalid dim size"); + llvm::errs() << "Unsupported dimSize " << dimSize << "\n"; + llvm_unreachable(""); } - - auto asmReturnTy = void_ty(ctx); - ptxBuilder.launch(rewriter, loc, asmReturnTy, /*hasSideEffect*/ true); - rewriter.eraseOp(op); - return mlir::success(); + return ptxAsm; } }; -class TMAStoreTiledOpPattern : public mlir::RewritePattern { +class TMAStoreTiledOpPattern + : public NVGPUOpPatternBase { public: - TMAStoreTiledOpPattern(mlir::MLIRContext *context) - : mlir::RewritePattern(ttn::TMAStoreTiledOp::getOperationName(), 1, - context) {} + using Base = NVGPUOpPatternBase; + using Base::Base; - mlir::LogicalResult - matchAndRewrite(mlir::Operation *op, - mlir::PatternRewriter &rewriter) const override { - auto ctx = rewriter.getContext(); - auto tmaStoreTiledOp = llvm::dyn_cast(op); - if (!tmaStoreTiledOp) - return mlir::failure(); - auto loc = op->getLoc(); - auto src = tmaStoreTiledOp.getSrc(); - auto tmaDesc = tmaStoreTiledOp.getTmaDesc(); - auto pred = tmaStoreTiledOp.getPred(); - auto coords = tmaStoreTiledOp.getCoords(); + OperandsAndConstraints + getOperandsAndConstraints(ttn::TMAStoreTiledOp op) const { + OperandsAndConstraints operandsAndTypes; + auto src = op.getSrc(); + auto tmaDesc = op.getTmaDesc(); + auto pred = op.getPred(); + auto coords = op.getCoords(); auto dimSize = coords.size(); + if (dimSize != 2 && dimSize != 3 && dimSize != 4) { + llvm::errs() << "Unsupported dimSize " << dimSize << "\n"; + llvm_unreachable(""); + } + operandsAndTypes.push_back({tmaDesc, "l"}); + operandsAndTypes.push_back({src, "r"}); + for (unsigned i = 0; i < dimSize; i++) { + operandsAndTypes.push_back({coords[i], "r"}); + } + operandsAndTypes.push_back({pred, "b"}); - PTXBuilder ptxBuilder; + return operandsAndTypes; + } + + std::string getPtxAsm(ttn::TMAStoreTiledOp op) const { + auto coords = op.getCoords(); + auto dimSize = coords.size(); + std::string ptxAsm; if (dimSize == 2) { - auto ptxAsm = "cp.async.bulk.tensor.2d.global.shared::cta.bulk_group" - "[$0, {$2, $3}], [$1];"; - auto &ptxInstr = *ptxBuilder.create(ptxAsm); - - auto *descOpr = ptxBuilder.newOperand(ptrtoint(i64_ty, tmaDesc), "l"); - auto *srcOpr = ptxBuilder.newOperand(ptrtoint(i32_ty, src), "r"); - auto *c0Opr = ptxBuilder.newOperand(coords[0], "r"); - auto *c1Opr = ptxBuilder.newOperand(coords[1], "r"); - auto *predOpr = ptxBuilder.newOperand(pred, "b"); - ptxInstr({descOpr, srcOpr, c0Opr, c1Opr, predOpr}, - /*onlyAttachMLIRArgs=*/true); + ptxAsm = "cp.async.bulk.tensor.2d.global.shared::cta.bulk_group" + "[$0, {$2, $3}], [$1];"; } else if (dimSize == 3) { - auto ptxAsm = "@$5 cp.async.bulk.tensor.3d.global.shared::cta.bulk_group" - "[$0, {$2, $3, $4}], [$1];"; - auto &ptxInstr = *ptxBuilder.create(ptxAsm); - - auto *descOpr = ptxBuilder.newOperand(ptrtoint(i64_ty, tmaDesc), "l"); - auto *srcOpr = ptxBuilder.newOperand(ptrtoint(i32_ty, src), "r"); - auto *c0Opr = ptxBuilder.newOperand(coords[0], "r"); - auto *c1Opr = ptxBuilder.newOperand(coords[1], "r"); - auto *c2Opr = ptxBuilder.newOperand(coords[2], "r"); - auto *predOpr = ptxBuilder.newOperand(pred, "b"); - ptxInstr({descOpr, srcOpr, c0Opr, c1Opr, c2Opr, predOpr}, - /*onlyAttachMLIRArgs=*/true); + ptxAsm = "@$5 cp.async.bulk.tensor.3d.global.shared::cta.bulk_group" + "[$0, {$2, $3, $4}], [$1];"; } else if (dimSize == 4) { - auto ptxAsm = "@$6 cp.async.bulk.tensor.4d.global.shared::cta.bulk_group" - "[$0, {$2, $3, $4, $5}], [$1];"; - auto &ptxInstr = *ptxBuilder.create(ptxAsm); - auto *descOpr = ptxBuilder.newOperand(ptrtoint(i64_ty, tmaDesc), "l"); - auto *srcOpr = ptxBuilder.newOperand(ptrtoint(i32_ty, src), "r"); - auto *c0Opr = ptxBuilder.newOperand(coords[0], "r"); - auto *c1Opr = ptxBuilder.newOperand(coords[1], "r"); - auto *c2Opr = ptxBuilder.newOperand(coords[2], "r"); - auto *c3Opr = ptxBuilder.newOperand(coords[3], "r"); - auto *predOpr = ptxBuilder.newOperand(pred, "b"); - ptxInstr({descOpr, srcOpr, c0Opr, c1Opr, c2Opr, c3Opr, predOpr}, - /*onlyAttachMLIRArgs=*/true); + ptxAsm = "@$6 cp.async.bulk.tensor.4d.global.shared::cta.bulk_group" + "[$0, {$2, $3, $4, $5}], [$1];"; } else { - assert(false && "invalid dim size"); + llvm::errs() << "Unsupported dimSize " << dimSize << "\n"; + llvm_unreachable(""); } + return ptxAsm; + } +}; - auto asmReturnTy = void_ty(ctx); - ptxBuilder.launch(rewriter, loc, asmReturnTy, /*hasSideEffect*/ true); - rewriter.eraseOp(op); - return mlir::success(); +class StoreDSmemOpPattern + : public NVGPUOpPatternBase { +public: + using Base = NVGPUOpPatternBase; + using Base::Base; + + OperandsAndConstraints getOperandsAndConstraints(ttn::StoreDSmemOp op) const { + OperandsAndConstraints operandsAndTypes; + auto addr = op.getAddr(); + auto ctaId = op.getCtaId(); + auto values = op.getValues(); + auto pred = op.getPred(); + auto bitwidth = op.getBitwidth(); + operandsAndTypes.push_back({addr, "r"}); + operandsAndTypes.push_back({ctaId, "r"}); + operandsAndTypes.push_back({pred, "b"}); + std::string c = bitwidth == 16 ? "h" : (bitwidth == 32 ? "r" : "l"); + for (unsigned i = 0; i < values.size(); i++) { + operandsAndTypes.push_back({values[i], c}); + } + return operandsAndTypes; + } + + std::string getPtxAsm(ttn::StoreDSmemOp op) const { + auto bitwidth = op.getBitwidth(); + auto vec = op.getVec(); + auto values = op.getValues(); + assert( + (bitwidth == 8 || bitwidth == 16 || bitwidth == 32 || bitwidth == 64) && + "invalid bitwidth"); + assert((vec == 1 || vec == 2 || vec == 4) && vec == values.size() && + "invalid vec size"); + std::string ptxAsm; + if (vec == 1) { + ptxAsm = "{ \n" + ".reg .u32 remoteAddr; \n" + "mapa.shared::cluster.u32 remoteAddr, $0, $1;\n" + ".reg .pred p; \n" + "mov.pred p, $2; \n" + "@p st.shared::cluster.u#bitwidth [remoteAddr], $3; \n" + "}\n"; + } + if (vec == 2) { + ptxAsm = "{ \n" + ".reg .u32 remoteAddr; \n" + "mapa.shared::cluster.u32 remoteAddr, $0, $1;\n" + ".reg .pred p; \n" + "mov.pred p, $2; \n" + "@p st.shared::cluster.v.u#bitwidth [remoteAddr], {$3, $4}; \n" + "}\n"; + } + if (vec == 4) { + ptxAsm = "{ \n" + ".reg .u32 remoteAddr; \n" + "mapa.shared::cluster.u32 remoteAddr, $0, $1;\n" + ".reg .pred p; \n" + "mov.pred p, $2; \n" + "@p st.shared::cluster.v.u#bitwidth [remoteAddr], {$3, $4, $5, " + "$6}; \n" + "}\n"; + } + return ptxAsm; } }; -class LoadDSmemOpPattern : public mlir::RewritePattern { +class LoadDSmemOpPattern + : public NVGPUOpPatternBase { public: - LoadDSmemOpPattern(mlir::MLIRContext *context) - : mlir::RewritePattern(ttn::LoadDSmemOp::getOperationName(), 1, context) { + using Base = NVGPUOpPatternBase; + using Base::Base; + + std::vector getOutputConstraints(ttn::LoadDSmemOp op) const { + auto bitwidth = op.getBitwidth(); + std::string c = bitwidth == 16 ? "=h" : (bitwidth == 32 ? "=r" : "=l"); + auto vec = op.getVec(); + return std::vector(vec, c); + } + OperandsAndConstraints getOperandsAndConstraints(ttn::LoadDSmemOp op) const { + OperandsAndConstraints operandsAndTypes; + auto addr = op.getAddr(); + auto ctaId = op.getCtaId(); + + operandsAndTypes.push_back({addr, "r"}); + operandsAndTypes.push_back({ctaId, "r"}); + return operandsAndTypes; } - mlir::LogicalResult - matchAndRewrite(mlir::Operation *op, - mlir::PatternRewriter &rewriter) const override { - auto ctx = rewriter.getContext(); - auto loadDSmemOp = llvm::dyn_cast(op); - if (!loadDSmemOp) - return mlir::failure(); - auto loc = op->getLoc(); - auto addr = loadDSmemOp.getAddr(); - auto ctaId = loadDSmemOp.getCtaId(); - auto bitwidth = loadDSmemOp.getBitwidth(); - auto vec = loadDSmemOp.getVec(); + std::string getPtxAsm(ttn::LoadDSmemOp op) const { + auto addr = op.getAddr(); + auto ctaId = op.getCtaId(); + auto bitwidth = op.getBitwidth(); + auto vec = op.getVec(); assert( (bitwidth == 8 || bitwidth == 16 || bitwidth == 32 || bitwidth == 64) && "invalid bitwidth"); assert((vec == 1 || vec == 2 || vec == 4) && "invalid vec size"); - PTXBuilder ptxBuilder; std::string o1 = vec > 1 ? ".v.u" : ".u"; std::string vecStr = vec == 1 ? "$0" @@ -524,58 +694,64 @@ class LoadDSmemOpPattern : public mlir::RewritePattern { o1 + std::to_string(bitwidth) + " " + vecStr + ", [remoteAddr];\n" "}\n"; + return ptxAsm; + } +}; - auto &ptxInstr = *ptxBuilder.create(ptxAsm); - std::string c = bitwidth == 16 ? "=h" : (bitwidth == 32 ? "=r" : "=l"); - SmallVector oprs; - for (unsigned i = 0; i < vec; ++i) { - auto *ret = ptxBuilder.newOperand(c); - oprs.push_back(ret); - } - auto *addrOpr = ptxBuilder.newOperand(addr, "r"); - auto *ctaIdOpr = ptxBuilder.newOperand(ctaId, "r"); - oprs.push_back(addrOpr); - oprs.push_back(ctaIdOpr); - - Type retTy = IntegerType::get(rewriter.getContext(), bitwidth); - SmallVector retTys(vec, retTy); - if (vec > 1) - retTy = struct_ty(retTys); +class WGMMAOpPattern : public NVGPUOpPatternBase { +public: + using Base = NVGPUOpPatternBase; + using Base::Base; - ptxInstr(oprs, - /*onlyAttachMLIRArgs=*/true); + std::vector getOutputConstraints(ttn::WGMMAOp op) const { + // TODO (zahi): Return type must always be a struct for wgmma, currently + // we rely on the size of output constraints vector to determine whether + // the output is a struct or not. We should find a way to pass this info + auto opC = op.getOpC(); + auto typeC = opC.getType(); - auto res = ptxBuilder.launch(rewriter, loc, retTy); - rewriter.replaceOp(op, {res}); - return mlir::success(); + auto structTypeC = typeC.dyn_cast(); + uint32_t numCRegs = structTypeC.getBody().size(); + std::string c = structTypeC.getBody().front().isF32() ? "=f" : "=r"; + return std::vector(numCRegs, c); } -}; -class WGMMAOpPattern : public mlir::RewritePattern { -public: - WGMMAOpPattern(mlir::MLIRContext *context) - : mlir::RewritePattern(ttn::WGMMAOp::getOperationName(), 1, context) {} + OperandsAndConstraints getOperandsAndConstraints(ttn::WGMMAOp op) const { + OperandsAndConstraints operandsAndConstraints; + auto opA = op.getOpA(); + auto opB = op.getOpB(); + auto opC = op.getOpC(); + auto typeA = opA.getType(); - mlir::LogicalResult - matchAndRewrite(mlir::Operation *op, - mlir::PatternRewriter &rewriter) const override { + auto structTypeA = typeA.dyn_cast(); + + // TODO (zahi): is this the best way to tie inputs/outputs ? + operandsAndConstraints.push_back({opC, "0"}); + + if (structTypeA) { + operandsAndConstraints.push_back({opA, "f"}); + } else { + operandsAndConstraints.push_back({opA, "l"}); + } + + // Operand B (must be `desc`) + operandsAndConstraints.push_back({opB, "l"}); + return operandsAndConstraints; + } + + std::string getPtxAsm(ttn::WGMMAOp op) const { using namespace ttn; - auto ctx = rewriter.getContext(); - auto wgmmaOp = llvm::dyn_cast(op); - if (!wgmmaOp) - return mlir::failure(); - auto loc = op->getLoc(); - auto opA = wgmmaOp.getOpA(); - auto opB = wgmmaOp.getOpB(); - auto opC = wgmmaOp.getOpC(); - auto m = wgmmaOp.getM(); - auto n = wgmmaOp.getN(); - auto k = wgmmaOp.getK(); - auto eltTypeC = wgmmaOp.getEltTypeC(); - auto eltTypeA = wgmmaOp.getEltTypeA(); - auto eltTypeB = wgmmaOp.getEltTypeB(); - auto layoutA = wgmmaOp.getLayoutA(); - auto layoutB = wgmmaOp.getLayoutB(); + auto opA = op.getOpA(); + auto opB = op.getOpB(); + auto opC = op.getOpC(); + auto m = op.getM(); + auto n = op.getN(); + auto k = op.getK(); + auto eltTypeC = op.getEltTypeC(); + auto eltTypeA = op.getEltTypeA(); + auto eltTypeB = op.getEltTypeB(); + auto layoutA = op.getLayoutA(); + auto layoutB = op.getLayoutB(); // Register checks auto typeA = opA.getType(); @@ -624,8 +800,6 @@ class WGMMAOpPattern : public mlir::RewritePattern { (m == 64 && 8 <= n && n <= 224 && k == 32); } assert(supported && "WGMMA type or shape is not supported"); - PTXBuilder ptxBuilder; - SmallVector oprs; // Operands uint32_t asmOpIdx = 0; @@ -637,25 +811,9 @@ class WGMMAOpPattern : public mlir::RewritePattern { args += "{"; for (uint32_t i = 0; i < numCRegs; ++i) { args += "$" + std::to_string(asmOpIdx++) + (i == numCRegs - 1 ? "" : ","); - // LLVM does not support `+` semantic, we must repeat the arguments for - // both input and outputs - PTXBuilder::Operand *opr; - if (structTypeC.getBody().front().isF32()) - opr = ptxBuilder.newOperand( - extract_val(structTypeC.getBody()[i], opC, i), "=f"); - else - opr = ptxBuilder.newOperand( - extract_val(structTypeC.getBody()[i], opC, i), "=r"); - oprs.push_back(opr); } args += "}, "; - for (uint32_t i = asmOpIdx - numCRegs; i < asmOpIdx; ++i) { - auto *opr = ptxBuilder.newOperand(i); - oprs.push_back(opr); - } - - // Note that LLVM will not skip the indexed repeating placeholders asmOpIdx += numCRegs; // Operand A if (structTypeA) { @@ -665,21 +823,14 @@ class WGMMAOpPattern : public mlir::RewritePattern { for (uint32_t i = 0; i < numARegs; ++i) { args += "$" + std::to_string(asmOpIdx++) + (i == numARegs - 1 ? "" : ","); - auto *opr = ptxBuilder.newOperand( - extract_val(structTypeA.getBody()[i], opA, i), "f"); - oprs.push_back(opr); } args += "}, "; } else { args += "$" + std::to_string(asmOpIdx++) + ", "; - auto *opr = ptxBuilder.newOperand(opA, "l"); - oprs.push_back(opr); } // Operand B (must be `desc`) args += "$" + std::to_string(asmOpIdx++) + ", "; - auto *opr = ptxBuilder.newOperand(opB, "l"); - oprs.push_back(opr); // `scale-d` is 1 by default args += "1"; @@ -699,338 +850,37 @@ class WGMMAOpPattern : public mlir::RewritePattern { std::to_string(k) + "." + stringifyEnum(eltTypeC).str() + "." + stringifyEnum(eltTypeA).str() + "." + stringifyEnum(eltTypeB).str() + " " + args + ";"; - - auto &ptxInstr = *ptxBuilder.create(ptxAsm); - ptxInstr(oprs, - /*onlyAttachMLIRArgs=*/true); - - auto res = - ptxBuilder.launch(rewriter, loc, structTypeC, /*hasSideEffect*/ true); - rewriter.replaceOp(op, {res}); - return mlir::success(); - } -}; - -class FenceMBarrierInitOpPattern - : public NVGPUOpPatternBase { -public: - using Base = - NVGPUOpPatternBase; - using Base::Base; - - std::string getPtxAsm(ttn::FenceMBarrierInitOp op) const { - return "fence.mbarrier_init.release.cluster;"; - } -}; - -class NamedBarrierArriveOpPattern : public mlir::RewritePattern { -public: - NamedBarrierArriveOpPattern(mlir::MLIRContext *context) - : mlir::RewritePattern(ttn::NamedBarrierArriveOp::getOperationName(), 1, - context) {} - - mlir::LogicalResult - matchAndRewrite(mlir::Operation *op, - mlir::PatternRewriter &rewriter) const override { - auto ctx = rewriter.getContext(); - auto namedBarrierArriveOp = llvm::dyn_cast(op); - if (!namedBarrierArriveOp) - return mlir::failure(); - auto loc = op->getLoc(); - auto bar = namedBarrierArriveOp.getBar(); - auto numThreads = namedBarrierArriveOp.getNumThreads(); - PTXBuilder ptxBuilder; - - auto &ptxInstr = *ptxBuilder.create("bar.arrive $0, $1;"); - auto *barOpr = ptxBuilder.newOperand(bar, "r"); - auto *numThreadsOpr = ptxBuilder.newOperand(numThreads, "r"); - ptxInstr({barOpr, numThreadsOpr}, /*onlyAttachMLIRArgs=*/true); - - auto asmReturnTy = void_ty(ctx); - ptxBuilder.launch(rewriter, loc, asmReturnTy); - rewriter.eraseOp(op); - return mlir::success(); - } -}; - -class NamedBarrierWaitOpPattern : public mlir::RewritePattern { -public: - NamedBarrierWaitOpPattern(mlir::MLIRContext *context) - : mlir::RewritePattern(ttn::NamedBarrierWaitOp::getOperationName(), 1, - context) {} - - mlir::LogicalResult - matchAndRewrite(mlir::Operation *op, - mlir::PatternRewriter &rewriter) const override { - auto ctx = rewriter.getContext(); - auto namedBarrierWaitOp = llvm::dyn_cast(op); - if (!namedBarrierWaitOp) - return mlir::failure(); - auto loc = op->getLoc(); - auto bar = namedBarrierWaitOp.getBar(); - auto numThreads = namedBarrierWaitOp.getNumThreads(); - PTXBuilder ptxBuilder; - - auto &ptxInstr = *ptxBuilder.create("bar.sync $0, $1;"); - auto *barOpr = ptxBuilder.newOperand(bar, "r"); - auto *numThreadsOpr = ptxBuilder.newOperand(numThreads, "r"); - ptxInstr({barOpr, numThreadsOpr}, /*onlyAttachMLIRArgs=*/true); - - auto asmReturnTy = void_ty(ctx); - ptxBuilder.launch(rewriter, loc, asmReturnTy); - rewriter.eraseOp(op); - return mlir::success(); - } -}; - -class CGABarrierArriveOpPattern - : public NVGPUOpPatternBase { -public: - using Base = - NVGPUOpPatternBase; - using Base::Base; - std::string getPtxAsm(ttn::CGABarrierArriveOp op) const { - return "barrier.cluster.arrive;"; - } -}; - -class CGABarrierWaitOpPattern - : public NVGPUOpPatternBase { -public: - using Base = - NVGPUOpPatternBase; - using Base::Base; - std::string getPtxAsm(ttn::CGABarrierWaitOp op) const { - return "barrier.cluster.wait;"; - } -}; - -class StoreDSmemOpPattern : public mlir::RewritePattern { -public: - StoreDSmemOpPattern(mlir::MLIRContext *context) - : mlir::RewritePattern(ttn::StoreDSmemOp::getOperationName(), 1, - context) {} - - mlir::LogicalResult - matchAndRewrite(mlir::Operation *op, - mlir::PatternRewriter &rewriter) const override { - auto ctx = rewriter.getContext(); - auto storeDSmemOp = llvm::dyn_cast(op); - if (!storeDSmemOp) - return mlir::failure(); - auto loc = op->getLoc(); - auto addr = storeDSmemOp.getAddr(); - auto ctaId = storeDSmemOp.getCtaId(); - auto values = storeDSmemOp.getValues(); - auto pred = storeDSmemOp.getPred(); - - auto bitwidth = storeDSmemOp.getBitwidth(); - auto vec = storeDSmemOp.getVec(); - assert( - (bitwidth == 8 || bitwidth == 16 || bitwidth == 32 || bitwidth == 64) && - "invalid bitwidth"); - assert((vec == 1 || vec == 2 || vec == 4) && vec == values.size() && - "invalid vec size"); - - PTXBuilder ptxBuilder; - - std::string ptxAsm = "{\n\t" - ".reg .u32 remoteAddr;\n\t" - "mapa.shared::cluster.u32 remoteAddr, $0, $1;\n\t" - ".reg .pred p;\n\t" - "mov.pred p, $2;\n\t" - "@p st.shared::cluster"; - if (vec > 1) - ptxAsm += ".v" + std::to_string(vec); - ptxAsm += ".u" + std::to_string(bitwidth) + " [remoteAddr], "; - if (vec == 1) - ptxAsm += "$3"; - else if (vec == 2) - ptxAsm += "{$3, $4}"; - else if (vec == 4) - ptxAsm += "{$3, $4, $5, $6}"; - ptxAsm += ";\n\t"; - ptxAsm += "}\n"; - auto &ptxInstr = *ptxBuilder.create(ptxAsm); - - std::string c = bitwidth == 16 ? "h" : (bitwidth == 32 ? "r" : "l"); - SmallVector oprs; - auto *addrOpr = ptxBuilder.newOperand(addr, "r"); - oprs.push_back(addrOpr); - auto *ctaIdOpr = ptxBuilder.newOperand(ctaId, "r"); - oprs.push_back(ctaIdOpr); - auto *predOpr = ptxBuilder.newOperand(pred, "b"); - oprs.push_back(predOpr); - for (unsigned i = 0; i < values.size(); i++) { - auto *valueOpr = ptxBuilder.newOperand(values[i], c); - oprs.push_back(valueOpr); - } - ptxInstr(oprs, - /*onlyAttachMLIRArgs=*/true); - - auto asmReturnTy = void_ty(ctx); - ptxBuilder.launch(rewriter, loc, asmReturnTy, /*hasSideEffect*/ true); - rewriter.eraseOp(op); - return mlir::success(); - } -}; - -class Sts64OpPattern : public mlir::RewritePattern { -public: - Sts64OpPattern(mlir::MLIRContext *context) - : mlir::RewritePattern(ttn::Sts64Op::getOperationName(), 1, context) {} - - mlir::LogicalResult - matchAndRewrite(mlir::Operation *op, - mlir::PatternRewriter &rewriter) const override { - auto ctx = rewriter.getContext(); - auto sts64Op = llvm::dyn_cast(op); - if (!sts64Op) - return mlir::failure(); - auto loc = op->getLoc(); - auto offset = sts64Op.getOffset(); - auto d0 = sts64Op.getD0(); - auto d1 = sts64Op.getD1(); - - PTXBuilder ptxBuilder; - - std::string ptxAsm = "st.shared.v2.b32 [$0], {$1, $2}"; - auto &ptxInstr = *ptxBuilder.create(ptxAsm); - - SmallVector oprs; - auto *addrOpr = ptxBuilder.newOperand(offset, "r"); - auto *d0Opr = ptxBuilder.newOperand(d0, "r"); - auto *d1Opr = ptxBuilder.newOperand(d1, "r"); - - ptxInstr({addrOpr, d0Opr, d1Opr}, - /*onlyAttachMLIRArgs=*/true); - - auto asmReturnTy = void_ty(ctx); - ptxBuilder.launch(rewriter, loc, asmReturnTy); - rewriter.eraseOp(op); - return mlir::success(); + return ptxAsm; } }; -class RegAllocOpPattern - : public NVGPUOpPatternBase { +class OffsetOfSts64OpPattern + : public NVGPUOpPatternBase { public: - using Base = NVGPUOpPatternBase; + using Base = NVGPUOpPatternBase; using Base::Base; - std::string getPtxAsm(ttn::RegAllocOp op) const { - auto regCount = op.getRegCount(); - return "setmaxnreg.inc.sync.aligned.u32 " + std::to_string(regCount) + ";"; + std::vector getOutputConstraints(ttn::OffsetOfSts64Op op) const { + return {"=r"}; } -}; - -class RegDeallocOpPattern - : public NVGPUOpPatternBase { -public: - using Base = NVGPUOpPatternBase; - using Base::Base; - - std::string getPtxAsm(ttn::RegDeallocOp op) const { - auto regCount = op.getRegCount(); - return "setmaxnreg.dec.sync.aligned.u32 " + std::to_string(regCount) + ";"; - } -}; - -class ClusterCTAIdOpPattern : public mlir::RewritePattern { -public: - ClusterCTAIdOpPattern(mlir::MLIRContext *context) - : mlir::RewritePattern(ttn::ClusterCTAIdOp::getOperationName(), 1, - context) {} - - mlir::LogicalResult - matchAndRewrite(mlir::Operation *op, - mlir::PatternRewriter &rewriter) const override { - auto ctx = rewriter.getContext(); - auto clusterCTAIdOp = llvm::dyn_cast(op); - if (!clusterCTAIdOp) - return mlir::failure(); - auto loc = op->getLoc(); - auto x = getSRegValue(rewriter, loc, "%cluster_ctaid.x"); - auto y = getSRegValue(rewriter, loc, "%cluster_ctaid.y"); - auto z = getSRegValue(rewriter, loc, "%cluster_ctaid.z"); - auto nx = getSRegValue(rewriter, loc, "%cluster_nctaid.x"); - auto ny = getSRegValue(rewriter, loc, "%cluster_nctaid.y"); - auto res = add(x, mul(add(y, mul(z, ny)), nx)); - rewriter.replaceOp(op, {res}); - return mlir::success(); - } -}; + OperandsAndConstraints + getOperandsAndConstraints(ttn::OffsetOfSts64Op op) const { + OperandsAndConstraints operandsAndConstraints; + auto threadId = op.getThreadId(); + auto rowOfWarp = op.getRowOfWarp(); + auto elemIdx = op.getElemIdx(); -class WGMMADescCreateOpPattern : public mlir::RewritePattern { -public: - WGMMADescCreateOpPattern(mlir::MLIRContext *context) - : mlir::RewritePattern(ttn::WGMMADescCreateOp::getOperationName(), 1, - context) {} + operandsAndConstraints.push_back({threadId, "r"}); + operandsAndConstraints.push_back({elemIdx, "r"}); + operandsAndConstraints.push_back({rowOfWarp, "r"}); - mlir::LogicalResult - matchAndRewrite(mlir::Operation *op, - mlir::PatternRewriter &rewriter) const override { - auto ctx = rewriter.getContext(); - auto wgmmaDescCreateOp = llvm::dyn_cast(op); - if (!wgmmaDescCreateOp) - return mlir::failure(); - auto loc = op->getLoc(); - auto buffer = wgmmaDescCreateOp.getBuffer(); - auto height = wgmmaDescCreateOp.getHeight(); - uint32_t mode = static_cast(wgmmaDescCreateOp.getMode()); - - auto smem_nvvm_pointer = ptrtoint(i64_ty, buffer); - - Value desc = int_val(64, 0); - uint64_t swizzling = (mode == 1 ? 128 : mode == 2 ? 64 : 32); - Value swizzling_ = int_val(64, swizzling); - Value smem_address_bit = smem_nvvm_pointer; - - Value strideDimension = - lshr(shl(swizzling_, int_val(64, 3)), int_val(64, 4)); - Value height64 = zext(i64_ty, height); - Value leadingDimension = lshr(mul(height64, swizzling_), int_val(64, 4)); - - // Value baseOffset = int_val(64, 0); - Value startAddr = - lshr(shl(smem_address_bit, int_val(64, 46)), int_val(64, 50)); - - Value mode_ = int_val(64, mode); - desc = or_(desc, shl(mode_, int_val(64, 62))); - desc = or_(desc, shl(strideDimension, int_val(64, 32))); - desc = or_(desc, shl(leadingDimension, int_val(64, 16))); - // desc = or_(desc, shl(baseOffset, int_val(64, 49))); - desc = or_(desc, startAddr); - - rewriter.replaceOp(op, {desc}); - return mlir::success(); + return operandsAndConstraints; } -}; -class OffsetOfSts64OpPattern : public mlir::RewritePattern { -public: - OffsetOfSts64OpPattern(mlir::MLIRContext *context) - : mlir::RewritePattern(ttn::OffsetOfSts64Op::getOperationName(), 1, - context) {} - - mlir::LogicalResult - matchAndRewrite(mlir::Operation *op, - mlir::PatternRewriter &rewriter) const override { - auto ctx = rewriter.getContext(); - auto offsetOfSts64Op = llvm::dyn_cast(op); - if (!offsetOfSts64Op) - return mlir::failure(); - auto loc = op->getLoc(); - auto threadId = offsetOfSts64Op.getThreadId(); - auto rowOfWarp = offsetOfSts64Op.getRowOfWarp(); - auto elemIdx = offsetOfSts64Op.getElemIdx(); - auto leadingDimOffset = offsetOfSts64Op.getLeadingDimOffset(); - auto rowStride = offsetOfSts64Op.getRowStride(); - auto swizzleEnabled = offsetOfSts64Op.getSwizzleEnabled(); + std::string getPtxAsm(ttn::OffsetOfSts64Op op) const { + auto rowStride = op.getRowStride(); + auto swizzleEnabled = op.getSwizzleEnabled(); if (swizzleEnabled) { assert((rowStride == 32 || rowStride == 64 || rowStride == 128) && @@ -1048,51 +898,77 @@ class OffsetOfSts64OpPattern : public mlir::RewritePattern { } else if (rowStride == 32) { perPhase = 4; maxPhase = 2; + } else { + assert(false && "Unsupported rowStride"); } - auto laneId = and_(threadId, i32_val(0x1f)); - auto myRow = - add(mul(and_(lshr(elemIdx, i32_val(1)), i32_val(0x1)), i32_val(8)), - udiv(laneId, i32_val(4))); - auto myCol = add(mul(udiv(elemIdx, i32_val(4)), i32_val(8)), - mul(urem(laneId, i32_val(4)), i32_val(2))); - myRow = add(myRow, rowOfWarp); - auto phase = urem(udiv(myRow, i32_val(perPhase)), i32_val(maxPhase)); - auto lineOffset = - add(mul(urem(myRow, i32_val(perPhase)), i32_val(rowStride)), - mul(myCol, i32_val(4))); - auto colOffset = - add(mul(xor_(udiv(lineOffset, i32_val(16)), phase), i32_val(16)), - urem(lineOffset, i32_val(16))); - auto offset = - add(mul(udiv(myRow, i32_val(perPhase)), i32_val(128)), colOffset); - - rewriter.replaceOp(op, {offset}); - return mlir::success(); + auto ptxAsm = "{\n" + ".reg .u32 a<9>; \n" + "and.b32 a0, $1, 0x1f;\n" // laneid + "shr.b32 a1, $2, 4; \n" + "and.b32 a1, a1, 0x1; \n" + "div.u32 a2, a0, 4; \n" + "mad.lo.u32 a2, a1, 8, a2; \n" // myRow + "div.u32 a3, $2, 4; \n" + "rem.u32 a4, a0, 4; \n" + "mul.lo.u32 a4, a4, 2; \n" + "mad.lo.u32 a4, a3, 8, a4; \n" // myCol + "add.u32 a2, a2, $3; \n" // myRow = myRow + rowOfWarp + "div.u32 a3, a2, " + + std::to_string(perPhase) + + "; \n" + "rem.u32 a3, a3, " + + std::to_string(maxPhase) + + "; \n" // phase + "rem.u32 a5, a2, " + + std::to_string(perPhase) + + "; \n" // lineOffset + "mul.lo.u32 a5, a5, #rowStride; \n" + "mad.lo.u32 a5, a4, 4, a5; \n" // lineOffset + "div.u32 a6, a5, 16; \n" + "xor.b32 a6, a6, a3; \n" // colOffset + "rem.u32 a7, a5, 16; \n" + "mad.lo.u32 a7, a6, 16, a7; \n" // colOffset + "div.u32 a8, a2, #perPhase; \n" + "mad.lo.u32 $0, a8, 128, a7; \n" // offset + "}"; + return ptxAsm; } }; -class OffsetOfStmatrixV4OpPattern : public mlir::RewritePattern { +class OffsetOfStmatrixV4OpPattern + : public NVGPUOpPatternBase { public: - OffsetOfStmatrixV4OpPattern(mlir::MLIRContext *context) - : mlir::RewritePattern(ttn::OffsetOfStmatrixV4Op::getOperationName(), 1, - context) {} + using Base = NVGPUOpPatternBase; + using Base::Base; - mlir::LogicalResult - matchAndRewrite(mlir::Operation *op, - mlir::PatternRewriter &rewriter) const override { - auto ctx = rewriter.getContext(); - auto offsetOfStmatrixV4Op = llvm::dyn_cast(op); - if (!offsetOfStmatrixV4Op) - return mlir::failure(); - auto loc = op->getLoc(); - auto threadId = offsetOfStmatrixV4Op.getThreadId(); - auto rowOfWarp = offsetOfStmatrixV4Op.getRowOfWarp(); - auto elemIdx = offsetOfStmatrixV4Op.getElemIdx(); - auto leadingDimOffset = offsetOfStmatrixV4Op.getLeadingDimOffset(); - auto rowStride = offsetOfStmatrixV4Op.getRowStride(); - auto swizzleEnabled = offsetOfStmatrixV4Op.getSwizzleEnabled(); + std::vector + getOutputConstraints(ttn::OffsetOfStmatrixV4Op op) const { + return {"=r"}; + } + OperandsAndConstraints + getOperandsAndConstraints(ttn::OffsetOfStmatrixV4Op op) const { + OperandsAndConstraints operandsAndConstraints; + auto threadId = op.getThreadId(); + auto rowOfWarp = op.getRowOfWarp(); + auto elemIdx = op.getElemIdx(); + + operandsAndConstraints.push_back({threadId, "r"}); + operandsAndConstraints.push_back({elemIdx, "r"}); + operandsAndConstraints.push_back({rowOfWarp, "r"}); + + return operandsAndConstraints; + } + + std::string getPtxAsm(ttn::OffsetOfStmatrixV4Op op) const { + auto leadingDimOffset = op.getLeadingDimOffset(); + auto rowStride = op.getRowStride(); + auto swizzleEnabled = op.getSwizzleEnabled(); + + std::string ptxAsm; if (swizzleEnabled) { uint32_t perPhase = 0; uint32_t maxPhase = 0; @@ -1105,43 +981,71 @@ class OffsetOfStmatrixV4OpPattern : public mlir::RewritePattern { } else if (rowStride == 16) { perPhase = 4; maxPhase = 2; + } else { + assert(false && "Unsupported rowStride"); } - Value iterOfCol = udiv(elemIdx, i32_val(8)); - Value myRow = add(rowOfWarp, and_(threadId, i32_val(0xf))); - Value myCol = - mul(and_(lshr(threadId, i32_val(4)), i32_val(0x1)), i32_val(8)); - myCol = add(myCol, mul(iterOfCol, i32_val(16))); - - Value offset0 = - mul(udiv(myCol, i32_val(rowStride)), i32_val(leadingDimOffset)); - myCol = urem(myCol, i32_val(rowStride)); - - Value phase = urem(udiv(myRow, i32_val(perPhase)), i32_val(maxPhase)); - - Value lineOffset = - add(mul(urem(myRow, i32_val(perPhase)), i32_val(rowStride)), myCol); - Value colOffset = - add(mul(xor_(udiv(lineOffset, i32_val(8)), phase), i32_val(8)), - urem(lineOffset, i32_val(8))); - Value offset1 = - add(mul(udiv(myRow, i32_val(perPhase)), i32_val(64)), colOffset); - - Value res = add(offset1, offset0); - - rewriter.replaceOp(op, {res}); + ptxAsm = + "{\n" + ".reg .u32 a<10>; \n" + "div.u32 a0, $2, 8; \n" // iterOfCol = udiv(elemIdx, i32_val(8)) + "and.b32 a1, $1, 0xf; \n" // myRow = and_(threadId, i32_val(0xf)) + "add.u32 a1, a1, $3; \n" // myRow = myRow + rowOfWarp + "shr.b32 a2, $1, 4; \n" // myCol = lshr(threadId, i32_val(4)) + "and.b32 a2, a2, 0x1; \n" // myCol = and_(myCol, i32_val(0x1)) + "mul.lo.u32 a2, a2, 8; \n" // myCol = mul(myCol, i32_val(8)) + "mad.lo.u32 a2, a0, 16, a2; \n" // myCol = add(myCol, + // mul(iterOfCol, i32_val(16))) + "div.u32 a3, a2, #rowStride; \n" // offset0 = udiv(myCol, + // i32_val(rowStride)) + "mul.lo.u32 a3, a3, #leadingDimOffset; \n" // offset0 = mul(offset0, + // i32_val(leadingDimOffset)) + "rem.u32 a2, a2, #rowStride; \n" // myCol = myCol % rowStride + "div.u32 a4, a1, " + + std::to_string(perPhase) + + "; \n" // phase = myrow // perPhase + "rem.u32 a4, a4, " + + std::to_string(maxPhase) + + "; \n" // phase = phase % maxPhase + "rem.u32 a5, a1, " + + std::to_string(perPhase) + + "; \n" // lineOffset = urem(myRow, i32_val(perPhase)) + "mad.lo.u32 a5, a5, #rowStride, a2; \n" // lineOffset = + // add(mul(lineOffset, + // rowStride), myCol) + "div.u32 a6, a5, 8; \n" // colOffset = udiv(lineOffset, i32_val(8) + "xor.b32 a6, a6, a4; \n" // colOffset = xor_(colOffset, phase) + "rem.u32 a7, a5, 8; \n" // temp = urem(lineOffset, i32_val(8) + "mad.lo.u32 a7, a6, 8, a7; \n" // colOffset = add(mul(colOffset, + // i32_val(8)), temp) + "div.u32 a8, a1, " + + std::to_string(perPhase) + + "; \n" // offset1 = udiv(myRow, i32_val(perPhase)) + "mad.lo.u32 a9, a8, 64, a7; \n" // offset1 = add(mul(offset1, + // i32_val(64)), colOffset) + "add.u32 $0, a9, a3; \n" // result = add(offset1, offset0) + "}"; } else { - Value iterOfCol = udiv(elemIdx, i32_val(4)); - Value myRow = add(rowOfWarp, and_(threadId, i32_val(0xf))); - Value myCol = - mul(and_(lshr(threadId, i32_val(4)), i32_val(0x1)), i32_val(8)); - myCol = add(myCol, mul(iterOfCol, i32_val(16))); - - Value offset = - add(mul(myRow, i32_val(rowStride)), mul(myCol, i32_val(2))); - rewriter.replaceOp(op, {offset}); + ptxAsm = "{\n" + ".reg .u64 a<5>; \n" + "div.u32 a0, $2, 4; \n" // iterOfCol = udiv(elemIdx, + // i32_val(4)) + "and.b32 a1, $1, 0xf; \n" // myRow = and_(threadId, + // i32_val(0xf)) + "add.u32 a1, a1, $3; \n" // myRow = myRow + rowOfWarp + "shr.b32 a2, $1, 4; \n" // myCol = lshr(threadId, + // i32_val(4)) + "and.b32 a2, a2, 0x1; \n" // myCol = and_(myCol, + // i32_val(0x1)) + "mul.lo.u32 a2, a2, 8; \n" // myCol = mul(myCol, + // i32_val(8)) + "mul.u32 a3, a1, #rowStride; \n" // offset = myRow * rowStride + "mad.lo.u32 $0, a2, 2, a3; \n" // result = add(mul(myCol, + // i32_val(2)), offset) + "}\n"; } - return mlir::success(); + + return ptxAsm; } }; @@ -1155,35 +1059,43 @@ class ConvertNVGPUToLLVM : public ConvertNVGPUToLLVMBase { ModuleOp mod = getOperation(); RewritePatternSet patterns(context); - patterns.add(context); - patterns.add(context); - patterns.add(context); - patterns.add(context); - patterns.add(context); - patterns.add(context); - patterns.add(context); - patterns.add(context); - patterns.add(context); - patterns.add(context); - patterns.add(context); - patterns.add(context); - patterns.add(context); - patterns.add(context); - patterns.add(context); - patterns.add(context); - patterns.add(context); - patterns.add(context); - patterns.add(context); - patterns.add(context); - patterns.add(context); - patterns.add(context); - - patterns.add(context); - patterns.add(context); - patterns.add(context); - patterns.add(context); - patterns.add(context); - patterns.add(context); +#define POPULATE_NVGPU_OP(SRC_OP, ASM) \ + patterns.add>(context, ASM, Constraints(), \ + Constraints()); + POPULATE_NVGPU_OP(ttn::RegAllocOp, Reg_Alloc_Op) + POPULATE_NVGPU_OP(ttn::WGMMAFenceOp, Wgmma_Fence_Op) + POPULATE_NVGPU_OP(ttn::CGABarrierSyncOp, Cga_Barrier_Sync_op) + POPULATE_NVGPU_OP(ttn::WGMMACommitGroupOp, Wgmma_Commit_Group_Op) + POPULATE_NVGPU_OP(ttn::WGMMAWaitGroupOp, Wgmma_Wait_Group_Op) + POPULATE_NVGPU_OP(ttn::ClusterWaitOp, Cluster_Wait_Op) + POPULATE_NVGPU_OP(ttn::FenceMBarrierInitOp, Fence_Mbarrier_Init_Op) + POPULATE_NVGPU_OP(ttn::CGABarrierArriveOp, Cga_Barrier_Arrive_Op) + POPULATE_NVGPU_OP(ttn::CGABarrierWaitOp, Cga_Barrier_Wait_Op) + POPULATE_NVGPU_OP(ttn::RegDeallocOp, Reg_Dealloc_Op) +#undef POPULATE_NVGPU_OP + patterns.add>( + context, Mbarrier_Init_Op, Constraints(), Constraints({"r", "b"})); + patterns.add>( + context, Mbarrier_Wait_Op, Constraints(), Constraints({"r", "r"})); + patterns.add>( + context, Named_Barrier_Arrive_Op, Constraints(), + Constraints({"r", "r"})); + patterns.add>( + context, Named_Barrier_Wait_Op, Constraints(), Constraints({"r", "r"})); + patterns.add>( + context, Sts64_Op, Constraints(), Constraints({"r", "r", "r"})); + patterns.add>( + context, Cluster_Cta_Id_Op, Constraints({"=r"}), Constraints()); + patterns.add>( + context, Wgmma_Desc_Create_op, Constraints({"=l"}), + Constraints({"l", "l"})); + + patterns.add(context); + if (applyPatternsAndFoldGreedily(mod, std::move(patterns)).failed()) signalPassFailure(); } diff --git a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/WGMMA.cpp b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/WGMMA.cpp index fbccdeefcfb0..f03d09788c9d 100644 --- a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/WGMMA.cpp +++ b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/WGMMA.cpp @@ -115,7 +115,7 @@ class DotOpMmaV3SmemLoader { mode = getModeFromLayout(sharedLayout, widthInByte); baseDesc = rewriter.create( - loc, i64_ty, base, i32_val(shape[ord[1]]), mode); + loc, base, i32_val(shape[ord[1]]), mode); } Value smemLoad(int a, int b) { diff --git a/test/NVGPU/test_cga.mlir b/test/NVGPU/test_cga.mlir index 0b72d92e1ee0..8b9705db54f2 100644 --- a/test/NVGPU/test_cga.mlir +++ b/test/NVGPU/test_cga.mlir @@ -17,14 +17,6 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 2 : %ptr = llvm.mlir.null : !llvm.ptr // CHECK: llvm.inline_asm - // CHECK: llvm.inline_asm - // CHECK: llvm.inline_asm - // CHECK: llvm.inline_asm - // CHECK: llvm.inline_asm - // CHECK: llvm.mul - // CHECK: llvm.add - // CHECK: llvm.mul - // CHECK: llvm.add %v = nvgpu.cluster_id llvm.store %v, %ptr : !llvm.ptr diff --git a/test/NVGPU/test_wgmma.mlir b/test/NVGPU/test_wgmma.mlir index bb4844ab5d18..f4ae65ad04cf 100644 --- a/test/NVGPU/test_wgmma.mlir +++ b/test/NVGPU/test_wgmma.mlir @@ -5,37 +5,11 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 2 %buffer = llvm.mlir.null : !llvm.ptr %height = arith.constant 16 : i32 // CHECK: llvm.ptrtoint - // CHECK: llvm.shl - // CHECK: llvm.lshr - // CHECK: llvm.zext - // CHECK: llvm.mul - // CHECK: llvm.lshr - // CHECK: llvm.shl - // CHECK: llvm.lshr - // CHECK: llvm.shl - // CHECK: llvm.or - // CHECK: llvm.shl - // CHECK: llvm.or - // CHECK: llvm.shl - // CHECK: llvm.or - // CHECK: llvm.or - %descA = nvgpu.wgmma_desc_create %buffer, %height {mode = 2 : i32}: (!llvm.ptr, i32) -> (i64) + // CHECK: llvm.inline_asm + %descA = nvgpu.wgmma_desc_create %buffer, %height {mode = 2 : i32, swizzling = 64 : i64}: (!llvm.ptr, i32) -> (i64) // CHECK: llvm.ptrtoint - // CHECK: llvm.shl - // CHECK: llvm.lshr - // CHECK: llvm.zext - // CHECK: llvm.mul - // CHECK: llvm.lshr - // CHECK: llvm.shl - // CHECK: llvm.lshr - // CHECK: llvm.shl - // CHECK: llvm.or - // CHECK: llvm.shl - // CHECK: llvm.or - // CHECK: llvm.shl - // CHECK: llvm.or - // CHECK: llvm.or - %descB = nvgpu.wgmma_desc_create %buffer, %height {mode = 2 : i32}: (!llvm.ptr, i32) -> (i64) + // CHECK: llvm.inline_asm + %descB = nvgpu.wgmma_desc_create %buffer, %height {mode = 2 : i32, swizzling = 64 : i64}: (!llvm.ptr, i32) -> (i64) // CHECK-COUNT-32: llvm.extractvalue // CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, 1, 1, 1, 0, 1;" From 13189bfe60e135d3e7e624f8a6d5e953951c1b5e Mon Sep 17 00:00:00 2001 From: David Berard Date: Fri, 1 Sep 2023 17:08:56 -0700 Subject: [PATCH 007/122] Coalesce pass - group values with same shape/order but different element type (#2199) **Motivation**: We have a kernel that loads multiple types of tensors - some int32 and some float16. The coalescing pass assigns `perThread = 8` for the float16 tensors and `perThread = 4` for the int32 tensors, resulting in unnecessary layout conversions that result in bad performance. Instead, we should just set `perThread = 8` for both of these loads. **Details**: One of the first steps in calculating the new encoding is to find the group of upstream/downstream tensors with the "same type", in order to find the maximal sizePerThread required in this group. This PR changes the logic so that tensors can be grouped as long as they have the same shape and same optimal ordering, even if they have different encoding or dtype. Next, the logic to compute `perThread` is updated to account for the change above; since dtype can now be different within a single "group", the `perThread` computation now considers different elemNumBits/elemNumBytes for each value in the group. --- lib/Dialect/TritonGPU/Transforms/Coalesce.cpp | 59 +++++++++++++--- test/TritonGPU/coalesce.mlir | 68 +++++++++++++++++++ 2 files changed, 117 insertions(+), 10 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp b/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp index 5ebc88083fca..9d6cd240d0a9 100644 --- a/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp @@ -21,11 +21,25 @@ template SmallVector argSort(const T &arr) { return ret; } +unsigned getElementBitWidth(const Value &val) { + auto valType = val.getType(); + if (valType.isa()) + valType = valType.cast().getPointeeType(); + auto tensorType = valType.cast(); + + auto typeForMem = + tensorType.getElementType().isa() + ? tensorType.getElementType().cast().getPointeeType() + : tensorType.getElementType(); + return typeForMem.getIntOrFloatBitWidth(); +} + typedef DenseMap> LayoutMap; struct CoalescePass : public TritonGPUCoalesceBase { Attribute getCoalescedEncoding(ModuleAxisInfoAnalysis &axisInfoAnalysis, - Value ptr, int numWarps, int threadsPerWarp) { + Value ptr, Operation *op, int numWarps, + int threadsPerWarp) { auto refType = ptr.getType(); if (refType.isa()) refType = refType.cast().getPointeeType(); @@ -74,6 +88,18 @@ struct CoalescePass : public TritonGPUCoalesceBase { order = argSort(queryAxisInfo(ptr).getContiguity()); } + auto matchesOrder = [&refTensorType](const Value &val) { + if (val.getType() == refTensorType) { + return true; + } + + auto rttType = val.getType().dyn_cast(); + if (!rttType) { + return false; + } + return rttType.getShape() == refTensorType.getShape(); + }; + // The desired divisibility is the maximum divisibility // among all dependent pointers who have the same order as // `ptr`. @@ -83,7 +109,7 @@ struct CoalescePass : public TritonGPUCoalesceBase { if (refType.isa() && ptr.getDefiningOp()) { for (Operation *op : mlir::multiRootGetSlice(ptr.getDefiningOp())) { for (Value val : op->getResults()) { - if (val.getType() != refTensorType) + if (!matchesOrder(val)) continue; auto currOrder = argSort(axisInfoAnalysis.getAxisInfo(val)->getContiguity()); @@ -109,11 +135,11 @@ struct CoalescePass : public TritonGPUCoalesceBase { // Thread tile size depends on memory alignment SmallVector sizePerThread(refTensorType.getRank(), 1); - unsigned elemNumBits = typeForMem.getIntOrFloatBitWidth(); - unsigned elemNumBytes = std::max(elemNumBits / 8, 1u); unsigned perThread = 1; for (Value val : withSameOrder) { auto valInfo = queryAxisInfo(val); + unsigned elemNumBits = getElementBitWidth(val); + unsigned elemNumBytes = std::max(elemNumBits / 8, 1u); unsigned maxMultipleBytes = valInfo.getDivisibility(order[0]); unsigned maxMultiple = std::max(maxMultipleBytes / elemNumBytes, 1u); unsigned maxContig = @@ -122,7 +148,20 @@ struct CoalescePass : public TritonGPUCoalesceBase { unsigned currPerThread = std::min(alignment, 128 / elemNumBits); perThread = std::max(perThread, currPerThread); } - sizePerThread[order[0]] = std::min(perThread, numElemsPerThread); + + perThread = std::min(perThread, numElemsPerThread); + + if (!dyn_cast(op)) { + // For ops that can result in a global memory write, we should enforce + // that each thread handles at most 128 bits, which is the widest + // available vectorized store op; otherwise, the store will have "gaps" + // in the memory write at the warp level, resulting in worse performance. + // For loads, we can expect that the gaps won't matter due to the L1 + // cache. + unsigned elemNumBits = getElementBitWidth(ptr); + perThread = std::min(perThread, 128 / elemNumBits); + } + sizePerThread[order[0]] = perThread; auto CTALayout = triton::gpu::getCTALayout(refTensorType.getEncoding()); return triton::gpu::BlockedEncodingAttr::get( @@ -132,9 +171,9 @@ struct CoalescePass : public TritonGPUCoalesceBase { std::function getTypeConverter(ModuleAxisInfoAnalysis &axisInfoAnalysis, Value ptr, - int numWarps, int threadsPerWarp) { - Attribute encoding = - getCoalescedEncoding(axisInfoAnalysis, ptr, numWarps, threadsPerWarp); + Operation *op, int numWarps, int threadsPerWarp) { + Attribute encoding = getCoalescedEncoding(axisInfoAnalysis, ptr, op, + numWarps, threadsPerWarp); return [encoding](Type type) { RankedTensorType tensorType = type.cast(); return RankedTensorType::get(tensorType.getShape(), @@ -240,8 +279,8 @@ struct CoalescePass : public TritonGPUCoalesceBase { int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); - auto convertType = - getTypeConverter(axisInfoAnalysis, ptr, numWarps, threadsPerWarp); + auto convertType = getTypeConverter(axisInfoAnalysis, ptr, curr, numWarps, + threadsPerWarp); layoutMap[ptr] = convertType; }); diff --git a/test/TritonGPU/coalesce.mlir b/test/TritonGPU/coalesce.mlir index 56c325683f55..f2b7f225d763 100644 --- a/test/TritonGPU/coalesce.mlir +++ b/test/TritonGPU/coalesce.mlir @@ -69,3 +69,71 @@ tt.func @load_tensor(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: } } + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + + +// CHECK: [[NARROW_LAYOUT:#.*]] = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +// CHECK: [[WIDE_LAYOUT:#.*]] = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +tt.func public @load_tensors_two_types(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) attributes {noinline = false} { + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + %3 = tt.splat %1 : (i32) -> tensor<1024xi32, #blocked> + %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> + %5 = tt.splat %arg3 : (i32) -> tensor<1024xi32, #blocked> + %6 = "triton_gpu.cmpi"(%4, %5) <{predicate = 2 : i64}> : (tensor<1024xi32, #blocked>, tensor<1024xi32, #blocked>) -> tensor<1024xi1, #blocked> + %7 = tt.splat %arg0 : (!tt.ptr) -> tensor<1024x!tt.ptr, #blocked> + %8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %9 = tt.load %8, %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked> + %10 = tt.splat %arg1 : (!tt.ptr) -> tensor<1024x!tt.ptr, #blocked> + %11 = tt.addptr %10, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %12 = tt.load %11, %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf16, #blocked> + %13 = arith.extf %12 : tensor<1024xf16, #blocked> to tensor<1024xf32, #blocked> + %14 = arith.addf %9, %13 : tensor<1024xf32, #blocked> + %15 = tt.splat %arg2 : (!tt.ptr) -> tensor<1024x!tt.ptr, #blocked> + %16 = tt.addptr %15, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + // CHECK: tt.store {{.*}} : tensor<1024xf32, [[WIDE_LAYOUT]]> + tt.store %16, %14, %6 {cache = 1 : i32, evict = 1 : i32} : tensor<1024xf32, #blocked> + tt.return +} + +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + +// CHECK-NOT: sizePerThread = [4] +// CHECK: #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +// CHECK-NOT: sizePerThread = [4] +tt.func public @load_tensors_two_types(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) attributes {noinline = false} { + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + %3 = tt.splat %1 : (i32) -> tensor<1024xi32, #blocked> + %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> + %5 = tt.splat %arg3 : (i32) -> tensor<1024xi32, #blocked> + %6 = "triton_gpu.cmpi"(%4, %5) <{predicate = 2 : i64}> : (tensor<1024xi32, #blocked>, tensor<1024xi32, #blocked>) -> tensor<1024xi1, #blocked> + %7 = tt.splat %arg0 : (!tt.ptr) -> tensor<1024x!tt.ptr, #blocked> + %8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %9 = tt.load %8, %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked> + %10 = tt.splat %arg1 : (!tt.ptr) -> tensor<1024x!tt.ptr, #blocked> + %11 = tt.addptr %10, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %12 = tt.load %11, %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf16, #blocked> + %13 = arith.extf %12 : tensor<1024xf16, #blocked> to tensor<1024xf32, #blocked> + %14 = arith.addf %9, %13 : tensor<1024xf32, #blocked> + %15 = tt.splat %arg2 : (!tt.ptr) -> tensor<1024x!tt.ptr, #blocked> + %16 = tt.addptr %15, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %17 = arith.truncf %14 : tensor<1024xf32, #blocked> to tensor<1024xf16, #blocked> + tt.store %16, %17, %6 {cache = 1 : i32, evict = 1 : i32} : tensor<1024xf16, #blocked> + tt.return +} + +} From a5398368768aafa4938d29ac39d58ec8a965b11a Mon Sep 17 00:00:00 2001 From: ivanyinwz <139336681+ivanyinwz@users.noreply.github.com> Date: Mon, 4 Sep 2023 10:27:00 +0800 Subject: [PATCH 008/122] Fix predicate for store tiled op (#2215) The predicate of Store Tiled op was not set which caused a lot of perf drop due to duplicated memory traffic in epilogue. --- lib/Conversion/NVGPUToLLVM/NVGPUToLLVMPass.cpp | 2 +- lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/Conversion/NVGPUToLLVM/NVGPUToLLVMPass.cpp b/lib/Conversion/NVGPUToLLVM/NVGPUToLLVMPass.cpp index 366750780673..69545a00d83e 100644 --- a/lib/Conversion/NVGPUToLLVM/NVGPUToLLVMPass.cpp +++ b/lib/Conversion/NVGPUToLLVM/NVGPUToLLVMPass.cpp @@ -566,7 +566,7 @@ class TMAStoreTiledOpPattern auto dimSize = coords.size(); std::string ptxAsm; if (dimSize == 2) { - ptxAsm = "cp.async.bulk.tensor.2d.global.shared::cta.bulk_group" + ptxAsm = "@$4 cp.async.bulk.tensor.2d.global.shared::cta.bulk_group" "[$0, {$2, $3}], [$1];"; } else if (dimSize == 3) { ptxAsm = "@$5 cp.async.bulk.tensor.3d.global.shared::cta.bulk_group" diff --git a/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp index 4180fee6d7d6..c050c3ad3a38 100644 --- a/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -749,7 +749,7 @@ struct StoreAsyncOpConversion typeConverter->convertType(rewriter.getI8Type()), 3); auto threadId = getThreadId(rewriter, loc); - Value pred = icmp_eq(urem(threadId, i32_val(32)), i32_val(0)); + Value pred = int_val(1, 1); auto llCoord = getTypeConverter()->unpackLLElements(loc, llDst, rewriter, dst.getType()); From 9e9fbe01f0fdb910c9873027dfb6c369994ac086 Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Mon, 4 Sep 2023 02:57:08 -0400 Subject: [PATCH 009/122] [FRONTEND] Fix specialization on triton integer types (#2236) https://github.com/openai/triton/issues/2231 --- python/test/unit/runtime/test_cache.py | 24 ++++++++++++++++++++---- python/triton/runtime/jit.py | 2 +- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/python/test/unit/runtime/test_cache.py b/python/test/unit/runtime/test_cache.py index 6f9b94d907f2..f75fa7c32800 100644 --- a/python/test/unit/runtime/test_cache.py +++ b/python/test/unit/runtime/test_cache.py @@ -105,6 +105,21 @@ def inc_counter(*args, **kwargs): assert counter == target +def test_annotation(): + @triton.jit + def kernel(X, i: tl.int32): + tl.store(X, i) + + x = torch.empty(1, dtype=torch.int32, device='cuda') + + device = torch.cuda.current_device() + kernel[(1,)](x, 1) + kernel[(1,)](x, 8) + kernel[(1,)](x, 16) + kernel[(1,)](x, 17) + assert len(kernel.cache[device]) == 4 + + def test_constexpr_not_callable() -> None: @triton.jit def kernel(X, c: tl.constexpr): @@ -138,13 +153,14 @@ def kernel_add(a, b, o, N: tl.constexpr): torch.randn(32, dtype=torch.float32, device="cuda"), 32, ] - assert len(kernel_add.cache) == 0 + device = torch.cuda.current_device() + assert len(kernel_add.cache[device]) == 0 kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1,)) - assert len(kernel_add.cache) == 1 + assert len(kernel_add.cache[device]) == 1 kernel_add.warmup(*args, grid=(1,)) - assert len(kernel_add.cache) == 1 + assert len(kernel_add.cache[device]) == 1 kernel_add.warmup(*args, grid=(1,)) - assert len(kernel_add.cache) == 1 + assert len(kernel_add.cache[device]) == 1 def test_jit_debug() -> None: diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index 062df8d993d1..d306c160e8ee 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -306,7 +306,7 @@ def _get_arg_specialization_key(self, arg) -> str: else (False,)' elif 'Tensor' in arg_annotation: return f'({arg}.data_ptr() % {JITFunction.divisibility} == 0)' - elif arg_annotation == 'int': + elif 'int' in arg_annotation or 'bool' in arg_annotation: return f'({arg} % {JITFunction.divisibility} == 0, {arg} % {JITFunction.divisibility_8} == 0, {arg} == 1)' else: return '(False,)' From 99f8f912aa1fb51c9cf37fbd17548fb2eebc2d22 Mon Sep 17 00:00:00 2001 From: jon-chuang <9093549+jon-chuang@users.noreply.github.com> Date: Tue, 5 Sep 2023 12:30:54 +0800 Subject: [PATCH 010/122] [OPS] Remove unnecessary perf bug workaround (#2240) This bug previously existed and I verified it in previously nightly release of triton (20230714). However, according to new benchmarks, this bug no longer exists on Triton main. See: https://github.com/google/jax/pull/17328#issuecomment-1705010065 --- python/triton/ops/flash_attention.py | 3 +-- python/tutorials/06-fused-attention.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/python/triton/ops/flash_attention.py b/python/triton/ops/flash_attention.py index 53481fb21bd2..1ae37c297f81 100644 --- a/python/triton/ops/flash_attention.py +++ b/python/triton/ops/flash_attention.py @@ -86,8 +86,7 @@ def _fwd_kernel( alpha = tl.math.exp2(m_i - m_i_new) p = tl.math.exp2(qk - m_i_new[:, None]) # -- scale and update acc -- - acc_scale = l_i * 0 + alpha # workaround some compiler bug - acc *= acc_scale[:, None] + acc *= alpha[:, None] acc += tl.dot(p.to(V.dtype.element_ty), v, allow_tf32=True) # -- update m_i and l_i -- l_i = l_i * alpha + tl.sum(p, 1) diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index 71aa9651b7a8..996235c793be 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -96,8 +96,7 @@ def _fwd_kernel( alpha = tl.math.exp2(m_i - m_i_new) p = tl.math.exp2(qk - m_i_new[:, None]) # -- scale and update acc -- - acc_scale = l_i * 0 + alpha # workaround some compiler bug - acc *= acc_scale[:, None] + acc *= alpha[:, None] acc += tl.dot(p.to(tl.float16), v) # -- update m_i and l_i -- l_i = l_i * alpha + tl.sum(p, 1) From e721911705a44f634bedc8089b4fe57015e1ffc7 Mon Sep 17 00:00:00 2001 From: Wang Weihan Date: Tue, 5 Sep 2023 12:31:38 +0800 Subject: [PATCH 011/122] [FRONTEND] clean build directly when executing python setup.py clean (#2238) Current setup.py could not clean the build directly because the default build directly has been changed in `CMakeBuild`. This PR is to clean build directly in this regard. --- python/setup.py | 37 +++++++++++++++++++++++-------------- 1 file changed, 23 insertions(+), 14 deletions(-) diff --git a/python/setup.py b/python/setup.py index 18764ec13165..8ab3839c313f 100644 --- a/python/setup.py +++ b/python/setup.py @@ -8,6 +8,7 @@ import tarfile import tempfile import urllib.request +from distutils.command.clean import clean from pathlib import Path from typing import NamedTuple @@ -158,6 +159,25 @@ def download_and_copy_ptxas(): # ---- cmake extension ---- +def get_base_dir(): + return os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir)) + + +def get_cmake_dir(): + plat_name = sysconfig.get_platform() + python_version = sysconfig.get_python_version() + dir_name = f"cmake.{plat_name}-{sys.implementation.name}-{python_version}" + cmake_dir = Path(get_base_dir()) / "python" / "build" / dir_name + cmake_dir.mkdir(parents=True, exist_ok=True) + return cmake_dir + + +class CMakeClean(clean): + def initialize_options(self): + clean.initialize_options(self) + self.build_temp = get_cmake_dir() + + class CMakeBuildPy(build_py): def run(self) -> None: self.run_command('build_ext') @@ -178,10 +198,7 @@ class CMakeBuild(build_ext): def initialize_options(self): build_ext.initialize_options(self) - self.base_dir = os.path.abspath( - os.path.join( - os.path.dirname(__file__), - os.pardir)) + self.base_dir = get_base_dir() def finalize_options(self): build_ext.finalize_options(self) @@ -200,14 +217,6 @@ def run(self): for ext in self.extensions: self.build_extension(ext) - def get_cmake_dir(self): - plat_name = sysconfig.get_platform() - python_version = sysconfig.get_python_version() - dir_name = f"cmake.{plat_name}-{sys.implementation.name}-{python_version}" - cmake_dir = Path(self.base_dir) / "python" / "build" / dir_name - cmake_dir.mkdir(parents=True, exist_ok=True) - return cmake_dir - def build_extension(self, ext): lit_dir = shutil.which('lit') user_home = os.getenv("HOME") or os.getenv("USERPROFILE") or \ @@ -265,7 +274,7 @@ def build_extension(self, ext): "-DCMAKE_SHARED_LINKER_FLAGS=-fuse-ld=lld"] env = os.environ.copy() - cmake_dir = self.get_cmake_dir() + cmake_dir = get_cmake_dir() subprocess.check_call(["cmake", self.base_dir] + cmake_args, cwd=cmake_dir, env=env) subprocess.check_call(["cmake", "--build", "."] + build_args, cwd=cmake_dir) @@ -300,7 +309,7 @@ def build_extension(self, ext): ], include_package_data=True, ext_modules=[CMakeExtension("triton", "triton/_C/")], - cmdclass={"build_ext": CMakeBuild, "build_py": CMakeBuildPy}, + cmdclass={"build_ext": CMakeBuild, "build_py": CMakeBuildPy, "clean": CMakeClean}, zip_safe=False, # for PyPI keywords=["Compiler", "Deep Learning"], From 60643f2a2d82446c08ae968cf1c8d209e3ce88d2 Mon Sep 17 00:00:00 2001 From: Thomas Date: Wed, 6 Sep 2023 00:52:06 -0700 Subject: [PATCH 012/122] [BACKEND][NFC] Simplify coalescing pass (#2230) Remove unnecessary templates and simplify the mapping of operations to encoding. --- lib/Dialect/TritonGPU/Transforms/Coalesce.cpp | 165 +++++++----------- 1 file changed, 63 insertions(+), 102 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp b/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp index 9d6cd240d0a9..4e32f7ab2e8e 100644 --- a/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp @@ -34,12 +34,26 @@ unsigned getElementBitWidth(const Value &val) { return typeForMem.getIntOrFloatBitWidth(); } -typedef DenseMap> LayoutMap; +static Value getMemAccessPtr(Operation *op) { + if (auto ld = dyn_cast(op)) + return ld.getPtr(); + if (auto atomic = dyn_cast(op)) + return atomic.getPtr(); + if (auto atomic = dyn_cast(op)) + return atomic.getPtr(); + if (auto insert = dyn_cast(op)) + return insert.getSrc(); + if (auto store = dyn_cast(op)) + return store.getPtr(); + return nullptr; +} struct CoalescePass : public TritonGPUCoalesceBase { - Attribute getCoalescedEncoding(ModuleAxisInfoAnalysis &axisInfoAnalysis, - Value ptr, Operation *op, int numWarps, - int threadsPerWarp) { + void + setCoalescedEncoding(ModuleAxisInfoAnalysis &axisInfoAnalysis, Operation *op, + int numWarps, int threadsPerWarp, + llvm::MapVector &layoutMap) { + Value ptr = getMemAccessPtr(op); auto refType = ptr.getType(); if (refType.isa()) refType = refType.cast().getPointeeType(); @@ -88,7 +102,7 @@ struct CoalescePass : public TritonGPUCoalesceBase { order = argSort(queryAxisInfo(ptr).getContiguity()); } - auto matchesOrder = [&refTensorType](const Value &val) { + auto matchesShape = [&refTensorType](const Value &val) { if (val.getType() == refTensorType) { return true; } @@ -104,17 +118,19 @@ struct CoalescePass : public TritonGPUCoalesceBase { // among all dependent pointers who have the same order as // `ptr`. // We only do it for normal tensors of pointers, not tensor pointers. - SetVector withSameOrder; - withSameOrder.insert(ptr); + llvm::SmallSetVector memAccessesSameOrder; + memAccessesSameOrder.insert(op); if (refType.isa() && ptr.getDefiningOp()) { - for (Operation *op : mlir::multiRootGetSlice(ptr.getDefiningOp())) { - for (Value val : op->getResults()) { - if (!matchesOrder(val)) - continue; - auto currOrder = - argSort(axisInfoAnalysis.getAxisInfo(val)->getContiguity()); - if (order == currOrder) - withSameOrder.insert(val); + for (Operation *use : mlir::multiRootGetSlice(op)) { + Value val = getMemAccessPtr(use); + if (!val) + continue; + if (!matchesShape(val)) + continue; + auto currOrder = + argSort(axisInfoAnalysis.getAxisInfo(val)->getContiguity()); + if (order == currOrder) { + memAccessesSameOrder.insert(use); } } } @@ -133,10 +149,8 @@ struct CoalescePass : public TritonGPUCoalesceBase { .getPointeeType() : refTensorType.getElementType(); - // Thread tile size depends on memory alignment - SmallVector sizePerThread(refTensorType.getRank(), 1); - unsigned perThread = 1; - for (Value val : withSameOrder) { + auto getNumElementPerThread = [&](Operation *op) { + Value val = getMemAccessPtr(op); auto valInfo = queryAxisInfo(val); unsigned elemNumBits = getElementBitWidth(val); unsigned elemNumBytes = std::max(elemNumBits / 8, 1u); @@ -146,6 +160,11 @@ struct CoalescePass : public TritonGPUCoalesceBase { std::min(valInfo.getContiguity(order[0]), shapePerCTA[order[0]]); unsigned alignment = std::min(maxMultiple, maxContig); unsigned currPerThread = std::min(alignment, 128 / elemNumBits); + return currPerThread; + }; + unsigned perThread = getNumElementPerThread(op); + for (Operation *op : memAccessesSameOrder) { + unsigned currPerThread = getNumElementPerThread(op); perThread = std::max(perThread, currPerThread); } @@ -159,60 +178,53 @@ struct CoalescePass : public TritonGPUCoalesceBase { // For loads, we can expect that the gaps won't matter due to the L1 // cache. unsigned elemNumBits = getElementBitWidth(ptr); - perThread = std::min(perThread, 128 / elemNumBits); + perThread = std::min(perThread, getNumElementPerThread(op)); } + SmallVector sizePerThread(refTensorType.getRank(), 1); sizePerThread[order[0]] = perThread; auto CTALayout = triton::gpu::getCTALayout(refTensorType.getEncoding()); - return triton::gpu::BlockedEncodingAttr::get( + layoutMap[op] = triton::gpu::BlockedEncodingAttr::get( &getContext(), refTensorType.getShape(), sizePerThread, order, numWarps, threadsPerWarp, CTALayout); } - std::function - getTypeConverter(ModuleAxisInfoAnalysis &axisInfoAnalysis, Value ptr, - Operation *op, int numWarps, int threadsPerWarp) { - Attribute encoding = getCoalescedEncoding(axisInfoAnalysis, ptr, op, - numWarps, threadsPerWarp); - return [encoding](Type type) { - RankedTensorType tensorType = type.cast(); - return RankedTensorType::get(tensorType.getShape(), - tensorType.getElementType(), encoding); - }; + static Type getNewType(Type type, Attribute encoding) { + RankedTensorType tensorType = type.cast(); + return RankedTensorType::get(tensorType.getShape(), + tensorType.getElementType(), encoding); } - template - void coalesceOp(LayoutMap &layoutMap, Operation *op, Value ptr, - OpBuilder builder) { - if (!layoutMap.count(ptr)) - return; - + void coalesceOp(Attribute encoding, Operation *op) { + OpBuilder builder(op); // Convert operands // For load/store with tensor pointers, we don't have to change the // operands' type, we do this by changing the outputs' type of // `make_tensor_ptr` - auto convertType = layoutMap.lookup(ptr); SmallVector newArgs; for (auto operand : op->getOperands()) { auto tensorType = operand.getType().dyn_cast(); if (tensorType && - !tensorType.getEncoding().isa()) + !tensorType.getEncoding().isa()) { + Type newType = getNewType(tensorType, encoding); newArgs.push_back(builder.create( - op->getLoc(), convertType(tensorType), operand)); - else + op->getLoc(), newType, operand)); + } else { newArgs.push_back(operand); + } } // Convert output types SmallVector newTypes; for (auto t : op->getResultTypes()) { - bool isAsync = std::is_same::value; - newTypes.push_back(isAsync ? t : convertType(t)); + bool isAsync = isa(op); + newTypes.push_back(isAsync ? t : getNewType(t, encoding)); } // Construct new op with the new encoding Operation *newOp = - builder.create(op->getLoc(), newTypes, newArgs, op->getAttrs()); + builder.create(op->getLoc(), op->getName().getIdentifier(), newArgs, + newTypes, op->getAttrs()); // Cast the results back to the original layout for (size_t i = 0; i < op->getNumResults(); i++) { @@ -226,25 +238,6 @@ struct CoalescePass : public TritonGPUCoalesceBase { op->erase(); } - void coalesceMakeTensorPtrOpResult(LayoutMap &layoutMap, Operation *op, - Value ptr, OpBuilder builder) { - if (!layoutMap.count(ptr)) - return; - - // Convert result type - auto convertType = layoutMap.lookup(ptr); - auto ptrType = ptr.getType().cast(); - auto resultTensorType = convertType(ptrType.getPointeeType()); - auto newResultType = - PointerType::get(resultTensorType, ptrType.getAddressSpace()); - - // Build new operation and replace - Operation *newOp = builder.create( - op->getLoc(), newResultType, op->getOperands(), op->getAttrs()); - op->getResult(0).replaceAllUsesWith(newOp->getResult(0)); - op->erase(); - } - void runOnOperation() override { // Run axis info analysis ModuleOp moduleOp = getOperation(); @@ -252,19 +245,9 @@ struct CoalescePass : public TritonGPUCoalesceBase { // For each i/o operation, we determine what layout // the pointers should have for best memory coalescing - LayoutMap layoutMap; + llvm::MapVector layoutMap; moduleOp.walk([&](Operation *curr) { - Value ptr; - if (auto op = dyn_cast(curr)) - ptr = op.getPtr(); - if (auto op = dyn_cast(curr)) - ptr = op.getPtr(); - if (auto op = dyn_cast(curr)) - ptr = op.getPtr(); - if (auto op = dyn_cast(curr)) - ptr = op.getSrc(); - if (auto op = dyn_cast(curr)) - ptr = op.getPtr(); + Value ptr = getMemAccessPtr(curr); if (!ptr) return; // We only convert `tensor>` or `tt.ptr>` load/store @@ -279,9 +262,8 @@ struct CoalescePass : public TritonGPUCoalesceBase { int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); - auto convertType = getTypeConverter(axisInfoAnalysis, ptr, curr, numWarps, - threadsPerWarp); - layoutMap[ptr] = convertType; + setCoalescedEncoding(axisInfoAnalysis, curr, numWarps, threadsPerWarp, + layoutMap); }); // For each memory op that has a layout L1: @@ -291,30 +273,9 @@ struct CoalescePass : public TritonGPUCoalesceBase { // produces a tensor with layout L2 // 4. Convert the output of this new memory op back to L1 // 5. Replace all the uses of the original memory op by the new one - moduleOp.walk([&](Operation *curr) { - OpBuilder builder(curr); - if (auto load = dyn_cast(curr)) { - coalesceOp(layoutMap, curr, load.getPtr(), builder); - return; - } - if (auto op = dyn_cast(curr)) { - coalesceOp(layoutMap, curr, op.getPtr(), builder); - return; - } - if (auto op = dyn_cast(curr)) { - coalesceOp(layoutMap, curr, op.getPtr(), builder); - return; - } - if (auto load = dyn_cast(curr)) { - coalesceOp(layoutMap, curr, - load.getSrc(), builder); - return; - } - if (auto store = dyn_cast(curr)) { - coalesceOp(layoutMap, curr, store.getPtr(), builder); - return; - } - }); + for (auto &kv : layoutMap) { + coalesceOp(kv.second, kv.first); + } } }; From 92e2c32283c3f6e8a16d6cb562da59713fa3a17e Mon Sep 17 00:00:00 2001 From: Wang Weihan Date: Wed, 6 Sep 2023 16:03:31 +0800 Subject: [PATCH 013/122] [BACKEND] Rebased Intel GPU Backend to be comptiable with latest Triton version (#2245) --- third_party/intel_xpu_backend | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/intel_xpu_backend b/third_party/intel_xpu_backend index 0bcc485f82b3..d05dc79dad63 160000 --- a/third_party/intel_xpu_backend +++ b/third_party/intel_xpu_backend @@ -1 +1 @@ -Subproject commit 0bcc485f82b34d49494bd0264bacc24a20aafb7a +Subproject commit d05dc79dad638b8ebbacfef44886f568b5885fc3 From 36859aebff1897f3bdf6dd19ba775e5961848943 Mon Sep 17 00:00:00 2001 From: jon-chuang <9093549+jon-chuang@users.noreply.github.com> Date: Wed, 6 Sep 2023 16:17:12 +0800 Subject: [PATCH 014/122] [DOCS] Add MLIR Autogenerated Docs to Sphinx Docs (#2234) Partially fixes: https://github.com/openai/triton/issues/2226 Here are some example renderings: ![Screenshot from 2023-09-04 18-39-20](https://github.com/openai/triton/assets/9093549/e9c4af04-aeae-4021-a8db-6a4a82b59ae7) ![Screenshot from 2023-09-04 18-39-30](https://github.com/openai/triton/assets/9093549/410391b8-e07e-4bed-909c-8ce5484072d1) ![Screenshot from 2023-09-04 18-39-41](https://github.com/openai/triton/assets/9093549/f1eaef95-66c1-4506-a153-c6069e2b5072) --- .github/workflows/documentation.yml | 1 + .gitignore | 9 +++ .../unicode_data/13.0.0/charmap.json.gz | Bin 0 -> 20988 bytes _deps/googletest-src | 1 + docs/conf.py | 59 +++++++++++++++++- docs/index.rst | 12 ++++ .../triton/Dialect/NVGPU/IR/CMakeLists.txt | 4 ++ .../triton/Dialect/Triton/IR/CMakeLists.txt | 4 ++ .../Dialect/TritonGPU/IR/CMakeLists.txt | 4 ++ .../Dialect/TritonNvidiaGPU/IR/CMakeLists.txt | 4 ++ .../unicode_data/13.0.0/charmap.json.gz | Bin 0 -> 20988 bytes python/setup.py | 1 + 12 files changed, 98 insertions(+), 1 deletion(-) create mode 100644 .hypothesis/unicode_data/13.0.0/charmap.json.gz create mode 160000 _deps/googletest-src create mode 100644 python/.hypothesis/unicode_data/13.0.0/charmap.json.gz diff --git a/.github/workflows/documentation.yml b/.github/workflows/documentation.yml index d413a3dca171..7993b5733336 100644 --- a/.github/workflows/documentation.yml +++ b/.github/workflows/documentation.yml @@ -25,6 +25,7 @@ jobs: pip3 install tabulate pip3 install cmake pip3 install sphinx + pip3 install myst_parser #- name: Fetch dependent branches # run: | diff --git a/.gitignore b/.gitignore index ef7867cbde86..cd3d84ead5a9 100644 --- a/.gitignore +++ b/.gitignore @@ -27,3 +27,12 @@ cmake-build-* # Third-party binaries ptxas + +# Docs +docs/_build/ +docs/python-api/generated/ +docs/dialects/ +docs/getting-started/tutorials +python/tutorials/ +!python/tutorials/*.py +!python/tutorials/*.rst diff --git a/.hypothesis/unicode_data/13.0.0/charmap.json.gz b/.hypothesis/unicode_data/13.0.0/charmap.json.gz new file mode 100644 index 0000000000000000000000000000000000000000..63a9ba0ccf8ffbcfdc83fc4f7c9d7aee50930157 GIT binary patch literal 20988 zcmbUIV{|25`0fkGcGBtCM#r|@vDvZFvD2|_t=P7$PC7PMY?~`5&))y{J$sDv<&5*? zs;g#I%{l9yg;Do!7I73D95@&l7{sTGoufOOyRkh7yDON2{@w%$^wP?>=+V{D*7K2Oue8Bfd5C9;m7U7QVwezhto|&rPwL}ibz>`O}WJ8b3RL+uE# z8pVo1KtS62#M6$yRssG|gvCl*YG;ZI z$xY9>hDDpM4B}j3a=eSF6qA1ULewUg)RG-+5iNXg)~o$4lIn7U<33)ZV^6gpq-HAb z?Y6s*qPGv_u!WV_u08_)$xKqL7p09H{2rwXBg;)^`IYrl8Dnt_n{Cdoy|HR?uBp6P zya?b8<9P)AL|nINC2*6Gl7D?)*LD4r(sdf+iNSTc!MtNlrQ)fqB@rWWD!`LT3`iL( zpTp!7TFAC7z)M`rM=i38`?He1Q&$E*v1f4IiE7&M`j9}NH;CHJCC6epo!iR~_UZhX z7gch9Qx-2il!BoA#`9Xk^9MI>yrcW5ru&oq_hK*SOE+l*X_4~1v6?GugO2%EDXZ6AOHlMD|n<&@+7 zrgr>u+$DM2+%^HLjoLerciNnGo-WAMTpChE3U zOI1tqsf=Q`$(`T;4fNBw{i+2;kDRaj_2#Vmk?^k2I3A!mrHS;$M|--gCE0PYyz-}>>uCjx$<{@Q z+T8>{KTsXxGRH8pu|~(S)=Xx;-A;6%d87CpWPif@eX7+Z)b*6thQ2(O|8=vRrEO}? zonxJ0vgvGru!_-TwbrY=x=0&i^k^i-PJdM$+e^j==#D?r_2Opx&sEoX)BX?8$>-9b za!s9J1<^dF^B|!xcX0FTj^kWP{kjA)^qRT0K%In%*c5)PjyZD6?WfgrX1OXKPn`$N z=Xgx+6IX6ly}BKfTO5c$#<;G%Ls>il{Hv8sH(8FF){zOqCfCOWfeCP#Ptr4iwbE7l zGVc;)N7#$*sZI$dcUckVL zCl;Qm8WmMv`_o7#>*hAyTtb`Y&AH#N*8hOjMH(%MBSSf&B(sujjIsN!Tb!h`yRWA| zisxbCcaono6icda-rh9cPDe?nlz)A3SkAinAS=KRcIFJm_I%2h%l7 zJaa2d2+^WL#{2sKk)PisNYKko3?u~ky*yPKggr7Gb~e&ZY;ExOvOlch#<$aL=JdU| zANzDBgRv)~{!^dqcoOk_?lxcR96Uoc^zVzWl4RuvTa+d`W(s}FI(Mk(-TxH2B*tdI zd&TTgDR(gW6lF>QasFy2(&gYcByQF#%;%SS@HZrVXr+oZ)HA%A>c2TV_SlNMe48b| z?p|7S`}X|$sek2J?a?0+XCv)U9IkR)SFBmtX>NUTFB^GK z1ZZsbN)=I`OzdiJyY<_kZ88{MNAd;MM{8Gh@#Qqsd{$U<`#Jh-5;ACMTFgH+mbkDd zVotPHZEK&Z|D9{N{K+s@ez&;@+px4!kN>)^@FB+1F*);6nEbFF3}RZyv@Y5A+N?N9 zb@Hp7Ctn4aA9MzVczYZWY3~)|5ec#URd1x&1Ifpbn!Ekw@T=IN zV{Ui(S)0QFCVMFdd}4)tY$858?dCS%^{oD*gOGKQN05qYR-3JrU0L5nbLQim|JlxU zK)@~7-68r>DS38Q9%9N=4ek#6p=;63ckE#GzK9Sl(fZ=y0X8c=@{&5Tw@IsxA0yzX z)1=CJDaDa&vF+h$dWq7T1~TN6(I)EB9yjOU6(V3;JbrQ<$*Y%-GhBYr8?@PGPpy$`T zItLq{W;PNzsa3Q&E)0T6~@dL_2@g^UmRfT&68QC zS|g>HL`^8`(%`FiYW`*DVdETs5JK5O|CLqy7GYPdJJ!il#EttTflU>$x`x$1wY`e^ z=o{ax(s`Bg;${uLpoUn7w%$eQOa7j>XUgI>ZG&Xh-|f`)g-4z#+l;e{>T`2A=2yEyg$C>`Mug@YSsAX=A2Cpm84%fOPs zVO+OmISe#X3X<&pTqIaGKUTcc)wX&E9lyPZ$Vfjoa>GD=o{ssANq2z4e$9*91pyP^KXvs*Ppczf)ud5Pghda4!^ul zE*E_8CR^;Y9t1V!DF>PDy*{r{F2gPWmT)*55@XBWystvpUk{NYen1!PnHTD$`-`{r zk&>G~><96G9RnTXM@>EDzxf`|2o28G9V72uaHg^t*pvH6KBxX$NkSlIwOqZN8a#UZ$?85;+;%Jj{t z=1yJgC0S7Om)#5B_F>I|Cg#F;dnahgiXhKb=>fNZFN>2>U+cBMj&{j+&jeE64 zbI_jR@|9w>;0Fz*#fP7sNtZvn zds}zjoYOLNgud#&xm*z5h1%X9sAXz_C6+}nON7ze^H`_eohihUiKmq3&Pk;S!6XP8YXi}oW(Sc&Eip^HaU0p`@Fv=_R@>y{oxY6kOh!w!MZ~r_9ZJ@}1Tb32_XQC67JT;X_zABAsZ{~b z=i)1Q57a(^x*G709fF=8E)Vm-_RuK~gK3kqqz^$E-(COyC&G(2ae?XOX(V5aD@@RQ z^HyzYjbk^x-l;{mUuum-_nCG_?n8~mnb+f)C#o;&7IlZo+XLq7)gxjWq6DDpEvX~d zXE}+N3dkb$LGTnbhOX;hP}=zB11J)_>0JprTD#J=-6Zy&vx#%xvt*=cL-a!NT)!k5 zgfGs0NpHauy7BE9x4p%T`)65h_yRBpiMKxoIeH~oeBDuR!Jh)DwanPS{%ak_r5f z&ac4r`tn9VVDj^fj_p0Asbg)gjBnCHD06fwYxL)lwB!EKr9#i9RC?&z9*1V7BD;?5 zpJgI390!4JZ=36+<}K#82gHc)gb3*Cx(*8XL>0^gn&WAZVM;luZsc|Rg7Zs z6%*=DjY_Fak9RttR;TKTKAgF&ntdMI3GNPQIBQ*9hh=17>|a9zvyOeG$)z>qzoU#> zxKz5pHy?s(#`Us8ViKyOK_v_pS|P@|Btn2zlI=@Mx%XU4e{<3AUo3sGd(QVO&PVx7 zpWycg2sBaN$Cv7c>5Q}e&nGUGL0900pMO!X>rfY`XSjOb=zDndbUZar|G=57fU{t+ z)(2=2pkL3q<=4R>vge}43+!BjjbqNC)bZ_{TMGp@K#s+>wY*ke>jc=qSF2kqf8Rpt z8$hE&$>Yx7#xrsdZ2@1Q(!)2h6RiPnqq4;duoeY?`%s;0Um(IumSv!sfJHm?`gsSV zl9R~OFbVgxp(>>@>j!%~5nul(>_o6NWB?VwFreJ zXGtGW9b5_#Ad5o_R?=p;ETe!1ZaEAX1>=^bsesJGLjM^c1OHvU_ZtTM3>|6ZX$X#L z8y#3durx%x3^!FkQnWEdxJ+5`+V|XYKl**GnBx9=b}*eiEj*4?m{APjAu%mZu*u+H z1mT2WaR@XSZi>F#XiW%V85|0*#5Pn%j4NsqD+0D~aTL-!IB^WR^2>Hd;Q(;o0c!Ag z3Vsy{=x-+C!Ew`gTkSHw*d!6F1g#Rusc+$Z2%SynIqD4930lL-=o0GknCJ?_-_i9= z8R}n6v(e&B^5Mp{;nO zSsIF+3YNNOkg-V8*S=miANmg<*>#-<(2t(p$}M0sR!|M__dhVQniR~6)!jG56m1#_ zuzc)M$bz(}$IFT*<@6i%X??tnild(A4NZY3$_-o+E81+_g0jd}7Ey>#%mAK&12w>3 z!CwL6ukwAa+&|HLu^+D&=DfHyz;Of#DjmV34Fe!NJnX50Fh~&dApfPP6w0)wbFBeO%%{wl8?AM(%Nmr=j(AcvO$PJA);!sN~p9?!lCy%i3Y83F@NN?HL;F6DAv+EQQJ=71P zwP~!*mbzSfF5A_0wFUh|wvVtPY2WQDX9cYui9{MxT06iIH#56`-J`7vHNlH}aZ?P$ zN3)WRjir-`MpB>~nfjpbNJ$c-XBQb9%QO(v&dLcAU{E}WDkZZh-_I0xYurn%;!E_V zbsO9ptnP&v#QK`v7ZASrAAcl>`zLfu+#{_riC-o9iq0$wJwPyPw1VPTp1TnUQ71P3`u{rI;ys<5v)>2(o&Q9f#B{iX*BN*34JbbI5pGufmG=CV2hnCe`n7UDKX;6sq=L2i^1W z*Fp+Fwrp8V$`ugrtbj@$7~#8z8e^>(@%1I;v@yKueWj)|1^jWUKk{cRVTEimQyZ&I z*Tz~*tcGO^a;$XY9}DfK^7Jk7<2m!sdttn1mbL@fDP}eF664KEW)U7bWj?X*XU9hW z%<{79{6VRyrno2TB_~l-KwslzHbI$8hh#gPiH5&c#e6IW=bcF|bu8_|``d^s1&gzi z%nogwG&P&=!YvvS5Wm~4gd+C57m$gj1B68Mv6MgtgB2Zn{K}q)?l0XTlEy;D56&QX ztwToK0DNbn3cT_A2YnY1MMXrdVXkjTectV6qY9Rkr)yghVTwqtM064%RHS+17v#j4 zBC*&UkK0I%eDdEzh5f{6dP%oa;7EPQc)du~;`4 zJDONG(mm?}BE-LO)J{e(<@+CKNXWy7;KTLHbn&WafH9!u8$fi! zkrTmHNtJ!?BNA1X%G(lC4g-T?BLJ5al%N5SfrXX-cT_!MnfchXi|&6eUDu8>|9gzS zFs~R~mM}X*a*ESBxXi3l^yb2fL{D9UYO6r}0^*3h^*L*J4qj(KbSU{+M12-sXGDEv zn=T0IcqEV-470(2J2)iSm_RqlyZ@ z$Mg^3sI|C~>`^5qEI{pvXn``evS3Wc0Ue3UQl>f$gchWY6{lm z+7bz)<~KVG>d5vi4nIy%Bw!P}+Y%&x-|=7P4;6*w#li!H)%cRZ@xpgP$*`Tq*9jY% zj~x_NrQq>t4$>6+*i~}10#WYPUpFPqWu#^A$#l_s-ja_5>sj%=)5p?fwN+N*6AD(0 zAW9O(H{}CjRXCMe;APqg!5Hvr)59T)s^a+1|@Qe~6SGD(xYkRQBB3!AADMT?hs(5do+-&Iy$^!_BD)~3pW*=doG zR<)rqSb`C|F^43HM2InaJZgJ~%ZO>gx6+Dn2kqayoXckBwM4-n7yg>CoMFd1B4iR< zNTne~i>=%s#4DV__*07-V}@5WLf-coHj-K`uluyTV(9E2^TF}I&A;CnEo;=1k5-+& zbB=(P4Qhoat99Qw4?xQn_2iS)o9|ozU~hGcwP%w!=GltgZ!Q|pvX^l%ci6aZHZG)V zF!J7bXf`F}Z2O$`(S)*-5_rCDM?^S+MBN*?HmN1 z~i zvFLAVHJa0BzF*o7|K83UH!<_x!4@cMy~f=t6npyS%CsGNVlpHOUk!bb~qg z<*BUUpQUZsDa1n)b|AL;f%+)4@oDrq_+_{@fJXjDb(9+Ul#D}oqyu{Pw*=Q@g`ooYhX{m zmglwYy(6H^H=F-$2nc0I!sgJ)%u`$~QXcm26oo$jPjSa4{VOt~W$y0>`g;WUe_7EeeoE@(g@%W?_^5E-hhpKH;shRVBI~y0A4mYeOxnvw-b;Oybnlax ztvGil^P8U!vCd-n`>fXC%cM~}L9OT3O*FJu7g6l;-i;TxLEPxsj)wqI?a};|>qpq$ zgUA14%l}zOG#zF7yNfWcb@0ZGe+y-F9rw!P@#@Bnduz%^yYhT`@50fCFpdca-I9TA z!Zm{wLc3sbjE_1IPQp#5(lzUXhTb@v{LWc}dbeikcjiB`3V2P{hR4Ib7g2@}&hgB( zEN$ydMf%NXZL_Ie8FPsPj7R{Dy9FaaR*G?9<=%WNKia;yyX%`ki250Eat+7krJ56+k8iD)mk0leCh~C^0_rxYS z%;}k=rmru0!YluZd|CNpz9FoWe~PS$=^3?Yr3tjWGx*Zt2ZM1#Pni6Bu>AXy*w_y$ z6aTXvpR^js-p?M~l;;-ZF!wkM;NzJ;-wSvH{kQihk7qrv^WJMA z{m90uNO8F9J1#Ik&O68C;lOOz8}4N9j1{@VHJW)CPY41BumMVR3~fl!XhmeSXeuPh zL6REASxqJQAP0F^x*DR0=VjjKkOn~vS!8l_j|}%{%@=tt)&@7W9SVA z#=aOoKJrly3WLm5uFVdVu*qgZN49IsWV?$uBg*O4%~Ho?vyJz~z=l@L(NdiA*~GQO zzw1^X9K6{zF3$QN>z3cSCug}$0JyKbKw(L1h1p4G=RLh@Br_w=85$FB<{)qGJt$`@ z_ae|%c~ErhVKj;nWQlQC^WfD73W}sZeR4 zZfL_TeZnM-DXYrgH4FoynIv-`PLG1j^9UH{0=tG^`GleQd533X^tu#-`;L`QufcS46lZ_OA;hF zDs5CB=0h8d%p|1unn(7t%vzjwLfsX{hXv5M{z~Ru39_cAlDwvO+hOc?r(s8v5$jmW zN5S&`9-32f-)y@mr!?no%8f-%9KY{Q-y@c-&ySN52qqI9xqY<0do&1FzHBfy1urAD z+v|{ndedLSZ77vpYcqKge*Q3nDVF)M8L5np6dM`o+O@5pP#cTB%0cEGki|Uo5@uRY zG~yj-rNW7$P}2D=YA2w`0`cxp2iBNHZ=Vl~jZ6#roAxnIr~YhEtcQqZ*S7Vv5#4(< z9c7i8mWm_O__~7Y!kQ(lS)l7QieQo7P=jY5LV0#^tov|#1fi{te!DDmd8la@UF-r* zu;IIX>vbhDH*$y$w2vCU>>gr(dWheMq-YZF~INW>oINE)cS*8Rl@u7#RolwoDHgCWtRdc9TIYTVi!M~p z&a^exLdCDzgNSLkbe9{F5^=!t$81k1@t4~}ST>nSIPNx6>oj)^r~_(m?VzTN1WBvP z=lvMcom4mIF(nP{C|#uo21RA$1Se(g7M5se&`W%g)p~7DMB7<9ryyP(WZn$>h#7Xk zv9H1CRkRe|e?DG`%DbNUAX;ec`gSvNe+E|ya3rRkKSQ(mSX~&V7pMNF`YwOnsWwX{ zd2c_V(4Jwe}?x0WUd(;jUZgg8Pvq8({W$1+IM8lTSK)c%5`d;}JT5%N@KNF?bLmxNd+SJ-uJ_rj*1gn1?*?DHw+w1I_9#iJdml zZoRRuX{?2?d|leSE9Pa>a?|!`#J*%ClpN8613~r$S9_JV^krS}w$*MzrtX0KG4Vs8 z5ISy7vb37C#6Q!}jsg62ZV6q)M`#H{-Pyl4dCm$d4+GgyaI1jh@N@me|8;@^vA74f zePryIrI#QnKC0HvyXDO?#`AG-!W*`^s?^OiXH^H&KTE(9C%ci!dkMR=@qzy?`}MKa zuA8NB+cWY6cX~9NZo1%Q4oswGgq1(!P&{{EE`6$%$xU6yy@4~s(6Gsow}(YcWMyd= ze(+|~=Qndo$!ail+N<>f@tjp@#vkgdE4$FD-al`in)Y{p6W$j+(M;x4xh8=l@_+KP z2M4wlL~pjd0zk>gIzzwT4cykWX#LvF{CLbw(t)uFDfQhwB(pd%W<;NUaUnM87H(j? znlaVDI3HI(qW(`N#W@P#U=xbckXzW&t;JaekjcOI#T#hY*Kim|9CF{BVrS9Ar?dW; z{F@Z-ITXHr3?y3jZ@p83IlsVp7LFnQ34VTH7FIzDog{(#H zc9r99edqN8w?~fY2T zQzVk+7JD3*;Rcpr_R?mvhrl>;IpOVm1e2pvzwViqV509k&wdvEAz<(+l3!zly++Y% zc*(e@^SO=_>8&7S{V)iU=aQKiKd+&6H4ZiQ-H}wVU4(Vh%3lAr+)m=vCJdMbu&?*O zyV-OY`z*O>PV%vjWdAmcAFD}G6u#NF%=vFd?n(u2wH|(jQHSfUb>~$yVUO%LIWVhB z+1(;>n(^P2+ZgaMV>$*nZH4=8>s={SHOe>eZj(+Nd8oj;wY=-Re(0W51r)Obj%rpo2LS^!Pjh-Y0-`8Dpvx|rq@;{)lT}@C%S#UdHFH+-L^n_s)xvjH{; z`D@9qhBtZkh34-Bw!=hCUX%n6rNs2M*@$sz(r2Y%p06TP|6WnX_3P(&biG6hmYrHs z3|jduzL*Pg+*WFLh71@eWHMcl8?ISo01<6oe)%zh*qyZ@kzSDRT7Q9hVrsujSMQ3k z^=c=bUu8Yd{Bz*^8)n|af^)x^E@YSRX{+rF+~>Rd%<2N2oP|x);t}@#)x_=*g39pE z#~GAjezJIfg!R28Vl3KE+qYqBc_eRpWb0`0-N$?z>Rsv^#h;p_+hx1Wf+U$@*Lu`?=lT~obE1n^Lc<4rKx(g zxO#D(0av}>Ynn}bMaeUO-afMiG(C;>Ahz*KAVL9&zC;_owlb& zxXFB(>^&n=J-A`pV)8@{v9u+4X;RP!v8jJ7d*V(nbR;m0^^>m$$Ocnd&v^3SmZk zS56aT2~TJtw(LNX3vq+WK7$-c-1xT0D^i2B*-?Gy5-YhQiENqMvF}rvd+g5wY6o=uo>c>B+TQ&)7cA${{ z>!O!tn@h;qL*v4<*K(uFS7vycHD|Vys2HN=<2d@BVW}`iM;1cw2IL(MMZ5Ja()M(A z?>yxT){GJmIvWU$7KsV5i@zJb3FIj7R(*_2!;GU9iS5mzounTW@s$vd#!p~E{Xk;P z3}%EogSDIR694fxpkN?WVq0Yt%NG$arEhf2`OQ~OH4nXm;8?e1u;w$Wjvm|TPnhzV zR5utELp(DQdCbX>IC}kt0)jIMvJ?9^1hGZtkV!NS)pR6f){tQ|a@B1_ zWoGC`vprOe*OS%2d#aXgn&OG-_eSd1aT?kn#e%y9D5qzi=={tZ8UAI24TfRoAyN8WMKgadb^ zPr$A5@`vt%2o>9|SHG&gN1*vONAtxo3fO~9giS+@Sg~SAXNV#E1G<$KLbO2pB?eyv z!x>B{cR*b_E1n*e7DF~3?5F8XyBH&u#7@Q04^1wD%{kmWW@+63(HbQZvbeu~0bJ_& z(oncq#=<^I2jFPmu%T&Yu)?Nuih3dR7WEiOpc`PIAz)N17yaVl?(p((2n;^UZP1V) zIO!gAMTm7zT{Q0PsGX0U6h3kdg4rsp5|H@Lsb~1J!V; zt@bNUm>IPc_r~~%YOza?RuQWxh2I{*DvEAXT*X*-K{VdDInCc>@fY%SzCNYfd3*wsf2Ea&U*8tI2exwdu;c16F+iZW3}P^!J74@Gt@ zi+598HZ#i0tt_G15r2O)-dOo%f>E7Gl(Jt&B}&OsK`~dvyCE;C_8_`G#VIIR$$nGP zVOj!C+1wbk6bx7g6+T$3Uxk}41Q!%xa=_Ej3R5QJu`A5VjQ^C$J6o^}@!Bc^aM2tq z;CTf_ZWR@B(VQsgbTmXBtl0%Pu=jJ}X<=p!cpxixm5>+&p|(Zv*8kzFkKnKWBT!GP zH-k25L#EWlxO2WP6aJ6KX&jDClHk#NK{pd=Az>Qp<2s#@XJH%Y(hlb(_e~G3FUp!8 z-VEbFONw?5?H@__mS*|sO4&)B`W$#jrC3O%LzIEq7Wi=#nlvq#N=5aUx?p^zPa<#8 z4=d5kL?uY!AE1CRAd2lF2HaG?rRqw%46sCW7MxoLE*3kWFvN%$d_n5J+<%G|m|FFJ zDeZCGnRObAUtnr@ZpE;4Y(q*+WpAW7GR$Wp*YjZS>}n7{v=Rns(f@b^3|RcJbb(v> z2IjV?f#H=vp8SI8b_Q{;t}AmaK~Axqfc))RIN@4}hd{#wLF5La{#B_A)N(8ltBZw( z+Ab4hV9J*O$?;=f6{#p;;j02=o;0K8u&?GWFc2mKkOZL~5sP$#jUf7ILWJ$Wm2_!_ z@M#Q*ZVh29y~P{@1DM#<QhK(HtS0TO`otaDYQIbTC zHn{|gSxEFz3akJ}Rs?>ZIV3S1qCG9YsZCf#kA~mri_@@#ubF?fG!FYaMG?bNmL2BT z%2(hbH>{BcX5h7X!BtED7f>OoBhL?n`!oLgzbU@#FyM_<{|gd%FvGvv|k+A|D!fYT5CqM90FScN%ypSvnR9 z}FMTQS#fK@s7Fjcj zB?i%oG1RV9WY20y&QA3Jhgea}a^(@1-7mBnAaZB10>Whv81Y;x`KPd+RaBl1lk_uC zrnJA}-<>VQ=`zaX)bvM!weKrPK2WhViZG}G_bZpbyg_`nEgz&#t#G#r5MT&e`IzMg zv`kYY-&-O2&WQtBW~pQDt>}H{(13tw6B?EgxF*QejJE6nUaXA=RE_MF#VEuQNqKYS2= z)3ph8fazfrEAWDsbb}??2UrR96_FdQVKSc}q!Udi?Rk3fD|SGZe1IpR{IQAyt7K&V zu4`G9H*YRj8@54o8Fw|V+tfwd z3&bQ~n7c#)ZoAxSUDOH9lQ7JYF#P{CU;sD9N_oYwLntugK)Z_Yi)S#ceyx>6j!$U$ z(0nlrp}_n!D-3Kaj8rQOFiCNfj1L}UsO^fA;I_Zf-0Z2;W=XY4k-kE#Zgcfl_RJh6 z5Okd?UkCcP`Ijv5x;B%$!+wRDYlFCd9R~49p{~^t=4<+0$(u>N{U30fu>TurR8O7o7V{zf$#r` zH6CO}eCHAXrst#E2UA^{x&MPM%G2Wk{4_a(-g)LFG7Q^0@!%(nEj7`}71PNzp$ibp%rs%<-aT(scCubtT3a9f z!?nRY>_&FYe9&~1R~%W0*9Gl9_Cl-f5$wauLv#bL9+~uVEK|0h!$ySuihc9f5Bgm# z^!nUgKKyy8V~B^2*JJl1lEMC{3nGl#^~b!Sf6*=z1Xhy2&pnhs&~G%p?(J>2R4b20 z{noc&*9Z7Q`}S0R4lI;Ap?kCFxQ}DWqkDVu;}GvO?tL4DQ=F~sQWR&ug+mfdXz-7_ zd`~i56V^kks{{PnpXX8&bQGlXq-)x|)<&zl);v?rXFvz=tIfB5yYM)e>`~}_dz!23 znb{0~Vj>f7Fm`xR(QZ&@K_PPlT}%sayUc{T&xX3I0Z!WfAJ(#kgcW`huymT;@3 zJSQ8I_({7+etqe@PdrLzj(89-V%vwGmR`yhMI zTZP}7HZjwP&`q(g7OcvE!o~o#Z6O#i2dnkLPsB}}C<|ayB}Gn_%DFgc*cM)#V2XNn zXBRu(iAaf~>o~RR%;uaug0LY^GfK~f%RnlLz7X3{BTtA|)thW0NFIzxmtL@ur01Yw zb)yzVp`t<|7D}h#Apj>Yv7$*ZOMJ=0n}25T>&>Qt@0zQ-;E_Jzv9&=_F-B1}Mp2PQ zQI-B5<@aJwO=`d-=j#2gW6hf2ijbwum-SnV%0RAT9ZC-}2Gfc6t`)#ftN8V=jJJdZ zr+U)(9Vq*{n}!cP?1S(rOg!9rE`m$`pH*tVhkfkm|3~@w#7Rv0WR6 zdrx#Is(d2a5wt-O!lF0oyGXD-81#bA<`tf-m{u*AW`PmRKaF9xs_>}CnVl?plHIxI zgOz{!Q)rbkr8hV}jR%TUW~ zMV@L+n#&xnktcL@EMzP86LnlkeNQ}lP#j<23IzufZz87wLlv(bIulcI&Dc1GY?W|^ zfIYK3{dq3u8R9ZzJnkhMhbO3 z6m_7N;?ndh)gfjG>P2+CG~1E(XfW4aCs|^#nCLT>2z|y?QXp_80M`nv@4brcX1ZXf(e{yloX%Hi!6 z5ZswK@CdC_2rXK?KMs{zw;%^$h^0cme*xHh)eH7~1nr#dcVo;BuLz8L)(9m6~eRfiHHA)?bVuY6iz zME1!H8#1eQyh5P-l)ma^qSJD4XExIN<@}@xlcH2T(wpP)Zx$y;LcB+d=(rQ`aa6F= zhL#4ocmtBCT}Y1)#BWy6(S~q;!Ljz^ANcc5kc-*kLUl?N;Jl5n!j2VSW$TbrjPUzX zkaxymWc`s-KzPDY$oa!C3|ILD`j9kLJGV_$;BS$A1)1e6_fS*xOPulF`AjoVp-Mu> zA(85i;!ZHC>(YZmi2`_0cd}r?ro`YZe(NKdeIjwUD!8z4()HPEJJhL%*)D=pOu!u> zQBWv5fd67vdo)_Ml`~|B(FJ4MXw6SQStnKTCahKBf*Ca>UR{?;SZ{r{KnO=pH$}Fa3@AB z{&7{Ck5oyRFS9O*zqSs3f;W73#oS1iGx!T#3QUZ1=&xJ;}@pxx;1v2#9W7%yBC28ea2eTT>x4%e6!D3@TVW$zaC9e1(%*ZF$cW=@w z8Im^Q+gzx-OO9SBh>INxpWsC?_8#!i_?b*+U91$<5<)Prav#ZocZKDU3Exi^wFnO;Oa^SYCmuyh*BujWHN}7Kh)&Q^M#uH4Wi@-m9s_~ zu@5t84XU5qSve#DZmnFsiFtprv7kp(1Zy+G6vc^4m%!VdslRpjII7U)jh~!m^jz|4viex}pV|2@7%TnAY2>#bZm^Kn8MGbnw=wepT z5FM55i*zzT%?%MA+>OLyL5F0SJ}$@F0spGD*@Hb5aQh^=-w(y>6v2 zAv=3OaN3drZx+um7L|3212ZKaRk41ha}bD_qKQOx9~Q24NRWG#UbaVmGivb(1;Mx& z{6U_MCp2BI)i7|Aw5FX55GtA_W@SiP*-u~D9|Ms2OHkj)aNYdP^Ow}B_+7;Qs(McR z!X4uuCmlRQz7E1(nFrMUxM6%XOpOqY|I~)kKx9TF$^5Nypzg#DnUyB&c_a( zRY-Gd77@_KfUv3!A=MEp>xGJE0dowBK2_%b8-GRKJ0-%U^- z^b;WLzf^QQ6r4Pdz44{W+QtXcxc8+V>d@@1TyK2Gj~CNd>(b}5>_oiJ@)O{;1qrmq ziR>&g%Ocb0aRGcUGfOM@=_g6)=0j@SsicrbAkQf=J6+XbCjw zMyHEDExvws$|?S19WHif#*#_!8YGAw#Xi@^g&tQ&@y^kj+gp~vU6euPCvtcvgjsHC zX%NBthx{$>({ikEU#l4*8IU&Wre^!zMYbD;tG~F$MyYkgQqy>6R<_h`T4A(ItWjGD zqvJ0z)A26$v9r?wX5B5Y6-%`h`>4f>BP9M7>B3fGvjO#({>^QT-n#F0yQg<$zE%&; zhsJ;}TSoII?qSF(+hSV!&(!YU!+ACTYKrw#%#{}i`AdMGAO7*M1q|e~+vKc_Lek`C zk>SHxZB+Y)U}H`6mvadiwjDD)G6QIl{eh>#YI*PeKmMqC8F*r)dctliZ+D6;+Lb9i zfYG>=@%L-VuL9}*^zZx3v!uy#sSQPPSnE9+8GB_+LNCI-D>1|p*bSiD612PS!xa6_ z`Wb(2Nfrte`O#@Y$IaTGp;hI~vvm4rc)rJi0h+FzCGt)qc>XwiypIJ_z8+hQxMLw zV|iZ%{1A5lLZ{oQg#IU&v=aW07>$r&p4zFjZb_kb)*2aU*DU6IjN7KsWdn4)L&D60 zDh*rnnyn>IfNloBulG`QYrr^oqAe(*eC=e8qxX}3Yc6CXq1ZM-+(#gUM?vj&BE_L( zwyA;W3Y}?KlQqS~I@maEv1B%WVeWvBHdkg=foNWV>VzuBGEmF}F^<^uDOMYf`vdm- ztF+og-c@IfvNK|d?2DQR^vwSiMi9C0#T?X;jp3453)HQ()2K^yRLaRyIhX{P)(iTujcJx#o@sTnZAE|e;q zc#sQuls|Cb(>ab!e2uB|&4q6ki=HlztSwq17rsdjqVyO_!jv2pp;Ne-%D8pWSsC}$ z3D@fK^iAH z0`{AA+qVm_mw#!|&189#dQ;S~`E^lm9#d~(W>Q4|?kr8frrN|rp30CfD$Pb}O;c$u zRz4E7X^w&KO|_b&md)w@G1F9|^Dp)%>glRZhU-gHoOc2+l#N}Sn4-coH6|n8(zuEP zbC3%27^A_NcGx0oq>i*j05pEMC~8f4ocX$`^$s;qQ`cd!yp6aa=z%H5IE_`HeQ)FJ zhoh>e$1Gvj7Gy%i+0NCzR#{O0cx+?eCg)Xq9A&NhlX#zyDeK|#HvGWVKHjh{-mq?* zU?_jS*aS7Z{5$b$P(7KegkPm{*AuaVd}P81msKN~K7{4MSPL;)H11z|5@>1|HH4Pt z3>YR{DBogX=~!Z_*g&~#pa5+b8gCDSTTZqdkXbV-yeAQVUHGADQ#FNHp=8Y2qW*#47;!(ExZzP*vOelyOlzopQH4DNoCr^0v(WhWnfBZ@Rz9 z{-*nz>~Ffi$^NGMo9s`q@(*4|G?o`--4E8=`do0UB@2b8OEM-4g{R+}WTEhMq44CI z`IsO3k)icu5pY=l`KS*WwL54Q4%>Hw=G`!2cW_!7Pm^n^ZPp&sBxV1eX5DzGG#Yi{ zQMCA<)s~vh-b6TdC^?6~F{iXXms_fzZSPbrS zF*wp!FVj~qR6RmtnRM;#eV>%m#X+Dmm>zn;ONB-<=b;gvaY?5>9^E@aTh(};`3v<& z6!=!IDqf=0JH%x@^~NNSn`xGl#tM2k zy1xICev4{wx`u{TM+g_Z;wJP?ywIulIRP3v6xHl=28L8;W@oGcV^|B$+G69H45K2LX}Z-FmGJthL7LbgZ>q5>P)r5;x9yi)2X^ z$5|tfsaZ_PVY+n3fi-W!nm42iHu)ABr#|5=8}J@2i~#~mj7CNO29P$wkl+>S!dZ1h zA-lXY9SOT#QNa$or3(qw!HU|&iP|?X>ih9cl5kS&jdd%B?j-|bLtID$DU3#yUX>4s z1mP%9tnO~LG6ScrT&97%k)cec0Yq?Zi2^{{iz9I#V(p(}`$(Z|VH9kzmZhvM(Bo16 zUWiODMW$EQFu7Yr>}#4>%DdTcrEf=e;Ka=RVlK|1W7wpEjr3&Pd_ zm4E5}WYss8l)f5ywxK3B`*q<4SNWOfRH+p+p)0}>4@_MlR>uUlRBGKPi+`$@g!d`s z$xb-3Q-18*Vs2aNWn;Do%@zUOHq>X}NY7l0E9GiWPRg&$Dv8A$FBY+~i|IDHFSYzc}blzMesxYtAj6bVneC?1R>0_2W*VrHfti(45l>)7Cw<*tW*+e3h^`rS`;pX!uag9rCiA6qvr5ETT7Zk zv*vJUaad+uG4b`4V+q$I??|Z>jhclm6coN$-GyB`HJfhDrc<+V{C45htwx?M&~R1r za`mXG*diLdAH#)dkobhgg=zSGKJN+6r(EDFv@%$DW!O{J=_yluoAKQ(e|Wh`8oed2 z@CYF|ma$Hv(R>EYDaFUPsaR}l7MoRg$Xj_vc4$ zDtz($$SfS8cqQwHR}+?rub<9VS@bh6$7cc9X?lMv@H!ZOb)g6@3-Todxevlz2Nf7S zjE}xT4U_TD?f52PMf(+2Sd3S`_~T~$QSnCN1?~ppchiLk zfd3BtpGm*^6d%zYb_w-OsICh11>=1oImtVkmJj29VgwzCkKcxxcGZ7QmVtg;D6Hc> zULkKGuDd-}@P4aY`;Wf=Px=*56^`+%id|_=$BZt&0_c&C`H?}8EM{FST`QioK+lU- z+YT=SXYoXJwNxS11pRf2|N6c_13rf#pF_gukn%a83OxNj2Q@j)XCDA;uI7X;Z}t^8wk9#y*FSK8Mdf2l^n=J_E40eCD#Gtd1;n)^>Bt;%`O=tF3h2W9&#K!X3NWNYoPzv2D{`y1|W zvcKv6Ci`f|vCRI2bR#kZjpa#s+UJn;IarfnKsn8dlT4USv*ILKanh_fNmiV+ zrj=xs$)A#po0`ar>?*M9l8g!unc0+Nk=<;Ooph1ie37i@jmAC&{=Pk5Ab;Op+*n~_ z{~l&kA1)b(TF_{`(g#<1muTZkXXNYpXp?^~srGRt{=D)5c8SI^y{yWRTuhN%pI2&v zXV+&R0eo_3Ec5qXsT;HCn@g%SWRfkH)LJa5RgE=+q-G#Z0>~LOUdg%A$MZ=n%GM0% zgQ*x7=cGT-%S=2sZ9|xbvOO^^i81|}A_4nSIT*m(>&|v$c6Wb=x)@zmF|KU~M8Atv zmk+(RVxSKo$z2RKU=L}Q0>hUwuq&8Dl;yTo)~VHXYh_(p*YUlgR%HJX%fT(7NNb1X z@0>NXnS!}Jws~!ApKm|q?SJ?0zv-o~@B(fAT1s>$ky_)-8YI>zTLuQUa%E~?xGt>f z#^!=rRG8XY-L%%#N~v`oC?Q&CK-`EI0Skuw&#LJgmlKAyu>96Hjy-a|s2xJmkwM5|2YJ4%q|Ad{ITeCIgw(IGY%8`00G& zHCQ;&`rbmM7mc4}@X6{q#bK!UUr-@`_#4+=07Z#&FYLPh5B`mwF&;3V8OoWsatYF% z==Z3>EsO}Nzk7D(WZdkVLDtoj?Ym(YZ+XfBHghNf3wv7ejs$osRaJt4xtxxV*L0Y& zgA=}>RK1)8fl>-%o1j0s+aW0-afOHZm$LqDJ0bB}U<#%z!2}RGg@e8=dRgiP|7hzU zfvGl?rI3db_1^wcv3T@9`Re72}R z5`TYwtL_eC3--;E#HywaVt_?9hFVi%S@CdI)q z#KmC-4>=9KS10LP;wllNLG3Z*@U6!dqB=lNL%Ms6Wq#W=kiMNE29Jj$(b?mjJy?Os z*cfnmG;w;kE|Kal-*)*e7UBZVOyw7X*HCd(aKMH06_Ktej}g=yan~Wc;4%53Fhg^d zC#0G;AtM_MliI3Q5hNixkU!LwKh#mip{tGq;R#*bo8w#H&)BYE!y-1yFgq+obCkJ- z>@4wapnkPc`U6g*Aq2y)lM9 z4=#PxJW0H2skJ6iek4_YBvF9mZkHtRd5ZNEgm|f(h{T(fs=<|D!Aa$#-1cBb*361I zUqyFsyD&fNcUS@T(s#HMTX$lE*>r10-mnI@V-%RX#ZHpQSIXn>pvx$C{Co^yq-S!> z-Ayz36!uaxIgIF~_=#)ziQNea*L;2lYucUVr7taLOp3eg8oNix_X7^_BhEC|LoMUe z%dKklg#*z1-sS7s#}LtHwdj? zZ-j$3*CS^v=1!Leaj|~+x1RyPG^*`->a2gh9e_oaFB&h3s(+nc6gwBiZkcMM1H32) zRE>-==Kn3%hhB$^>al@AnvW$)w%Q@U&;f6{4XY8r7O|iGxxO>qkf!3VkPYKF=egm6 zZaBHYHr$NE+6X26xpbx0T%iL6+0=y)pYu6*!4k@FY!+ z@PO;&l0C2J%VCVT+6RD7LD|0_aJycQcT87_l(E}`yDxoW*gVM#+A)lyf5Hm3H}0=4 zDd*ew0PstYco*{!|7?F@)!GYvk`GY2R7i6@gKx^s-gduZNK|)Q zU$aVvx_Ws{3&hjLafy)pG(m*l_|z1Vb~qLPqb3a}EM#nzB=P@BO&ob%36rDJB%kM| zqkDKhgYPslt30efY5b?BS*<)G&O0X1s2l`NN{{gF2*P~2fig$MR*uQ2Yme9S+_`7+ zS)5lS^9cBS?5yj85Z=~_7+l=Yx*!Y7z~0srcxUVEVW-%km#{n{vVUP7}@RP007f4#W(-} literal 0 HcmV?d00001 diff --git a/_deps/googletest-src b/_deps/googletest-src new file mode 160000 index 000000000000..58d77fa8070e --- /dev/null +++ b/_deps/googletest-src @@ -0,0 +1 @@ +Subproject commit 58d77fa8070e8cec2dc1ed015d66b454c8d78850 diff --git a/docs/conf.py b/docs/conf.py index 23ff8ecc9e54..54ca524685b7 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -24,7 +24,10 @@ import os +import shutil import sys +import sysconfig +from pathlib import Path import sphinx_rtd_theme from sphinx_gallery.sorting import FileNameSortKey @@ -36,6 +39,58 @@ def process_sig(app, what, name, obj, options, signature, return_annotation): return (signature, return_annotation) +def get_cmake_dir(): + plat_name = sysconfig.get_platform() + python_version = sysconfig.get_python_version() + dir_name = f"cmake.{plat_name}-{sys.implementation.name}-{python_version}" + cmake_dir = Path("../python") / "build" / dir_name + return cmake_dir + + +def setup_generated_mlir_docs(): + dst_path = Path("dialects") + os.makedirs(dst_path, exist_ok=True) + + cmake_dir = get_cmake_dir() + src_dir = cmake_dir / "docs" / "dialects" + assert os.path.isdir(src_dir) + + shutil.copytree(src_dir, dst_path, dirs_exist_ok=True) + + files = os.listdir(dst_path) + + dialects = "\n ".join(["./" + f for f in files if "Dialect" in f]) + ops = [f for f in files if "Ops" in f] + + # Add titles + for op in ops: + with open(dst_path / op, 'r+') as f: + lines = f.readlines() + lines.insert(0, "# " + op.split(".md")[0]) + f.seek(0) + f.writelines(lines) + ops = "\n ".join(["./" + op for op in ops]) + + rst_string = f""" +Triton MLIR Dialects and Ops +===================== + +.. toctree:: + :maxdepth: 1 + :caption: Dialects + + {dialects} + +.. toctree:: + :maxdepth: 1 + :caption: Dialect Ops + + {ops} +""" + with open(dst_path / "dialects.rst", "w+") as f: + f.write(rst_string) + + def setup(app): """Customize function args retrieving to get args under decorator.""" import os @@ -44,6 +99,7 @@ def setup(app): app.connect("autodoc-process-signature", process_sig) os.system("pip install -e ../python") + setup_generated_mlir_docs() def forward_jit_fn(func): old = func @@ -82,7 +138,8 @@ def documenter(app, obj, parent): 'sphinx.ext.autosummary', 'sphinx.ext.coverage', 'sphinx.ext.napoleon', - 'sphinx_multiversion'] + 'sphinx_multiversion', + 'myst_parser'] autosummary_generate = True # versioning config diff --git a/docs/index.rst b/docs/index.rst index 080f942392c9..b72c9352ed93 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -37,6 +37,18 @@ Python API python-api/triton.testing +Triton MLIR Dialects and Ops +-------------------- + +- :doc:`Triton MLIR Dialects and Ops ` + +.. toctree:: + :maxdepth: 1 + :caption: Triton MLIR Dialects + :hidden: + + dialects/dialects + Going Further ------------- diff --git a/include/triton/Dialect/NVGPU/IR/CMakeLists.txt b/include/triton/Dialect/NVGPU/IR/CMakeLists.txt index aa965dac6284..f8932cdc4b7f 100644 --- a/include/triton/Dialect/NVGPU/IR/CMakeLists.txt +++ b/include/triton/Dialect/NVGPU/IR/CMakeLists.txt @@ -1,3 +1,5 @@ +set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) + set(LLVM_TARGET_DEFINITIONS NVGPUOps.td) mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=nvgpu) mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=nvgpu) @@ -6,6 +8,8 @@ mlir_tablegen(Ops.h.inc -gen-op-decls) mlir_tablegen(Ops.cpp.inc -gen-op-defs) mlir_tablegen(OpsEnums.h.inc -gen-enum-decls) mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs) +add_mlir_doc(NVGPUDialect NVGPUDialect dialects/ -gen-dialect-doc) +add_mlir_doc(NVGPUOps NVGPUOps dialects/ -gen-op-doc) add_public_tablegen_target(NVGPUTableGen) set(LLVM_TARGET_DEFINITIONS NVGPUAttrDefs.td) diff --git a/include/triton/Dialect/Triton/IR/CMakeLists.txt b/include/triton/Dialect/Triton/IR/CMakeLists.txt index 84bd723f63ac..42e6c039d2aa 100644 --- a/include/triton/Dialect/Triton/IR/CMakeLists.txt +++ b/include/triton/Dialect/Triton/IR/CMakeLists.txt @@ -1,12 +1,16 @@ +set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) + set(LLVM_TARGET_DEFINITIONS TritonOps.td) mlir_tablegen(Ops.h.inc -gen-op-decls) mlir_tablegen(Ops.cpp.inc -gen-op-defs) mlir_tablegen(OpsEnums.h.inc -gen-enum-decls) mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs) +add_mlir_doc(TritonOps TritonOps dialects/ -gen-op-doc) set(LLVM_TARGET_DEFINITIONS TritonDialect.td) mlir_tablegen(Dialect.h.inc -gen-dialect-decls) mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs) +add_mlir_doc(TritonDialect TritonDialect dialects/ -gen-dialect-doc) set(LLVM_TARGET_DEFINITIONS TritonTypes.td) mlir_tablegen(Types.h.inc -gen-typedef-decls) diff --git a/include/triton/Dialect/TritonGPU/IR/CMakeLists.txt b/include/triton/Dialect/TritonGPU/IR/CMakeLists.txt index d32192749f25..7b7ca5593afd 100644 --- a/include/triton/Dialect/TritonGPU/IR/CMakeLists.txt +++ b/include/triton/Dialect/TritonGPU/IR/CMakeLists.txt @@ -1,3 +1,5 @@ +set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) + set(LLVM_TARGET_DEFINITIONS TritonGPUOps.td) mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=triton_gpu) mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=triton_gpu) @@ -5,6 +7,8 @@ mlir_tablegen(Ops.h.inc -gen-op-decls) mlir_tablegen(Ops.cpp.inc -gen-op-defs) mlir_tablegen(Types.h.inc -gen-typedef-decls -typedefs-dialect=triton_gpu) mlir_tablegen(Types.cpp.inc -gen-typedef-defs -typedefs-dialect=triton_gpu) +add_mlir_doc(TritonGPUDialect TritonGPUDialect dialects/ -gen-dialect-doc) +add_mlir_doc(TritonGPUOps TritonGPUOps dialects/ -gen-op-doc) add_public_tablegen_target(TritonGPUTableGen) set(LLVM_TARGET_DEFINITIONS TritonGPUAttrDefs.td) diff --git a/include/triton/Dialect/TritonNvidiaGPU/IR/CMakeLists.txt b/include/triton/Dialect/TritonNvidiaGPU/IR/CMakeLists.txt index aba08ab137d3..b7ce83fe7ea6 100644 --- a/include/triton/Dialect/TritonNvidiaGPU/IR/CMakeLists.txt +++ b/include/triton/Dialect/TritonNvidiaGPU/IR/CMakeLists.txt @@ -1,3 +1,5 @@ +set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) + set(LLVM_TARGET_DEFINITIONS TritonNvidiaGPUOps.td) mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=triton_nvidia_gpu) mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=triton_nvidia_gpu) @@ -5,6 +7,8 @@ mlir_tablegen(Ops.h.inc -gen-op-decls) mlir_tablegen(Ops.cpp.inc -gen-op-defs) mlir_tablegen(Types.h.inc -gen-typedef-decls -typedefs-dialect=triton_nvidia_gpu) mlir_tablegen(Types.cpp.inc -gen-typedef-defs -typedefs-dialect=triton_nvidia_gpu) +add_mlir_doc(TritonNvidiaGPUDialect TritonNvidiaGPUDialect dialects/ -gen-dialect-doc) +add_mlir_doc(TritonNvidiaGPUOps TritonNvidiaGPUOps dialects/ -gen-op-doc) add_public_tablegen_target(TritonNvidiaGPUTableGen) set(LLVM_TARGET_DEFINITIONS TritonNvidiaGPUAttrDefs.td) diff --git a/python/.hypothesis/unicode_data/13.0.0/charmap.json.gz b/python/.hypothesis/unicode_data/13.0.0/charmap.json.gz new file mode 100644 index 0000000000000000000000000000000000000000..c6b036796449f41d4de90e2a0bfcc8eddfd055e6 GIT binary patch literal 20988 zcmbUIV{|25`0fkGcGBtCM#r|@vDvZFvD2|_t=P7$PC7PMY?~`5&))y{J$sDv<&5*? zs;g#I%{l9yg;Do!7I73D95@&l7{sTGog30Ka3& z82{6>cy8P>Q1bCrXz^C4X@03f@#lJ5p!4UJyU?#Q#napkH{jjr)CZya52u@m%E9(q zqtAtEp7-7(zUJxaQC}?gAm=zge?48<&dbA&CmWh{_&a*_V{$ci76shuRTf zHHsC3fPl33iKiWZtpfa!T+WsFgQo*Q?g7`;UB|jidG6oMg0!u*?t|uH_TI#jV(^IW z?WxDh8|Ao~mSpqPp_^HYP2ieJ=0hhBg=!K)=r zshV0jop_JX5dR{+70p^;wIGF^H1F^G#LymHS2LTt@#_+oZdzaHS0p*`2#b}r)Xo$a zlAE4$4U0Bi8N|87+csV$a~Z6VkZxFScOOC~II=7b}?9=%% zFRJAJrYv53CMUS1|9RSQdZlY zF!R~h(-qcaXS8#&sU(kXXXCm9}E$|=YD zP3`#SxJ&Z3xorYg$w|EC6g4T78b_1u@z5IE?H;7fG^bM}cSNhDE;U^lSH?z})A;9~ zK)_X^i_lSxO-45CXqC2!Y*3I(;yQ|S z7Zc)>*9|g`O`+oEW=LnJix8JzYkSScK}!v#%QVCj1NZzECrSPE@k8f~bgb8>wH!P- zeH(8~S^KvAqFlFBb4sv*N;koKY~_4kwri`HN)9RP=U{)MK*yA7lRMq#*g^}@iHhnXZ|cMCE|+x__+?i2kFYh_wfI1ln6Wik1nGB_v%$@V zw7%xx8tA8U`&A2y9ywq4>&;pBBjH`4aXdhCN)zdgkM?v~OS0o)dF4+%*V76XldX#q zwYv#^exN$WWsYHHV~vhut(nYxyPfDj^G5MI$o_=)`&6q-sOu@O4Sjhm|LbNsOWV|( zJI6Z1WYgIKVHKmxYOPm!b&)p4=+Q`uo&KsiwwH_#&>erK>&4CXpR2C(ru`qFlh36= z<(fLd3Zi*T=RrbY?%?Lv9mlzp`gI9p=rwa~fjS8ju_^po9dqQC+fS?M%yLyeo;nYj z&+(YtC$8MAdUZP{w>S`ijB#Cihq8DA_*W~NZn7LTts@hJO|Fj%0u$gepQL93Yo)99 zW!@#qj<6TqQ=Jk_?y@wP1vj4(ObYMzT*Y(6m)KXz(Q|Yrb<;SA7^B+JnQ1i`!8Y)g zPb@rBH7csU_NS3f*3E6Yxr8>)n{&Tkt^Wb5i!@piM}~4lNoFP67-RQcw>U{>cVAC` z7A@_*;XNcie>N<#RWKIMMOe#~oA-1UHwV2X>ZmqC7tLIqmM^5v{Tb@?PWiDCsJD12 zwp#b}+{vOL(_~OX=O(+GG?h??^kVD0JI)T{+>fNoKX~GL6lXg!e|9dVc+kVj4yJ3E zc;;4^5TZqgjQ95eB0s-Nkf4{F7)S{6dwHrf2zz8W>};f+*xKOlWq(-1jc=#h%;|e^ zKlbTN24hb|{ii%oO^Tb?#8nyZ_ z-MzHv_U-xgQ~%1d+M_=t&PLjyI9vlW(MG!`MkCvMS0`Y+rH}wqu2{3O)7<*xUN-We z2+-Knb_6dcI&r6+hj1jj^qohkJhg2;>&5M`K++!_H*>vBxKOiw3vTtEOB8^ z#GGiY+SWc*|2x-k`IBL+{BCm*wqa?d9{+V+;X{n2V{+!DF!^CU7{s)YXf~2DPreE;Kj;h!@%A_%(%vh^BNAfytKLYl2a=Czt$be`@WA);cem!>GskFr2O;Yqk02G*tTtOKyRyEE=FG=A|FfOz zfPh=DyF>J&Qu6GqJj9f#8r&WBL)W66@7TfWeGwsAqV>hY18i1$zlV%x*h^b(~v4P?kCqfOMMJ#NmyD@4Gyc>Lrzl2L$%9%s`%n`B z;xOFE!2?=S5RV`I#w zbq+Q@&1@ucQmclOX>TW{ooN`#G3@yUcgub{YoR-ulkxntXP;N#aVJFOSJb&r^P88v zkPh>d|++zc|3wnrNLM4)cniP!^S!OAcV4m{wu5YEyAu^cdV1Ah#U7w0-Gvgbq%Y3YI_y+ z(Ko(XrSmG~#myRgK@G7EZM}=qm;60%&y>Y&+6Kw0zuT$p3y(ZgxczG1vIwWynsW_( z@jgFKZBPBT{j_a5y&l8%j!Hq)xubUX{)Z!17W~}=Kqjwq8`$MWpAW=6N;_H#q=iak zOMWWgvJa>5fS$DHt=p|%DaR*&x=`F#5ZKT~jx+shU!O%vP|Mt?%!uR81m@14GTJ^z z_a6Qsd!G2|T@Za(b#GcMzq;c~u|u5A+ueH~6SvOZ@vL%@6Re}h6;*+@E4mc;CtyAm z&rX`1m_d~Z(KkNqpxrk?93dTtJ(~x?!UVn7Q@dx-M$Y6_YtPkt_hIE4PjcM8mw_dN z!?t8MiTI(~Z*$t4hSFWvFFVmIcYDcz4a%V^XVOx^?z zJp5Z^q&EBXyYuJ)HV@u~On>L!`t-m3RQxooifiRLCBAdwm}Grj_R`{O19ji~Y{!a{ zd+vOOU7;QV*KH@5IoBKQ(KpINKlJ5b8{YGsI39Qn=ieONuRm)c1Sw#9pRS~;9e#PA zTrT+FO}5x)JqT*dQw}oQdwpJ^T!vi$Ea7lAB*vD#d0&OHzaAn*{D3anGcVLh_ZM&L zBPBO~*bn0WItDt%kD7YQfAc+_5gMGWJ4W8S;7nyPuqXEyseEFwHB3GVK!N%I{^SyQ z^0*|3(|_F9?0>$0*>Bal&k=x?G+oQ=X=-VA}r|+)44Wriw{3NlP-UD z_qOi5Ij3dl2z}LkbGh*Dj{jnP`SSxn>-$u}0f`c@ev5R0(&i=MP|4BoNR>?hXzJ-q z1LqocKVN>H+Gfdww??fAGdC2buSLmbBSLm(I`*qa?R`u{$-8ccd7rTy1T^haKQ`!j zV*`plFb;%Y_N9416#F_dFM%ipB8zLC&M=!&7VSrnuoBH3LKmkf=N4~(t+cd{E_!~? z`@c62*$Nwjt7pE6uR}+?UHGjRSC}7kEkZJcNjaEx5lane&|?5Ek4 z%Eiu!8sEE&0-UAm<1Bv7s36U`9q-9jZ2+EdCy@Hs;x*J=I+Uz`31GC4?+YOCE%@x&@e^JLQmX=< z&&5~r9;kf+bv580I|MyHTps3u?V(c~2Gb^INgskTzPtYYPlOk5;sVpl(@4G;SD2vr z=B?V)8pm#Wy;F;BztkFw?lbL<+=m*AGq1-pPgGylE$R-Fw+GDEt4G8%LhgWxG>3|-g1ptSML2T&w<)4LLMw05O!yGiUlXA|eXXURy@hUkUjxqeAB z2w$B0lHP(RbmQAIZhMOv_s_E2@C9HH5^sMFa`Z~F_`0Lsft3l4sP7EcF&-7A;2mynJ-4;&9Ul#=`su;!M zD<;&R8kJI;9`AHOtxnYweK>PlHTyia6WkrraMrrI4$H{E*uRDbW*z%VlS^yJe@7X& zaH({GZ$1RojO%5G#3WQlgGv}Iv_gz^NrV8cB-@via__m8{^p|JzgYTW_nhxnoR9LE zKEdw~5NM*jk1y2?(-~*`pHEyWgRa00KmVd&*P$*>&v5m=(f9D^>3C|M{(&=D0cXKt ztq;&5K);@I%ddk&WY0y77udN58^@eOspH!@w-yR+fEGF8M9NdwzYY_@d z&XPW$I=B=fKo*A7wi89LI+(-0ih zHaf6^U}=bW8E&e8q-bM^aGA2=wePv*e)RiVF~$A$>|i>3T6i3(Fryg4Lt0Yg!6hHLX4feqdZ-^n zYtvYrEp@r}T(+z0Y76>_Y#(7o(!SeQ&I(#R5{WdXw03|aZf17>x<^|TYJwN{;-(mg zk7gws8%rk>jif*~GW9{o&MISltUTi1jtSFCcvLKmJG%_fP1SxJO!L62D6H6`ffWdVpZoXa&WwJa;kN)2)<( zp&DyqE2f1>V{_Vxd1G5Tt+iYP>#oWr-;@8966=kf`biuqU~um)AcPJI@-=MYvb8f> zfL}H6V3a-$5oGlqI}WK^6h~AkteLN<=aBKXUWFC!P4N2DO{(AFx~4txC{*pg4!Y;# zuZ0wVY}vA!lq(?KSpk(iFv52aHO5*o;_FMwX=8ZR`$|n`3i#txf8@_t!V1}BrZ!fa zu8p;rSPjb-_1E$8+YN_riG3ENus{Q_O1UCB~bT%pyE=%6wwq&yJ1$ zndN2I`GZnZO>s}wOHQJwfWF4bY=SbG4#{>n6Agc@iuqU$&O4J_>R8%^_qP#O3KnN2 znH}0VX=*m#g)fp@p_${LhooC2yS3^yFARz%4sqYewjc%VzF*8 zb~LeWqn%J)CekdTKD!H4UY>Ecz<#%bHW3X@T{D%Vs*$5zq1g#&o8 z0La0HVzbiwt)aKNLHQa|Iwpl_)dpzTND|OH_?2J5DPRfLYhF`9xGy+%0AoPQH-PAd zBPW8Zk}CV&MvPud9K6ne=uq;ti25wN&WQTT zHeC`z?qpdw_=;6C9k*G%U&_r7$nEtD* z%ywkG;Cqb7kk(*=LplbjE}4;ha8-mbw&~Ytr1|iyBE{1S#u~80pcf;p66GtsMimu) zkLe%6QEPD}*`rE!a$X0wlGce&EegO&M4@;HWGRz={Tl0$?S347yrA%|kX`%#)fBAB zwIvcp&2M%X)RFC39DbajNWdm`w~nA1VsVi-iXYtMMg+Npi|`szpJdi=>17Ptxc5$v(q9W zt!hJKummG^V-86Yi4bG1IE4%)wYIhW1MYl(tEF8noNIm3>3M93tz zkV->}7F)SNh*vm;@uwCw#tg4$guL%FY$UZ>UiWEv#n9P5=7ZyZn}5GETGps1AFVol z=Ntho8`KId7aoH{ZDcz~1T>YtJTe%(E4}-&{1HWiR7k?yzy+Y+Oj! zVC22=&}>S`+W_XH@z`u`y#&C1T3Y=o)96=oL>TYg_rdtsxL^3 zW6|H#YBZa z!nsOI=f134zqXjb*WwGht?Q!c-3$7yC766bDCI%;70vILOe=_BYw1NBj8>p8&MAH$4{x{{QR$Q>eGuzAWI`JO9M@pWs&r;!V7E*T9~B zEzfJ)dq+T-Z#Mti5D?0agw3IonWwl~q&)23DGGi5pW=>9`d4H|%iP}&^!Etx|4Nu? z{@Z`ze+?zj*njG`{FKzm3k?r(@loNv55>YY#R)v#MAmJ~Kac=`nY5RWyqEeY>E0(X zTXF7A<~KhdVx7hC_gSsOmr0{|f?Ch5n`mgSE~41wy&ErXgSgSN9S;Gb+N1d^*N?Eh z2ao^9mjAPmXgbRDcNbw?>)?$W{}#&VI_{On-i4zNVH^_aoP?W9rEAs&4ZU$T`JJ-{^={47@63N>74Vv@4UdO=FQN<|oa32m zS=!c_iu9Y&+GbO`GUgHo7?A)PcMC>#XYi%P4+i6go-q0MVEOkYv9TXi zCjMtTK4~?My`Md}DbFp+VeWAjz{fLxz8CuXMk9w)Ohe~m`=LQZ`fu-19?yDS=e^fL z`jL%Qk>YUIcU)k8oOh1N!-3hbH{8kI87p#!Yc%sPo)82MU;~us7}}7c(Td1u(Nsv1 zgCsSKvzkiqK@Re;bTvc~&&#~e%VWL|e}0`}rmdvK!F#OWyT>1)h3(^3bJsvgIp zcu2YDa?cO_MGT}6)xN=3uH5{vCl#4AV6S@!A;aq5&yhIy2TQ{fErwAWS2D08Lr91e zr67YT5AWAk>j!7CJ1gnCHy6&kF!M2xTl?*o))`2}WX+28@Ar_k{NGqPT1zvE#?TuK zjD0bFeB`4Z6b6~AT$>#zVUx{-j%?SM$#xfSMwHX7o28D)W*hH|feo#gqop|Kvx#em zf7h)(IC!&bT%7ek)-AttPtJ0i0B~P*NM=T!Gc+dN%t7AVdr;0+ z?nR)j@}TJ0!)O#E$P(kO=E18E6ckCzO5bn?iI-QClJ`xaA$fSTJ?arrG;EAo)lZp9 z+|Y(w`h-auQ&yF~YZwMZGfCz?oE`<4=MgZ@1$GU!9J@j>py1O@ENx5D6F60#KV^CY zsu&7?osd#Gafji5mQrbSz?kD5HF0=YWbu1kD9vR%jdjz2M4k^W?>&Zwv>L@RZOFRJ zyKn5SJ|KI^$omNswrj@Dzn${EOa?`q^A66(i~MfzP;V>r0kzt8D0%Jv67}zS(wBPHE2FlpBkjIDX%qzDF!upC2bB5KJaIa{Fj~_h=BVeA!@Z3SLHP zx7Q&D^`^gu+fXXI)@Je~{QO}CQ!MjiGg28HDK;|FwQE~Hp*9wMm4nPXAd7kECCs#* zXv90vN`(_gp``O$)J{N=1>)VI4y-YY-aa1|8<`gLH|=AbPW{=SSPv1+u5IgSBf9r! zI?5_FEfq(m@pT2)g*8i9vq0Br6u~0Dp$5-Bg!1g-Soh)f2tr#O{dQUC@=((*y4VGr zV8eI&*6T`QZsZUhXdg9x**(Mn^$@>z>n4=ymc7BNO!Ug&Dg;A>=xMYOi#{8HWO9Lc zROU=vFrfjh5EE#nmdfc^!Optc8GT6r@9#vDndqQxbS!Hu?g-NCA>r{?wDoo2E+054 zX0!uh$u*Q~R+OAy^0Kcr=EE!=)6!yW(V$?EmhO%TA6SFxzdr$EZac>PY1Bgd0a-US zwq1a9+)$qj$R91FNIPKRj&Ks(+xGak&8XakT_9vrHzZH<^kViVY}BD4`U6!AT(DM( zKEn`FlxT>_SIDcC89q)l#h$Zr?oZBpx?fH~+d73de_1RE?0(^8b9j_$3ER5XipVjW zS|O#Ui4g!GjXLw|`Ouv*6mT@II~B~5exKxsDk(6etG1QFQ!HR%SVO$Iw9fq!7G0>G zooQ>Xg^FLb2NBb7=`J@UCE|ePkJ+A3;xD&{uxv7uaNKRC)@kk-PzTiB+Cfbj36fTo z&-*c?JE?BaV@ev@QMyVG42sIg2~Nt~EiBQ{pqKa}tM%HRh_jn$h;Z$5i{(7 zV_$>Ot7s{_|9rd>m3KYyLA21?_3dWl{tT`Z;7Ck6e}-oBvAQr!FHZeU^6TLXjlje|iAlRnEBsBi@Yc-dTBF>5` z12NglODs=_FmBT%8dM;`-9ewy_oy8z+~~GoW`mZ?%g_naiH0quf!3KqaKY0~;Yjj7 z$>^ZVcwymp@H**^#v^n9mpgblV(=hBaNPhudV0U;O(}_0Fb{d&Q!oxW2Ab7f5<6|8 z-FjnR(^v~(`MR`uSIo<%<)-b?h<(XOC^@1B2ZHPiuJ$Tz>C3v{ZL8gcOx*$dW8#NG zA#~iFWN9^NiGQY{9Rv95+!DHokI)i^y0d?8@|+b^9tN_Z;8p?0;ph5`|LX(;VsQ^{ z`^eZaOD{oEd{nKScgve)jOXLvgg0z+RjHe4&Z-Wkf0lqJPIe=c_Y!t#;{*R)_UmJ- zT{lbNwrAuC?(}Fj-E_gr9GFPW2rGZcp?L1TT>4ZilbgDZdjn^Np<$CFZx4%@$jZ_# z{NT-|&u`|GlGR}9v{&l|;yJ6*j6c*@S9YOQy?@?3HSO>ICcH0tqM6L8a!mq9VNM>f1tRx?oo99z1v%A@w zr${8tE%rDr!woFM?4`|S4}o#ya>CpB2qs6Te%&)I!9?G8p8YKRL%`ruB)`T8dyS&k z@RD&)=W`t=(py2u`e6_x&m}W4eqKZCY8-0pyCbP!y9n#1mA(FLxt+wTO&Bl>U|;Wl zceCj*_E~b%oaAF4$^LB?KUR~XD15VTne*R_+?5L4YCZf4qYl?y>&~la!XDXga$r`M zvb#m%G~>T3w=v*j#&isD+6wpG*1J-uYLsu_-6owl@=$?wYkAjs{m?zB3Mgg=9NGR> zkMCDIHctg+gRj{v)1n*GRjve_O|Pv^s-5()Q3{IAq}$~*y=of$41G)3^s&tGaoWuU zGL-(tj{!p2c3G|smBK*xJg?jfyo}FX$^E>Y%!p$xKP`U2bN1gi#QZe5^|Z*GM$)by zKDEI8d>Y$gC}6)WwBjcw@t%<$`Q#Dp!E(ZTo=7HkZKk9tuyCK)*5zM+H@|kXX9H{y z^4F4I4R7-73(emNY=?=OyeJ7CN{Q)hvk~Lcq|ZvjJYPko{=K4%>(|fm=z571EIYNN z7_{h#`$Yi=8H(ayG03zDF{PJT0u{&!+BE2Brwf+M2#MFM5uHF@6 z>(x#=zsh=^`RBm-H_W_;1?PS-UC1us(^lIVxX*X@nbie4ISZSp#Ut$ftBKts1eM{R zk25I6{ABU|2 zYthlas5|dtk>g`=$R-9MniVT<+|%Az{4x!lY} zarNRn1Fm|%*EE~>ijrply?tg4XnGp$L2TofLYCWq2erj5cDcM)N%raIpD%s!X%eV+#PUYFgK22V67{I=uUqCa6Lfpokc+P@{Ncx!IIH*eJ z(v(8&hD&}7sw-a~pLl8I38dAYxwK)#eGCobH2?ixay1LAUv3WX8-t|+Rg3i0RC1)F zR&6bt_??nHLKvpK^p(XGTnxD#Iv$FK=a?Gu3O-6~c`8 zuAC;w5}wdPY}tV%7vct$eFiy@xbchod=BRf_VMA5q0qcDPk(hnuqhlb(s&^)1_l0j zK_PBuZ!l-ca`j`lSsr74%u@ty=n^?0g6}Pji=gtcPn_rxX2=vdKpP+v=4Tl$J~dB- z`|~^QPyZrb{g36!KszGbE9y;)r3S=bmUVs`N^wypM(DGtMF)sKHtw`>yF>_8#? z*F`VQHkXjIhsK3zujNLUugvf?YtC#ZQ87f#$8q#M!%|_4jx2=U4ahqjigxQ;r0wbK z-g(LwtQjRBbT$wgEfN!A7k@W=6Ub5Et@;?5h8agI65E?aJ4rt%;wvE@ji11T`hmon z8O#WG25UFrCH~`YK*2z$#J0*PmMXcH=)Kg1ib#zYCC-omYD!$WYEG=vdv7($ED3SY5$MDJw}8k8924(gQ%qu$C2 zib=3{6%e>cF^UV+!_67LwTc$gr__o*&|Q*`7tN_-qElazZ|yr`7APQak}ef(V;IYA z)B@?WLbOs41V|x@ln@FO>6V5%l=Uf#m6T;Eii2(HTX{y3`rM(SHGmRB{aYwMH)%fkD$o*pZdkQ72;?KHc$+$X#RQX1x zqw)AK9Bkf(iDT^F2k6o7n+7zuB;zTkhQi`~S_iLs0iJdK+yc-V&|MP3>V0a^nXsmQ zooe5Vm66C69iaWd%tiHN{!oagw8f=k)Mk7uX@QUp|KSu{1svyq zlWN|m2)L$Ef(biBZypaj#(YAL$lPF}CIM>q#X$+`R4F~sYN@{!q0^d4W- zpMYEAMF6tD-I2%Cl)v0}xL&JaWR2XrefglK{GOANjU zhBKH@?tr>dQiJgk0ADUbQn{&8%%+k66qBTk+WO0A}0=U%k zrJ-=KjD>xa4#3gAVMEi*V1-TR6!k*rE$T6nKsUfZL%^t3F8am8-Qnfo5Ey)x+n^yq zaMYpH5{JVn8)IG$J0Y?sTI%qt$c*#=UT zgA&0Kc(!MOEg_ie>C1=gtmX+kR#7~PD~(~jU1+M7Pm z$Q(QODv^`Np)jyc9oLZ-hW~gSO3$Z|L?iqH7RuesnI}!6=HE)Z{)|}3$`w3h4(3LC ztw~~9zRnUQ-D_cpESI2A(X5*e1{>lBFbo)p0pNW&12V47AS2S}4UoaMsn7~6v8#_5SkBkqHPR1{a&7Gx?23Ms6=kA|pj3NDAByZ+ z7VoCGY-W^~TUkQ2BmVwqys`4j1fx2WC}qElN|chNf?}?ScSBxO?Ll;Xic?UslKrNn z!?XmPvbiy6DHyO0Dtxe5zX~^92rekX|;1c>NShEXoVDIO`)56Re@IY4XDj_ilLT!uSt^dPUAHiS$N1&co zZw77BhD@o8ap!zpCj1|d(>NTPB*CNkf^H_#Lc%oG$8|a*&%!p)r5(;o?wcN5Uz9aH zycx!UmK5zA+CP%;EzR=Nm9mpM^*Qj6O0keihbRNJE%4(gG-+Bem5S;yb;0;bpG4lG zA6BB7iAs>dKR^LtKor|U47jO&OVyQj8DNR%EI79gTr74#VTchi_=41bx&IU`FtzIc zQrhFVGwU=KzrfV++=^l8*oKst%HBwEWSGxHuIItt+0`I^XeA8NqW|#-7_j(b=>oU% z4a{v(1H&tUJoyFF?F`~xU03EBm zwOuC2z?3fmlHFBFK9;X3X&?r6Ae>XU7O4mAYyHI> zjI5Cg$o_*PiZLP@Vlw|R1eZA+a;b8bbey1amN7bB2ys}j+2(KU2wt-|^3nYZIN~ns zSUyO?rD7x^8Zd|{3AGTvNHmq1aDsl!4P*+)wjk~6gsL!p4I4WQu0nh>Iy0B(qa=wO zZE^_~vykYc6j%X{tO)!*b4X%3M0;9(Q=71g9u2?K7pGwfUo-z|X&m-ZdfA?%)o2&f~%JPFQ7tFN1h)D_hfmD_kKiSAXy}~hq@4Enai`5#683rr37t89b*Md_L%y%KXO6?Nu|P{WASXR zZ%gvr<|?@v!MfZBaWX^VNF>s21}Z&i!Gn|%XYr2tMLsm})UxAIiH?Ou?=hPThaT@$pc#EsgWP7|)(53myID{cz zUDvWIZ{A$6Hf)3FGVW@y$s*uECv&2iUBkv|qDd@j&(n9BjPzO4XaPlt7;hr!wyBG@ z7l=u~Fn5Up+;+Lux~LPHCt;W)Vfg=PzyNNHmGX*VhfrX~fp!()7tdf?{aP!D9G}qg zq4{DMLV@{dRv6e;7^zkmV3Oh{86P~zP}>zJ!EJw|x!F^x&5~-9B7KEe-RA1A?3p=C zAm}<(z7F(n^DkNAb!{ehhy4mQ*9LL_It=2ILS3sP%-8h0k~foj`#<0|VgEPMsGer0 zg9>^B!Cv`4e@SnDZ_W$t=5Vd|H7M-L|9=FZVIyO3?thpDAH0TQoO@uDH?I>c0^k1= zYdpw|_|7E&OwUKR52m^@bN>fjl&8l7{BK`M|Gx-p3ZFt$Lea4FyfKA7NdqDk8|5tb zV%CSJpzM@XJ8=wxm6H5?GHK>FsRilzlV1>Z^yuc9p8znrcjxFs^#88#%JH2ru6FRomA_{y z9&ut6mGBD0(-S+oyz|UUWEi%0;=xZCTWX?{E2ficLKh&GnQ6kzy?frO>}0*Pw6;F{ zhiij**p2L(`Jm}2uQ;+0uM66J?1fg{BiM(Rhv)`gJu>O#Sf*@0hm8pR75nC|AN0Fg z==Hg~eE9QF#}E%4ugC63B!m4?7epAf>yLRu|Ds(c2&^Q3pL-~OpxINbPyT7SCTw1ia~mRfCICBbZ*LUoUPW+Wt5>qPoF z6E$L!PPjj_N(`C?#`c9&7qsJ9POAlRmgw?NLFH@XitXWGViP(^Ugdd^eT+%J`^k`n8vY!@ za{8nN9&>}Ewk|1b_A(5sEstG%WMJYCr4E-$#}m7RvdZ8})EA8cq?K*>%?2SZE#X#6 zc}_Ma@soCu{QA;)pLmqc9PuDv#I_kZqh}PhZ_t(xB2ll1L(c&pFppsE*7<5s_CfZZ zw+g>EZDOVop_^h~Em)NSg^dAf+d?p44p!@fpNN|{Q5L|aN{XB;m2+{@ur0he!4&oE z&MtPm6Oj@}*Kumsnaw$S1Ytv-W|W={mw{9geId4^MxGF_syEq0kUSWXF1=tQNzXyW z>P9V!LPdo_ER;^eLjX=*Vnvf+miUr~H~-Aw*PBfP-!)fv!6SXbV{3z=VvM3{jG`iq zqALAA%J0RVn$&pBse;UJXRpC*OGdo%IB)fCZ z3%v-^k;8nB9uT(WY2NKiXfLC#SKxgZ-mdslHGGble)6#1Ce(u1e)$d#smGX6_4st; zhSJ4bCgaSf(-aLG4eR^sm!X!~ ziagbtG?zJCBTwk+SjblFC+fJ8`kr|9pg6w36$%a}-b79VhALh=bS9?anz3;V*(%`- z0efb7`tw}QGsI=cc-%`i4qG}-;SfF*uWMTMF_IDQEV(6LsNE>Zuz9*MnBa6 zDe6Ek#ii+2szb~W)QjkNX|^Nn(O|BhFE*}B?-X#poX zhiBQvRZ-`B(fock=K=c!HQ%~XQ0+#t%f3(i_i~XlTODWo$Z+w%-t!*F?%Wu z+Zi^FS+?0Go5*jR@g(}iAv)A#B4*OzShjx{lrW6WbpG(kmOUxHa*da|3-qt2YE%!L z5N4Wb>XkKymLtYC?3SU?5{e}LxOawqZ>Kk&Lnk?j^mhs7>nI>3+&=h`{Co6%mBZUD zAh#)zypNwU5r-_`!{YF<|DPIX+!JZsFgd@=enI)-@`stzSYLPV!!Uiq}b zi0qRYHe^=qc!fauDSg$;M5pE8&TORj%lSzYCPk@wq&LUo-z-jygm{k@(Qzl>Cg&ixv%GM#L7~%J& zAn%OB$oeCvfbfK)kn@LO7_RaQ^dV`gc5a)hz~3VK3Np)C?xCjWmpJ3U^O(gDP+Wu62V};}JmaNu%uB$H9?IC4(KAk5mYL+C+MO4fbuCd+JBP z#C1~4egCh=jA(>^NnUf9ie43tqvHAfv|a-dM}iL<)bA*&!yi;L`Hib#KdH>U;ZBTN z{Nt)NAE}ZsUuInre{CK71aJ86in)<2XYd!g6qp$2&|lkZ`(iW;B)i6~@3ur~|Trvg~_aUCTQqK1~%P){>-ERkf>aVe}&Ql2swt zB6j~t=x)1&Sx>)d&%rO?HHYADxM$oaCu>lsfb|)>^zZyC4^vLO1>wKn@hu{) zbbtJ9&{i2@FJ;6z2pX3^vQS*{|0$6fvZoMcSJ6WMf=!du8 zYVOff$tvZ+IrGVRy6TY3=%Y`mz}1xu)PCSZ5T#J4$z%{Ef2hfq=LZ$g!Q2(#k zk%sW-IS}|OCVv{Zag3z65Xq|+9gq6U6v=?H#^{#EmZi8)5d5!+Fl{7GiW>BU(Z#Hw zAv!AA7wK%Wf;0-L+&d+%e8oliiLl!nh8F4U0uUf4Hmd)-Q5 zLU#6m;It(L-YlMBEGp|32WCnw?nG7qTxal`m(m>MA(|EUo(B~{i8zc;}ms?FUSCwwvu3XBWbpAO$g zC1^W9t(TW7RW2c}{O}z~^n*(B`*E|}hoU=K*SlZXJ0wo-A=sLc2>U@xcRK==w z=qEtff2rtrC^&f@d*e%$wT%y?aqmk#)S=m1x!(AWA1|h_)}_y9*@<|c;4x3?^TyC{RoPvr1U2(#SO z(jbEO5BXc%r{!4TzE(3rG9Yc#P0jYbi)=RxSATJhjZ*7~rKa)DtZb>>w8ChaSfjQQ zM#o=brsG}gV`rxW%(`1(E0$_2_EC!$M@alF(uJ+WW&`Rm{hQkwy>;L1c2Dohe61dw z4~+p|wv6Ua+{2Jnw#BscpQ+uyhx2Ow)fDTgm@6+3@|OTVKm6li3mC{}x5-%EHL6XGxRgQX7iou-1DtGWN=tgkFSuS7L}Iup2tTi&yu360a7`IKM%LeFphlH60 zRT{SDHCs!b0No6LU+<;r)_`&FL|afq`P#`GNAD;7)?CO&La}XvxQ{>xkAm9mM2bVn zY*Pc#6*|+fCTog|b+B>TV##d$!rTELZLZ9$0@1tz)d^LMWuTY|VjQvQQ>-=|_Xq6v zS8270ysOR{WoN_^*%vht=$Zd3Mi9C0#T?X;jp3453)HQ()2K^yRLaRyIhX{P)(iTujcJx#o@sTnZAE|e;q zc#sQuls|Cb(>ab!e2uB|&4q6ki=HlztSwq17rsdjqVyO_!jv2pp;Ne-%D8pWSsC}$ z3D@fK^iAH z0`{AA+qVm_mw#!|&189#dQ;S~`E^lm9#d~(W>Q4|?kr8frrN|rp30CfD$Pb}O;c$u zRz4E7X^w&KO|_b&md)w@G1F9|^Dp)%>glRZhU-gHoOc2+l#N}Sn4-coH6|n8(zuEP zbC3%27^A_NcGx0oq>i*j05pEMC~8f4ocX$`^$s;qQ`cd!yp6aa=z%H5IE_`HeQ)FJ zhoh>e$1Gvj7Gy%i+0NCzR#{O0cx+?eCg)Xq9A&NhlX#zyDeK|#HvGWVKHjh{-mq?* zU?_jS*aS7Z{5$b$P(7KegkPm{*AuaVd}P81msKN~K7{4MSPL;)H11z|5@>1|HH4Pt z3>YR{DBogX=~!Z_*g&~#pa5+b8gCDSTTZqdkXbV-yeAQVUHGADQ#FNHp=8Y2qW*#47;!(ExZzP*vOelyOlzopQH4DNoCr^0v(WhWnfBZ@Rz9 z{-*nz>~Ffi$^NGMo9s`q@(*4|G?o`--4E8=`do0UB@2b8OEM-4g{R+}WTEhMq44CI z`IsO3k)icu5pY=l`KS*WwL54Q4%>Hw=G`!2cW_!7Pm^n^ZPp&sBxV1eX5DzGG#Yi{ zQMCA<)s~vh-b6TdC^?6~F{iXXms_fzZSPbrS zF*wp!FVj~qR6RmtnRM;#eV>%m#X+Dmm>zn;ONB-<=b;gvaY?5>9^E@aTh(};`3v<& z6!=!IDqf=0JH%x@^~NNSn`xGl#tM2k zy1xICev4{wx`u{TM+g_Z;wJP?ywIulIRP3v6xHl=28L8;W@oGcV^|B$+G69H45K2LX}Z-FmGJthL7LbgZ>q5>P)r5;x9yi)2X^ z$5|tfsaZ_PVY+n3fi-W!nm42iHu)ABr#|5=8}J@2i~#~mj7CNO29P$wkl+>S!dZ1h zA-lXY9SOT#QNa$or3(qw!HU|&iP|?X>ih9cl5kS&jdd%B?j-|bLtID$DU3#yUX>4s z1mP%9tnO~LG6ScrT&97%k)cec0Yq?Zi2^{{iz9I#V(p(}`$(Z|VH9kzmZhvM(Bo16 zUWiODMW$EQFu7Yr>}#4>%DdTcrEf=e;Ka=RVlK|1W7wpEjr3&Pd_ zm4E5}WYss8l)f5ywxK3B`*q<4SNWOfRH+p+p)0}>4@_MlR>uUlRBGKPi+`$@g!d`s z$xb-3Q-18*Vs2aNWn;Do%@zUOHq>X}NY7l0E9GiWPRg&$Dv8A$FBY+~i|IDHFSYzc}blzMesxYtAj6bVneC?1R>0_2W*VrHfti(45l>)7Cw<*tW*+e3h^`rS`;pX!uag9rCiA6qvr5ETT7Zk zv*vJUaad+uG4b`4V+q$I??|Z>jhclm6coN$-GyB`HJfhDrc<+V{C45htwx?M&~R1r za`mXG*diLdAH#)dkobhgg=zSGKJN+6r(EDFv@%$DW!O{J=_yluoAKQ(e|Wh`8oed2 z@CYF|ma$Hv(R>EYDaFUPsaR}l7MoRg$Xj_vc4$ zDtz($$SfS8cqQwHR}+?rub<9VS@bh6$7cc9X?lMv@H!ZOb)g6@3-Todxevlz2Nf7S zjE}xT4U_TD?f52PMf(+2Sd3S`_~T~$QSnCN1?~ppchiLk zfd3BtpGm*^6d%zYb_w-OsICh11>=1oImtVkmJj29VgwzCkKcxxcGZ7QmVtg;D6Hc> zULkKGuDd-}@P4aY`;Wf=Px=*56^`+%id|_=$BZt&0_c&C`H?}8EM{FST`QioK+lU- z+YT=SXYoXJwNxS11pRf2|N6c_13rf#pF_gukn%a83OxNj2Q@j)XCDA;uI7X;Z}t^8wk9#y*FSK8Mdf2l^n=J_E40eCD#Gtd1;n)^>Bt;%`O=tF3h2W9&#K!X3NWNYoPzv2D{`y1|W zvcKv6Ci`f|vCRI2bR#kZjpa#s+UJn;IarfnKsn8dlT4USv*ILKanh_fNmiV+ zrj=xs$)A#po0`ar>?*M9l8g!unc0+Nk=<;Ooph1ie37i@jmAC&{=Pk5Ab;Op+*n~_ z{~l&kA1)b(TF_{`(g#<1muTZkXXNYpXp?^~srGRt{=D)5c8SI^y{yWRTuhN%pI2&v zXV+&R0eo_3Ec5qXsT;HCn@g%SWRfkH)LJa5RgE=+q-G#Z0>~LOUdg%A$MZ=n%GM0% zgQ*x7=cGT-%S=2sZ9|xbvOO^^i81|}A_4nSIT*m(>&|v$c6Wb=x)@zmF|KU~M8Atv zmk+(RVxSKo$z2RKU=L}Q0>hUwuq&8Dl;yTo)~VHXYh_(p*YUlgR%HJX%fT(7NNb1X z@0>NXnS!}Jws~!ApKm|q?SJ?0zv-o~@B(fAT1s>$ky_)-8YI>zTLuQUa%E~?xGt>f z#^!=rRG8XY-L%%#N~v`oC?Q&CK-`EI0Skuw&#LJgmlKAyu>96Hjy-a|s2xJmkwM5|2YJ4%q|Ad{ITeCIgw(IGY%8`00G& zHCQ;&`rbmM7mc4}@X6{q#bK!UUr-@`_#4+=07Z#&FYLPh5B`mwF&;3V8OoWsatYF% z==Z3>EsO}Nzk7D(WZdkVLDtoj?Ym(YZ+XfBHghNf3wv7ejs$osRaJt4xtxxV*L0Y& zgA=}>RK1)8fl>-%o1j0s+aW0-afOHZm$LqDJ0bB}U<#%z!2}RGg@e8=dRgiP|7hzU zfvGl?rI3db_1^wcv3T@9`Re72}R z5`TYwtL_eC3--;E#HywaVt_?9hFVi%S@CdI)q z#KmC-4>=9KS10LP;wllNLG3Z*@U6!dqB=lNL%Ms6Wq#W=kiMNE29Jj$(b?mjJy?Os z*cfnmG;w;kE|Kal-*)*e7UBZVOyw7X*HCd(aKMH06_Ktej}g=yan~Wc;4%53Fhg^d zC#0G;AtM_MliI3Q5hNixkU!LwKh#mip{tGq;R#*bo8w#H&)BYE!y-1yFgq+obCkJ- z>@4wapnkPc`U6g*Aq2y)lM9 z4=#PxJW0H2skJ6iek4_YBvF9mZkHtRd5ZNEgm|f(h{T(fs=<|D!Aa$#-1cBb*361I zUqyFsyD&fNcUS@T(s#HMTX$lE*>r10-mnI@V-%RX#ZHpQSIXn>pvx$C{Co^yq-S!> z-Ayz36!uaxIgIF~_=#)ziQNea*L;2lYucUVr7taLOp3eg8oNix_X7^_BhEC|LoMUe z%dKklg#*z1-sS7s#}LtHwdj? zZ-j$3*CS^v=1!Leaj|~+x1RyPG^*`->a2gh9e_oaFB&h3s(+nc6gwBiZkcMM1H32) zRE>-==Kn3%hhB$^>al@AnvW$)w%Q@U&;f6{4XY8r7O|iGxxO>qkf!3VkPYKF=egm6 zZaBHYHr$NE+6X26xpbx0T%iL6+0=y)pYu6*!4k@FY!+ z@PO;&l0C2J%VCVT+6RD7LD|0_aJycQcT87_l(E}`yDxoW*gVM#+A)lyf5Hm3H}0=4 zDd*ew0PstYco*{!|7?F@)!GYvk`GY2R7i6@gKx^s-gduZNK|)Q zU$aVvx_Ws{3&hjLafy)pG(m*l_|z1Vb~qLPqb3a}EM#nzB=P@BO&ob%36rDJB%kM| zqkDKhgYPslt30efY5b?BS*<)G&O0X1s2l`NN{{gF2*P~2fig$MR*uQ2Yme9S+_`7+ zS)5lS^9cBS?5yj85Z=~_7+l=Yx*!Y7z~0srcxUVEVW-%km#{n{vVUP7}@RP000D$#VG&) literal 0 HcmV?d00001 diff --git a/python/setup.py b/python/setup.py index 8ab3839c313f..bcdc5faa3107 100644 --- a/python/setup.py +++ b/python/setup.py @@ -277,6 +277,7 @@ def build_extension(self, ext): cmake_dir = get_cmake_dir() subprocess.check_call(["cmake", self.base_dir] + cmake_args, cwd=cmake_dir, env=env) subprocess.check_call(["cmake", "--build", "."] + build_args, cwd=cmake_dir) + subprocess.check_call(["cmake", "--build", ".", "--target", "mlir-doc"], cwd=cmake_dir) download_and_copy_ptxas() From 3dec616c7c14d2141277b37a334ef5857108a8ac Mon Sep 17 00:00:00 2001 From: Zahi Moudallal <128723247+zahimoud@users.noreply.github.com> Date: Wed, 6 Sep 2023 13:21:29 -0700 Subject: [PATCH 015/122] [CI] Fix submodule issue (#2253) ... --- .github/workflows/integration-tests.yml | 5 +++-- _deps/googletest-src | 1 - 2 files changed, 3 insertions(+), 3 deletions(-) delete mode 160000 _deps/googletest-src diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 1b83b566359f..097c55e590d1 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -44,8 +44,9 @@ jobs: steps: - name: Checkout - uses: actions/checkout@v2 - + uses: actions/checkout@v3 + with: + submodules: 'true' - name: Set CUDA ENV if: ${{(matrix.runner[0] == 'self-hosted') && (matrix.runner[1] == 'V100' || matrix.runner[1] == 'A100' || matrix.runner[1] == 'H100')}} run: | diff --git a/_deps/googletest-src b/_deps/googletest-src deleted file mode 160000 index 58d77fa8070e..000000000000 --- a/_deps/googletest-src +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 58d77fa8070e8cec2dc1ed015d66b454c8d78850 From f21b36c8c54f35a88e96d7217e2c6bc9cc02ee69 Mon Sep 17 00:00:00 2001 From: Zahi Moudallal <128723247+zahimoud@users.noreply.github.com> Date: Wed, 6 Sep 2023 13:42:42 -0700 Subject: [PATCH 016/122] [CLEANUP] Delete binaries that went in by mistake (#2256) --- .hypothesis/unicode_data/13.0.0/charmap.json.gz | Bin 20988 -> 0 bytes .../unicode_data/13.0.0/charmap.json.gz | Bin 20988 -> 0 bytes 2 files changed, 0 insertions(+), 0 deletions(-) delete mode 100644 .hypothesis/unicode_data/13.0.0/charmap.json.gz delete mode 100644 python/.hypothesis/unicode_data/13.0.0/charmap.json.gz diff --git a/.hypothesis/unicode_data/13.0.0/charmap.json.gz b/.hypothesis/unicode_data/13.0.0/charmap.json.gz deleted file mode 100644 index 63a9ba0ccf8ffbcfdc83fc4f7c9d7aee50930157..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 20988 zcmbUIV{|25`0fkGcGBtCM#r|@vDvZFvD2|_t=P7$PC7PMY?~`5&))y{J$sDv<&5*? zs;g#I%{l9yg;Do!7I73D95@&l7{sTGoufOOyRkh7yDON2{@w%$^wP?>=+V{D*7K2Oue8Bfd5C9;m7U7QVwezhto|&rPwL}ibz>`O}WJ8b3RL+uE# z8pVo1KtS62#M6$yRssG|gvCl*YG;ZI z$xY9>hDDpM4B}j3a=eSF6qA1ULewUg)RG-+5iNXg)~o$4lIn7U<33)ZV^6gpq-HAb z?Y6s*qPGv_u!WV_u08_)$xKqL7p09H{2rwXBg;)^`IYrl8Dnt_n{Cdoy|HR?uBp6P zya?b8<9P)AL|nINC2*6Gl7D?)*LD4r(sdf+iNSTc!MtNlrQ)fqB@rWWD!`LT3`iL( zpTp!7TFAC7z)M`rM=i38`?He1Q&$E*v1f4IiE7&M`j9}NH;CHJCC6epo!iR~_UZhX z7gch9Qx-2il!BoA#`9Xk^9MI>yrcW5ru&oq_hK*SOE+l*X_4~1v6?GugO2%EDXZ6AOHlMD|n<&@+7 zrgr>u+$DM2+%^HLjoLerciNnGo-WAMTpChE3U zOI1tqsf=Q`$(`T;4fNBw{i+2;kDRaj_2#Vmk?^k2I3A!mrHS;$M|--gCE0PYyz-}>>uCjx$<{@Q z+T8>{KTsXxGRH8pu|~(S)=Xx;-A;6%d87CpWPif@eX7+Z)b*6thQ2(O|8=vRrEO}? zonxJ0vgvGru!_-TwbrY=x=0&i^k^i-PJdM$+e^j==#D?r_2Opx&sEoX)BX?8$>-9b za!s9J1<^dF^B|!xcX0FTj^kWP{kjA)^qRT0K%In%*c5)PjyZD6?WfgrX1OXKPn`$N z=Xgx+6IX6ly}BKfTO5c$#<;G%Ls>il{Hv8sH(8FF){zOqCfCOWfeCP#Ptr4iwbE7l zGVc;)N7#$*sZI$dcUckVL zCl;Qm8WmMv`_o7#>*hAyTtb`Y&AH#N*8hOjMH(%MBSSf&B(sujjIsN!Tb!h`yRWA| zisxbCcaono6icda-rh9cPDe?nlz)A3SkAinAS=KRcIFJm_I%2h%l7 zJaa2d2+^WL#{2sKk)PisNYKko3?u~ky*yPKggr7Gb~e&ZY;ExOvOlch#<$aL=JdU| zANzDBgRv)~{!^dqcoOk_?lxcR96Uoc^zVzWl4RuvTa+d`W(s}FI(Mk(-TxH2B*tdI zd&TTgDR(gW6lF>QasFy2(&gYcByQF#%;%SS@HZrVXr+oZ)HA%A>c2TV_SlNMe48b| z?p|7S`}X|$sek2J?a?0+XCv)U9IkR)SFBmtX>NUTFB^GK z1ZZsbN)=I`OzdiJyY<_kZ88{MNAd;MM{8Gh@#Qqsd{$U<`#Jh-5;ACMTFgH+mbkDd zVotPHZEK&Z|D9{N{K+s@ez&;@+px4!kN>)^@FB+1F*);6nEbFF3}RZyv@Y5A+N?N9 zb@Hp7Ctn4aA9MzVczYZWY3~)|5ec#URd1x&1Ifpbn!Ekw@T=IN zV{Ui(S)0QFCVMFdd}4)tY$858?dCS%^{oD*gOGKQN05qYR-3JrU0L5nbLQim|JlxU zK)@~7-68r>DS38Q9%9N=4ek#6p=;63ckE#GzK9Sl(fZ=y0X8c=@{&5Tw@IsxA0yzX z)1=CJDaDa&vF+h$dWq7T1~TN6(I)EB9yjOU6(V3;JbrQ<$*Y%-GhBYr8?@PGPpy$`T zItLq{W;PNzsa3Q&E)0T6~@dL_2@g^UmRfT&68QC zS|g>HL`^8`(%`FiYW`*DVdETs5JK5O|CLqy7GYPdJJ!il#EttTflU>$x`x$1wY`e^ z=o{ax(s`Bg;${uLpoUn7w%$eQOa7j>XUgI>ZG&Xh-|f`)g-4z#+l;e{>T`2A=2yEyg$C>`Mug@YSsAX=A2Cpm84%fOPs zVO+OmISe#X3X<&pTqIaGKUTcc)wX&E9lyPZ$Vfjoa>GD=o{ssANq2z4e$9*91pyP^KXvs*Ppczf)ud5Pghda4!^ul zE*E_8CR^;Y9t1V!DF>PDy*{r{F2gPWmT)*55@XBWystvpUk{NYen1!PnHTD$`-`{r zk&>G~><96G9RnTXM@>EDzxf`|2o28G9V72uaHg^t*pvH6KBxX$NkSlIwOqZN8a#UZ$?85;+;%Jj{t z=1yJgC0S7Om)#5B_F>I|Cg#F;dnahgiXhKb=>fNZFN>2>U+cBMj&{j+&jeE64 zbI_jR@|9w>;0Fz*#fP7sNtZvn zds}zjoYOLNgud#&xm*z5h1%X9sAXz_C6+}nON7ze^H`_eohihUiKmq3&Pk;S!6XP8YXi}oW(Sc&Eip^HaU0p`@Fv=_R@>y{oxY6kOh!w!MZ~r_9ZJ@}1Tb32_XQC67JT;X_zABAsZ{~b z=i)1Q57a(^x*G709fF=8E)Vm-_RuK~gK3kqqz^$E-(COyC&G(2ae?XOX(V5aD@@RQ z^HyzYjbk^x-l;{mUuum-_nCG_?n8~mnb+f)C#o;&7IlZo+XLq7)gxjWq6DDpEvX~d zXE}+N3dkb$LGTnbhOX;hP}=zB11J)_>0JprTD#J=-6Zy&vx#%xvt*=cL-a!NT)!k5 zgfGs0NpHauy7BE9x4p%T`)65h_yRBpiMKxoIeH~oeBDuR!Jh)DwanPS{%ak_r5f z&ac4r`tn9VVDj^fj_p0Asbg)gjBnCHD06fwYxL)lwB!EKr9#i9RC?&z9*1V7BD;?5 zpJgI390!4JZ=36+<}K#82gHc)gb3*Cx(*8XL>0^gn&WAZVM;luZsc|Rg7Zs z6%*=DjY_Fak9RttR;TKTKAgF&ntdMI3GNPQIBQ*9hh=17>|a9zvyOeG$)z>qzoU#> zxKz5pHy?s(#`Us8ViKyOK_v_pS|P@|Btn2zlI=@Mx%XU4e{<3AUo3sGd(QVO&PVx7 zpWycg2sBaN$Cv7c>5Q}e&nGUGL0900pMO!X>rfY`XSjOb=zDndbUZar|G=57fU{t+ z)(2=2pkL3q<=4R>vge}43+!BjjbqNC)bZ_{TMGp@K#s+>wY*ke>jc=qSF2kqf8Rpt z8$hE&$>Yx7#xrsdZ2@1Q(!)2h6RiPnqq4;duoeY?`%s;0Um(IumSv!sfJHm?`gsSV zl9R~OFbVgxp(>>@>j!%~5nul(>_o6NWB?VwFreJ zXGtGW9b5_#Ad5o_R?=p;ETe!1ZaEAX1>=^bsesJGLjM^c1OHvU_ZtTM3>|6ZX$X#L z8y#3durx%x3^!FkQnWEdxJ+5`+V|XYKl**GnBx9=b}*eiEj*4?m{APjAu%mZu*u+H z1mT2WaR@XSZi>F#XiW%V85|0*#5Pn%j4NsqD+0D~aTL-!IB^WR^2>Hd;Q(;o0c!Ag z3Vsy{=x-+C!Ew`gTkSHw*d!6F1g#Rusc+$Z2%SynIqD4930lL-=o0GknCJ?_-_i9= z8R}n6v(e&B^5Mp{;nO zSsIF+3YNNOkg-V8*S=miANmg<*>#-<(2t(p$}M0sR!|M__dhVQniR~6)!jG56m1#_ zuzc)M$bz(}$IFT*<@6i%X??tnild(A4NZY3$_-o+E81+_g0jd}7Ey>#%mAK&12w>3 z!CwL6ukwAa+&|HLu^+D&=DfHyz;Of#DjmV34Fe!NJnX50Fh~&dApfPP6w0)wbFBeO%%{wl8?AM(%Nmr=j(AcvO$PJA);!sN~p9?!lCy%i3Y83F@NN?HL;F6DAv+EQQJ=71P zwP~!*mbzSfF5A_0wFUh|wvVtPY2WQDX9cYui9{MxT06iIH#56`-J`7vHNlH}aZ?P$ zN3)WRjir-`MpB>~nfjpbNJ$c-XBQb9%QO(v&dLcAU{E}WDkZZh-_I0xYurn%;!E_V zbsO9ptnP&v#QK`v7ZASrAAcl>`zLfu+#{_riC-o9iq0$wJwPyPw1VPTp1TnUQ71P3`u{rI;ys<5v)>2(o&Q9f#B{iX*BN*34JbbI5pGufmG=CV2hnCe`n7UDKX;6sq=L2i^1W z*Fp+Fwrp8V$`ugrtbj@$7~#8z8e^>(@%1I;v@yKueWj)|1^jWUKk{cRVTEimQyZ&I z*Tz~*tcGO^a;$XY9}DfK^7Jk7<2m!sdttn1mbL@fDP}eF664KEW)U7bWj?X*XU9hW z%<{79{6VRyrno2TB_~l-KwslzHbI$8hh#gPiH5&c#e6IW=bcF|bu8_|``d^s1&gzi z%nogwG&P&=!YvvS5Wm~4gd+C57m$gj1B68Mv6MgtgB2Zn{K}q)?l0XTlEy;D56&QX ztwToK0DNbn3cT_A2YnY1MMXrdVXkjTectV6qY9Rkr)yghVTwqtM064%RHS+17v#j4 zBC*&UkK0I%eDdEzh5f{6dP%oa;7EPQc)du~;`4 zJDONG(mm?}BE-LO)J{e(<@+CKNXWy7;KTLHbn&WafH9!u8$fi! zkrTmHNtJ!?BNA1X%G(lC4g-T?BLJ5al%N5SfrXX-cT_!MnfchXi|&6eUDu8>|9gzS zFs~R~mM}X*a*ESBxXi3l^yb2fL{D9UYO6r}0^*3h^*L*J4qj(KbSU{+M12-sXGDEv zn=T0IcqEV-470(2J2)iSm_RqlyZ@ z$Mg^3sI|C~>`^5qEI{pvXn``evS3Wc0Ue3UQl>f$gchWY6{lm z+7bz)<~KVG>d5vi4nIy%Bw!P}+Y%&x-|=7P4;6*w#li!H)%cRZ@xpgP$*`Tq*9jY% zj~x_NrQq>t4$>6+*i~}10#WYPUpFPqWu#^A$#l_s-ja_5>sj%=)5p?fwN+N*6AD(0 zAW9O(H{}CjRXCMe;APqg!5Hvr)59T)s^a+1|@Qe~6SGD(xYkRQBB3!AADMT?hs(5do+-&Iy$^!_BD)~3pW*=doG zR<)rqSb`C|F^43HM2InaJZgJ~%ZO>gx6+Dn2kqayoXckBwM4-n7yg>CoMFd1B4iR< zNTne~i>=%s#4DV__*07-V}@5WLf-coHj-K`uluyTV(9E2^TF}I&A;CnEo;=1k5-+& zbB=(P4Qhoat99Qw4?xQn_2iS)o9|ozU~hGcwP%w!=GltgZ!Q|pvX^l%ci6aZHZG)V zF!J7bXf`F}Z2O$`(S)*-5_rCDM?^S+MBN*?HmN1 z~i zvFLAVHJa0BzF*o7|K83UH!<_x!4@cMy~f=t6npyS%CsGNVlpHOUk!bb~qg z<*BUUpQUZsDa1n)b|AL;f%+)4@oDrq_+_{@fJXjDb(9+Ul#D}oqyu{Pw*=Q@g`ooYhX{m zmglwYy(6H^H=F-$2nc0I!sgJ)%u`$~QXcm26oo$jPjSa4{VOt~W$y0>`g;WUe_7EeeoE@(g@%W?_^5E-hhpKH;shRVBI~y0A4mYeOxnvw-b;Oybnlax ztvGil^P8U!vCd-n`>fXC%cM~}L9OT3O*FJu7g6l;-i;TxLEPxsj)wqI?a};|>qpq$ zgUA14%l}zOG#zF7yNfWcb@0ZGe+y-F9rw!P@#@Bnduz%^yYhT`@50fCFpdca-I9TA z!Zm{wLc3sbjE_1IPQp#5(lzUXhTb@v{LWc}dbeikcjiB`3V2P{hR4Ib7g2@}&hgB( zEN$ydMf%NXZL_Ie8FPsPj7R{Dy9FaaR*G?9<=%WNKia;yyX%`ki250Eat+7krJ56+k8iD)mk0leCh~C^0_rxYS z%;}k=rmru0!YluZd|CNpz9FoWe~PS$=^3?Yr3tjWGx*Zt2ZM1#Pni6Bu>AXy*w_y$ z6aTXvpR^js-p?M~l;;-ZF!wkM;NzJ;-wSvH{kQihk7qrv^WJMA z{m90uNO8F9J1#Ik&O68C;lOOz8}4N9j1{@VHJW)CPY41BumMVR3~fl!XhmeSXeuPh zL6REASxqJQAP0F^x*DR0=VjjKkOn~vS!8l_j|}%{%@=tt)&@7W9SVA z#=aOoKJrly3WLm5uFVdVu*qgZN49IsWV?$uBg*O4%~Ho?vyJz~z=l@L(NdiA*~GQO zzw1^X9K6{zF3$QN>z3cSCug}$0JyKbKw(L1h1p4G=RLh@Br_w=85$FB<{)qGJt$`@ z_ae|%c~ErhVKj;nWQlQC^WfD73W}sZeR4 zZfL_TeZnM-DXYrgH4FoynIv-`PLG1j^9UH{0=tG^`GleQd533X^tu#-`;L`QufcS46lZ_OA;hF zDs5CB=0h8d%p|1unn(7t%vzjwLfsX{hXv5M{z~Ru39_cAlDwvO+hOc?r(s8v5$jmW zN5S&`9-32f-)y@mr!?no%8f-%9KY{Q-y@c-&ySN52qqI9xqY<0do&1FzHBfy1urAD z+v|{ndedLSZ77vpYcqKge*Q3nDVF)M8L5np6dM`o+O@5pP#cTB%0cEGki|Uo5@uRY zG~yj-rNW7$P}2D=YA2w`0`cxp2iBNHZ=Vl~jZ6#roAxnIr~YhEtcQqZ*S7Vv5#4(< z9c7i8mWm_O__~7Y!kQ(lS)l7QieQo7P=jY5LV0#^tov|#1fi{te!DDmd8la@UF-r* zu;IIX>vbhDH*$y$w2vCU>>gr(dWheMq-YZF~INW>oINE)cS*8Rl@u7#RolwoDHgCWtRdc9TIYTVi!M~p z&a^exLdCDzgNSLkbe9{F5^=!t$81k1@t4~}ST>nSIPNx6>oj)^r~_(m?VzTN1WBvP z=lvMcom4mIF(nP{C|#uo21RA$1Se(g7M5se&`W%g)p~7DMB7<9ryyP(WZn$>h#7Xk zv9H1CRkRe|e?DG`%DbNUAX;ec`gSvNe+E|ya3rRkKSQ(mSX~&V7pMNF`YwOnsWwX{ zd2c_V(4Jwe}?x0WUd(;jUZgg8Pvq8({W$1+IM8lTSK)c%5`d;}JT5%N@KNF?bLmxNd+SJ-uJ_rj*1gn1?*?DHw+w1I_9#iJdml zZoRRuX{?2?d|leSE9Pa>a?|!`#J*%ClpN8613~r$S9_JV^krS}w$*MzrtX0KG4Vs8 z5ISy7vb37C#6Q!}jsg62ZV6q)M`#H{-Pyl4dCm$d4+GgyaI1jh@N@me|8;@^vA74f zePryIrI#QnKC0HvyXDO?#`AG-!W*`^s?^OiXH^H&KTE(9C%ci!dkMR=@qzy?`}MKa zuA8NB+cWY6cX~9NZo1%Q4oswGgq1(!P&{{EE`6$%$xU6yy@4~s(6Gsow}(YcWMyd= ze(+|~=Qndo$!ail+N<>f@tjp@#vkgdE4$FD-al`in)Y{p6W$j+(M;x4xh8=l@_+KP z2M4wlL~pjd0zk>gIzzwT4cykWX#LvF{CLbw(t)uFDfQhwB(pd%W<;NUaUnM87H(j? znlaVDI3HI(qW(`N#W@P#U=xbckXzW&t;JaekjcOI#T#hY*Kim|9CF{BVrS9Ar?dW; z{F@Z-ITXHr3?y3jZ@p83IlsVp7LFnQ34VTH7FIzDog{(#H zc9r99edqN8w?~fY2T zQzVk+7JD3*;Rcpr_R?mvhrl>;IpOVm1e2pvzwViqV509k&wdvEAz<(+l3!zly++Y% zc*(e@^SO=_>8&7S{V)iU=aQKiKd+&6H4ZiQ-H}wVU4(Vh%3lAr+)m=vCJdMbu&?*O zyV-OY`z*O>PV%vjWdAmcAFD}G6u#NF%=vFd?n(u2wH|(jQHSfUb>~$yVUO%LIWVhB z+1(;>n(^P2+ZgaMV>$*nZH4=8>s={SHOe>eZj(+Nd8oj;wY=-Re(0W51r)Obj%rpo2LS^!Pjh-Y0-`8Dpvx|rq@;{)lT}@C%S#UdHFH+-L^n_s)xvjH{; z`D@9qhBtZkh34-Bw!=hCUX%n6rNs2M*@$sz(r2Y%p06TP|6WnX_3P(&biG6hmYrHs z3|jduzL*Pg+*WFLh71@eWHMcl8?ISo01<6oe)%zh*qyZ@kzSDRT7Q9hVrsujSMQ3k z^=c=bUu8Yd{Bz*^8)n|af^)x^E@YSRX{+rF+~>Rd%<2N2oP|x);t}@#)x_=*g39pE z#~GAjezJIfg!R28Vl3KE+qYqBc_eRpWb0`0-N$?z>Rsv^#h;p_+hx1Wf+U$@*Lu`?=lT~obE1n^Lc<4rKx(g zxO#D(0av}>Ynn}bMaeUO-afMiG(C;>Ahz*KAVL9&zC;_owlb& zxXFB(>^&n=J-A`pV)8@{v9u+4X;RP!v8jJ7d*V(nbR;m0^^>m$$Ocnd&v^3SmZk zS56aT2~TJtw(LNX3vq+WK7$-c-1xT0D^i2B*-?Gy5-YhQiENqMvF}rvd+g5wY6o=uo>c>B+TQ&)7cA${{ z>!O!tn@h;qL*v4<*K(uFS7vycHD|Vys2HN=<2d@BVW}`iM;1cw2IL(MMZ5Ja()M(A z?>yxT){GJmIvWU$7KsV5i@zJb3FIj7R(*_2!;GU9iS5mzounTW@s$vd#!p~E{Xk;P z3}%EogSDIR694fxpkN?WVq0Yt%NG$arEhf2`OQ~OH4nXm;8?e1u;w$Wjvm|TPnhzV zR5utELp(DQdCbX>IC}kt0)jIMvJ?9^1hGZtkV!NS)pR6f){tQ|a@B1_ zWoGC`vprOe*OS%2d#aXgn&OG-_eSd1aT?kn#e%y9D5qzi=={tZ8UAI24TfRoAyN8WMKgadb^ zPr$A5@`vt%2o>9|SHG&gN1*vONAtxo3fO~9giS+@Sg~SAXNV#E1G<$KLbO2pB?eyv z!x>B{cR*b_E1n*e7DF~3?5F8XyBH&u#7@Q04^1wD%{kmWW@+63(HbQZvbeu~0bJ_& z(oncq#=<^I2jFPmu%T&Yu)?Nuih3dR7WEiOpc`PIAz)N17yaVl?(p((2n;^UZP1V) zIO!gAMTm7zT{Q0PsGX0U6h3kdg4rsp5|H@Lsb~1J!V; zt@bNUm>IPc_r~~%YOza?RuQWxh2I{*DvEAXT*X*-K{VdDInCc>@fY%SzCNYfd3*wsf2Ea&U*8tI2exwdu;c16F+iZW3}P^!J74@Gt@ zi+598HZ#i0tt_G15r2O)-dOo%f>E7Gl(Jt&B}&OsK`~dvyCE;C_8_`G#VIIR$$nGP zVOj!C+1wbk6bx7g6+T$3Uxk}41Q!%xa=_Ej3R5QJu`A5VjQ^C$J6o^}@!Bc^aM2tq z;CTf_ZWR@B(VQsgbTmXBtl0%Pu=jJ}X<=p!cpxixm5>+&p|(Zv*8kzFkKnKWBT!GP zH-k25L#EWlxO2WP6aJ6KX&jDClHk#NK{pd=Az>Qp<2s#@XJH%Y(hlb(_e~G3FUp!8 z-VEbFONw?5?H@__mS*|sO4&)B`W$#jrC3O%LzIEq7Wi=#nlvq#N=5aUx?p^zPa<#8 z4=d5kL?uY!AE1CRAd2lF2HaG?rRqw%46sCW7MxoLE*3kWFvN%$d_n5J+<%G|m|FFJ zDeZCGnRObAUtnr@ZpE;4Y(q*+WpAW7GR$Wp*YjZS>}n7{v=Rns(f@b^3|RcJbb(v> z2IjV?f#H=vp8SI8b_Q{;t}AmaK~Axqfc))RIN@4}hd{#wLF5La{#B_A)N(8ltBZw( z+Ab4hV9J*O$?;=f6{#p;;j02=o;0K8u&?GWFc2mKkOZL~5sP$#jUf7ILWJ$Wm2_!_ z@M#Q*ZVh29y~P{@1DM#<QhK(HtS0TO`otaDYQIbTC zHn{|gSxEFz3akJ}Rs?>ZIV3S1qCG9YsZCf#kA~mri_@@#ubF?fG!FYaMG?bNmL2BT z%2(hbH>{BcX5h7X!BtED7f>OoBhL?n`!oLgzbU@#FyM_<{|gd%FvGvv|k+A|D!fYT5CqM90FScN%ypSvnR9 z}FMTQS#fK@s7Fjcj zB?i%oG1RV9WY20y&QA3Jhgea}a^(@1-7mBnAaZB10>Whv81Y;x`KPd+RaBl1lk_uC zrnJA}-<>VQ=`zaX)bvM!weKrPK2WhViZG}G_bZpbyg_`nEgz&#t#G#r5MT&e`IzMg zv`kYY-&-O2&WQtBW~pQDt>}H{(13tw6B?EgxF*QejJE6nUaXA=RE_MF#VEuQNqKYS2= z)3ph8fazfrEAWDsbb}??2UrR96_FdQVKSc}q!Udi?Rk3fD|SGZe1IpR{IQAyt7K&V zu4`G9H*YRj8@54o8Fw|V+tfwd z3&bQ~n7c#)ZoAxSUDOH9lQ7JYF#P{CU;sD9N_oYwLntugK)Z_Yi)S#ceyx>6j!$U$ z(0nlrp}_n!D-3Kaj8rQOFiCNfj1L}UsO^fA;I_Zf-0Z2;W=XY4k-kE#Zgcfl_RJh6 z5Okd?UkCcP`Ijv5x;B%$!+wRDYlFCd9R~49p{~^t=4<+0$(u>N{U30fu>TurR8O7o7V{zf$#r` zH6CO}eCHAXrst#E2UA^{x&MPM%G2Wk{4_a(-g)LFG7Q^0@!%(nEj7`}71PNzp$ibp%rs%<-aT(scCubtT3a9f z!?nRY>_&FYe9&~1R~%W0*9Gl9_Cl-f5$wauLv#bL9+~uVEK|0h!$ySuihc9f5Bgm# z^!nUgKKyy8V~B^2*JJl1lEMC{3nGl#^~b!Sf6*=z1Xhy2&pnhs&~G%p?(J>2R4b20 z{noc&*9Z7Q`}S0R4lI;Ap?kCFxQ}DWqkDVu;}GvO?tL4DQ=F~sQWR&ug+mfdXz-7_ zd`~i56V^kks{{PnpXX8&bQGlXq-)x|)<&zl);v?rXFvz=tIfB5yYM)e>`~}_dz!23 znb{0~Vj>f7Fm`xR(QZ&@K_PPlT}%sayUc{T&xX3I0Z!WfAJ(#kgcW`huymT;@3 zJSQ8I_({7+etqe@PdrLzj(89-V%vwGmR`yhMI zTZP}7HZjwP&`q(g7OcvE!o~o#Z6O#i2dnkLPsB}}C<|ayB}Gn_%DFgc*cM)#V2XNn zXBRu(iAaf~>o~RR%;uaug0LY^GfK~f%RnlLz7X3{BTtA|)thW0NFIzxmtL@ur01Yw zb)yzVp`t<|7D}h#Apj>Yv7$*ZOMJ=0n}25T>&>Qt@0zQ-;E_Jzv9&=_F-B1}Mp2PQ zQI-B5<@aJwO=`d-=j#2gW6hf2ijbwum-SnV%0RAT9ZC-}2Gfc6t`)#ftN8V=jJJdZ zr+U)(9Vq*{n}!cP?1S(rOg!9rE`m$`pH*tVhkfkm|3~@w#7Rv0WR6 zdrx#Is(d2a5wt-O!lF0oyGXD-81#bA<`tf-m{u*AW`PmRKaF9xs_>}CnVl?plHIxI zgOz{!Q)rbkr8hV}jR%TUW~ zMV@L+n#&xnktcL@EMzP86LnlkeNQ}lP#j<23IzufZz87wLlv(bIulcI&Dc1GY?W|^ zfIYK3{dq3u8R9ZzJnkhMhbO3 z6m_7N;?ndh)gfjG>P2+CG~1E(XfW4aCs|^#nCLT>2z|y?QXp_80M`nv@4brcX1ZXf(e{yloX%Hi!6 z5ZswK@CdC_2rXK?KMs{zw;%^$h^0cme*xHh)eH7~1nr#dcVo;BuLz8L)(9m6~eRfiHHA)?bVuY6iz zME1!H8#1eQyh5P-l)ma^qSJD4XExIN<@}@xlcH2T(wpP)Zx$y;LcB+d=(rQ`aa6F= zhL#4ocmtBCT}Y1)#BWy6(S~q;!Ljz^ANcc5kc-*kLUl?N;Jl5n!j2VSW$TbrjPUzX zkaxymWc`s-KzPDY$oa!C3|ILD`j9kLJGV_$;BS$A1)1e6_fS*xOPulF`AjoVp-Mu> zA(85i;!ZHC>(YZmi2`_0cd}r?ro`YZe(NKdeIjwUD!8z4()HPEJJhL%*)D=pOu!u> zQBWv5fd67vdo)_Ml`~|B(FJ4MXw6SQStnKTCahKBf*Ca>UR{?;SZ{r{KnO=pH$}Fa3@AB z{&7{Ck5oyRFS9O*zqSs3f;W73#oS1iGx!T#3QUZ1=&xJ;}@pxx;1v2#9W7%yBC28ea2eTT>x4%e6!D3@TVW$zaC9e1(%*ZF$cW=@w z8Im^Q+gzx-OO9SBh>INxpWsC?_8#!i_?b*+U91$<5<)Prav#ZocZKDU3Exi^wFnO;Oa^SYCmuyh*BujWHN}7Kh)&Q^M#uH4Wi@-m9s_~ zu@5t84XU5qSve#DZmnFsiFtprv7kp(1Zy+G6vc^4m%!VdslRpjII7U)jh~!m^jz|4viex}pV|2@7%TnAY2>#bZm^Kn8MGbnw=wepT z5FM55i*zzT%?%MA+>OLyL5F0SJ}$@F0spGD*@Hb5aQh^=-w(y>6v2 zAv=3OaN3drZx+um7L|3212ZKaRk41ha}bD_qKQOx9~Q24NRWG#UbaVmGivb(1;Mx& z{6U_MCp2BI)i7|Aw5FX55GtA_W@SiP*-u~D9|Ms2OHkj)aNYdP^Ow}B_+7;Qs(McR z!X4uuCmlRQz7E1(nFrMUxM6%XOpOqY|I~)kKx9TF$^5Nypzg#DnUyB&c_a( zRY-Gd77@_KfUv3!A=MEp>xGJE0dowBK2_%b8-GRKJ0-%U^- z^b;WLzf^QQ6r4Pdz44{W+QtXcxc8+V>d@@1TyK2Gj~CNd>(b}5>_oiJ@)O{;1qrmq ziR>&g%Ocb0aRGcUGfOM@=_g6)=0j@SsicrbAkQf=J6+XbCjw zMyHEDExvws$|?S19WHif#*#_!8YGAw#Xi@^g&tQ&@y^kj+gp~vU6euPCvtcvgjsHC zX%NBthx{$>({ikEU#l4*8IU&Wre^!zMYbD;tG~F$MyYkgQqy>6R<_h`T4A(ItWjGD zqvJ0z)A26$v9r?wX5B5Y6-%`h`>4f>BP9M7>B3fGvjO#({>^QT-n#F0yQg<$zE%&; zhsJ;}TSoII?qSF(+hSV!&(!YU!+ACTYKrw#%#{}i`AdMGAO7*M1q|e~+vKc_Lek`C zk>SHxZB+Y)U}H`6mvadiwjDD)G6QIl{eh>#YI*PeKmMqC8F*r)dctliZ+D6;+Lb9i zfYG>=@%L-VuL9}*^zZx3v!uy#sSQPPSnE9+8GB_+LNCI-D>1|p*bSiD612PS!xa6_ z`Wb(2Nfrte`O#@Y$IaTGp;hI~vvm4rc)rJi0h+FzCGt)qc>XwiypIJ_z8+hQxMLw zV|iZ%{1A5lLZ{oQg#IU&v=aW07>$r&p4zFjZb_kb)*2aU*DU6IjN7KsWdn4)L&D60 zDh*rnnyn>IfNloBulG`QYrr^oqAe(*eC=e8qxX}3Yc6CXq1ZM-+(#gUM?vj&BE_L( zwyA;W3Y}?KlQqS~I@maEv1B%WVeWvBHdkg=foNWV>VzuBGEmF}F^<^uDOMYf`vdm- ztF+og-c@IfvNK|d?2DQR^vwSiMi9C0#T?X;jp3453)HQ()2K^yRLaRyIhX{P)(iTujcJx#o@sTnZAE|e;q zc#sQuls|Cb(>ab!e2uB|&4q6ki=HlztSwq17rsdjqVyO_!jv2pp;Ne-%D8pWSsC}$ z3D@fK^iAH z0`{AA+qVm_mw#!|&189#dQ;S~`E^lm9#d~(W>Q4|?kr8frrN|rp30CfD$Pb}O;c$u zRz4E7X^w&KO|_b&md)w@G1F9|^Dp)%>glRZhU-gHoOc2+l#N}Sn4-coH6|n8(zuEP zbC3%27^A_NcGx0oq>i*j05pEMC~8f4ocX$`^$s;qQ`cd!yp6aa=z%H5IE_`HeQ)FJ zhoh>e$1Gvj7Gy%i+0NCzR#{O0cx+?eCg)Xq9A&NhlX#zyDeK|#HvGWVKHjh{-mq?* zU?_jS*aS7Z{5$b$P(7KegkPm{*AuaVd}P81msKN~K7{4MSPL;)H11z|5@>1|HH4Pt z3>YR{DBogX=~!Z_*g&~#pa5+b8gCDSTTZqdkXbV-yeAQVUHGADQ#FNHp=8Y2qW*#47;!(ExZzP*vOelyOlzopQH4DNoCr^0v(WhWnfBZ@Rz9 z{-*nz>~Ffi$^NGMo9s`q@(*4|G?o`--4E8=`do0UB@2b8OEM-4g{R+}WTEhMq44CI z`IsO3k)icu5pY=l`KS*WwL54Q4%>Hw=G`!2cW_!7Pm^n^ZPp&sBxV1eX5DzGG#Yi{ zQMCA<)s~vh-b6TdC^?6~F{iXXms_fzZSPbrS zF*wp!FVj~qR6RmtnRM;#eV>%m#X+Dmm>zn;ONB-<=b;gvaY?5>9^E@aTh(};`3v<& z6!=!IDqf=0JH%x@^~NNSn`xGl#tM2k zy1xICev4{wx`u{TM+g_Z;wJP?ywIulIRP3v6xHl=28L8;W@oGcV^|B$+G69H45K2LX}Z-FmGJthL7LbgZ>q5>P)r5;x9yi)2X^ z$5|tfsaZ_PVY+n3fi-W!nm42iHu)ABr#|5=8}J@2i~#~mj7CNO29P$wkl+>S!dZ1h zA-lXY9SOT#QNa$or3(qw!HU|&iP|?X>ih9cl5kS&jdd%B?j-|bLtID$DU3#yUX>4s z1mP%9tnO~LG6ScrT&97%k)cec0Yq?Zi2^{{iz9I#V(p(}`$(Z|VH9kzmZhvM(Bo16 zUWiODMW$EQFu7Yr>}#4>%DdTcrEf=e;Ka=RVlK|1W7wpEjr3&Pd_ zm4E5}WYss8l)f5ywxK3B`*q<4SNWOfRH+p+p)0}>4@_MlR>uUlRBGKPi+`$@g!d`s z$xb-3Q-18*Vs2aNWn;Do%@zUOHq>X}NY7l0E9GiWPRg&$Dv8A$FBY+~i|IDHFSYzc}blzMesxYtAj6bVneC?1R>0_2W*VrHfti(45l>)7Cw<*tW*+e3h^`rS`;pX!uag9rCiA6qvr5ETT7Zk zv*vJUaad+uG4b`4V+q$I??|Z>jhclm6coN$-GyB`HJfhDrc<+V{C45htwx?M&~R1r za`mXG*diLdAH#)dkobhgg=zSGKJN+6r(EDFv@%$DW!O{J=_yluoAKQ(e|Wh`8oed2 z@CYF|ma$Hv(R>EYDaFUPsaR}l7MoRg$Xj_vc4$ zDtz($$SfS8cqQwHR}+?rub<9VS@bh6$7cc9X?lMv@H!ZOb)g6@3-Todxevlz2Nf7S zjE}xT4U_TD?f52PMf(+2Sd3S`_~T~$QSnCN1?~ppchiLk zfd3BtpGm*^6d%zYb_w-OsICh11>=1oImtVkmJj29VgwzCkKcxxcGZ7QmVtg;D6Hc> zULkKGuDd-}@P4aY`;Wf=Px=*56^`+%id|_=$BZt&0_c&C`H?}8EM{FST`QioK+lU- z+YT=SXYoXJwNxS11pRf2|N6c_13rf#pF_gukn%a83OxNj2Q@j)XCDA;uI7X;Z}t^8wk9#y*FSK8Mdf2l^n=J_E40eCD#Gtd1;n)^>Bt;%`O=tF3h2W9&#K!X3NWNYoPzv2D{`y1|W zvcKv6Ci`f|vCRI2bR#kZjpa#s+UJn;IarfnKsn8dlT4USv*ILKanh_fNmiV+ zrj=xs$)A#po0`ar>?*M9l8g!unc0+Nk=<;Ooph1ie37i@jmAC&{=Pk5Ab;Op+*n~_ z{~l&kA1)b(TF_{`(g#<1muTZkXXNYpXp?^~srGRt{=D)5c8SI^y{yWRTuhN%pI2&v zXV+&R0eo_3Ec5qXsT;HCn@g%SWRfkH)LJa5RgE=+q-G#Z0>~LOUdg%A$MZ=n%GM0% zgQ*x7=cGT-%S=2sZ9|xbvOO^^i81|}A_4nSIT*m(>&|v$c6Wb=x)@zmF|KU~M8Atv zmk+(RVxSKo$z2RKU=L}Q0>hUwuq&8Dl;yTo)~VHXYh_(p*YUlgR%HJX%fT(7NNb1X z@0>NXnS!}Jws~!ApKm|q?SJ?0zv-o~@B(fAT1s>$ky_)-8YI>zTLuQUa%E~?xGt>f z#^!=rRG8XY-L%%#N~v`oC?Q&CK-`EI0Skuw&#LJgmlKAyu>96Hjy-a|s2xJmkwM5|2YJ4%q|Ad{ITeCIgw(IGY%8`00G& zHCQ;&`rbmM7mc4}@X6{q#bK!UUr-@`_#4+=07Z#&FYLPh5B`mwF&;3V8OoWsatYF% z==Z3>EsO}Nzk7D(WZdkVLDtoj?Ym(YZ+XfBHghNf3wv7ejs$osRaJt4xtxxV*L0Y& zgA=}>RK1)8fl>-%o1j0s+aW0-afOHZm$LqDJ0bB}U<#%z!2}RGg@e8=dRgiP|7hzU zfvGl?rI3db_1^wcv3T@9`Re72}R z5`TYwtL_eC3--;E#HywaVt_?9hFVi%S@CdI)q z#KmC-4>=9KS10LP;wllNLG3Z*@U6!dqB=lNL%Ms6Wq#W=kiMNE29Jj$(b?mjJy?Os z*cfnmG;w;kE|Kal-*)*e7UBZVOyw7X*HCd(aKMH06_Ktej}g=yan~Wc;4%53Fhg^d zC#0G;AtM_MliI3Q5hNixkU!LwKh#mip{tGq;R#*bo8w#H&)BYE!y-1yFgq+obCkJ- z>@4wapnkPc`U6g*Aq2y)lM9 z4=#PxJW0H2skJ6iek4_YBvF9mZkHtRd5ZNEgm|f(h{T(fs=<|D!Aa$#-1cBb*361I zUqyFsyD&fNcUS@T(s#HMTX$lE*>r10-mnI@V-%RX#ZHpQSIXn>pvx$C{Co^yq-S!> z-Ayz36!uaxIgIF~_=#)ziQNea*L;2lYucUVr7taLOp3eg8oNix_X7^_BhEC|LoMUe z%dKklg#*z1-sS7s#}LtHwdj? zZ-j$3*CS^v=1!Leaj|~+x1RyPG^*`->a2gh9e_oaFB&h3s(+nc6gwBiZkcMM1H32) zRE>-==Kn3%hhB$^>al@AnvW$)w%Q@U&;f6{4XY8r7O|iGxxO>qkf!3VkPYKF=egm6 zZaBHYHr$NE+6X26xpbx0T%iL6+0=y)pYu6*!4k@FY!+ z@PO;&l0C2J%VCVT+6RD7LD|0_aJycQcT87_l(E}`yDxoW*gVM#+A)lyf5Hm3H}0=4 zDd*ew0PstYco*{!|7?F@)!GYvk`GY2R7i6@gKx^s-gduZNK|)Q zU$aVvx_Ws{3&hjLafy)pG(m*l_|z1Vb~qLPqb3a}EM#nzB=P@BO&ob%36rDJB%kM| zqkDKhgYPslt30efY5b?BS*<)G&O0X1s2l`NN{{gF2*P~2fig$MR*uQ2Yme9S+_`7+ zS)5lS^9cBS?5yj85Z=~_7+l=Yx*!Y7z~0srcxUVEVW-%km#{n{vVUP7}@RP007f4#W(-} diff --git a/python/.hypothesis/unicode_data/13.0.0/charmap.json.gz b/python/.hypothesis/unicode_data/13.0.0/charmap.json.gz deleted file mode 100644 index c6b036796449f41d4de90e2a0bfcc8eddfd055e6..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 20988 zcmbUIV{|25`0fkGcGBtCM#r|@vDvZFvD2|_t=P7$PC7PMY?~`5&))y{J$sDv<&5*? zs;g#I%{l9yg;Do!7I73D95@&l7{sTGog30Ka3& z82{6>cy8P>Q1bCrXz^C4X@03f@#lJ5p!4UJyU?#Q#napkH{jjr)CZya52u@m%E9(q zqtAtEp7-7(zUJxaQC}?gAm=zge?48<&dbA&CmWh{_&a*_V{$ci76shuRTf zHHsC3fPl33iKiWZtpfa!T+WsFgQo*Q?g7`;UB|jidG6oMg0!u*?t|uH_TI#jV(^IW z?WxDh8|Ao~mSpqPp_^HYP2ieJ=0hhBg=!K)=r zshV0jop_JX5dR{+70p^;wIGF^H1F^G#LymHS2LTt@#_+oZdzaHS0p*`2#b}r)Xo$a zlAE4$4U0Bi8N|87+csV$a~Z6VkZxFScOOC~II=7b}?9=%% zFRJAJrYv53CMUS1|9RSQdZlY zF!R~h(-qcaXS8#&sU(kXXXCm9}E$|=YD zP3`#SxJ&Z3xorYg$w|EC6g4T78b_1u@z5IE?H;7fG^bM}cSNhDE;U^lSH?z})A;9~ zK)_X^i_lSxO-45CXqC2!Y*3I(;yQ|S z7Zc)>*9|g`O`+oEW=LnJix8JzYkSScK}!v#%QVCj1NZzECrSPE@k8f~bgb8>wH!P- zeH(8~S^KvAqFlFBb4sv*N;koKY~_4kwri`HN)9RP=U{)MK*yA7lRMq#*g^}@iHhnXZ|cMCE|+x__+?i2kFYh_wfI1ln6Wik1nGB_v%$@V zw7%xx8tA8U`&A2y9ywq4>&;pBBjH`4aXdhCN)zdgkM?v~OS0o)dF4+%*V76XldX#q zwYv#^exN$WWsYHHV~vhut(nYxyPfDj^G5MI$o_=)`&6q-sOu@O4Sjhm|LbNsOWV|( zJI6Z1WYgIKVHKmxYOPm!b&)p4=+Q`uo&KsiwwH_#&>erK>&4CXpR2C(ru`qFlh36= z<(fLd3Zi*T=RrbY?%?Lv9mlzp`gI9p=rwa~fjS8ju_^po9dqQC+fS?M%yLyeo;nYj z&+(YtC$8MAdUZP{w>S`ijB#Cihq8DA_*W~NZn7LTts@hJO|Fj%0u$gepQL93Yo)99 zW!@#qj<6TqQ=Jk_?y@wP1vj4(ObYMzT*Y(6m)KXz(Q|Yrb<;SA7^B+JnQ1i`!8Y)g zPb@rBH7csU_NS3f*3E6Yxr8>)n{&Tkt^Wb5i!@piM}~4lNoFP67-RQcw>U{>cVAC` z7A@_*;XNcie>N<#RWKIMMOe#~oA-1UHwV2X>ZmqC7tLIqmM^5v{Tb@?PWiDCsJD12 zwp#b}+{vOL(_~OX=O(+GG?h??^kVD0JI)T{+>fNoKX~GL6lXg!e|9dVc+kVj4yJ3E zc;;4^5TZqgjQ95eB0s-Nkf4{F7)S{6dwHrf2zz8W>};f+*xKOlWq(-1jc=#h%;|e^ zKlbTN24hb|{ii%oO^Tb?#8nyZ_ z-MzHv_U-xgQ~%1d+M_=t&PLjyI9vlW(MG!`MkCvMS0`Y+rH}wqu2{3O)7<*xUN-We z2+-Knb_6dcI&r6+hj1jj^qohkJhg2;>&5M`K++!_H*>vBxKOiw3vTtEOB8^ z#GGiY+SWc*|2x-k`IBL+{BCm*wqa?d9{+V+;X{n2V{+!DF!^CU7{s)YXf~2DPreE;Kj;h!@%A_%(%vh^BNAfytKLYl2a=Czt$be`@WA);cem!>GskFr2O;Yqk02G*tTtOKyRyEE=FG=A|FfOz zfPh=DyF>J&Qu6GqJj9f#8r&WBL)W66@7TfWeGwsAqV>hY18i1$zlV%x*h^b(~v4P?kCqfOMMJ#NmyD@4Gyc>Lrzl2L$%9%s`%n`B z;xOFE!2?=S5RV`I#w zbq+Q@&1@ucQmclOX>TW{ooN`#G3@yUcgub{YoR-ulkxntXP;N#aVJFOSJb&r^P88v zkPh>d|++zc|3wnrNLM4)cniP!^S!OAcV4m{wu5YEyAu^cdV1Ah#U7w0-Gvgbq%Y3YI_y+ z(Ko(XrSmG~#myRgK@G7EZM}=qm;60%&y>Y&+6Kw0zuT$p3y(ZgxczG1vIwWynsW_( z@jgFKZBPBT{j_a5y&l8%j!Hq)xubUX{)Z!17W~}=Kqjwq8`$MWpAW=6N;_H#q=iak zOMWWgvJa>5fS$DHt=p|%DaR*&x=`F#5ZKT~jx+shU!O%vP|Mt?%!uR81m@14GTJ^z z_a6Qsd!G2|T@Za(b#GcMzq;c~u|u5A+ueH~6SvOZ@vL%@6Re}h6;*+@E4mc;CtyAm z&rX`1m_d~Z(KkNqpxrk?93dTtJ(~x?!UVn7Q@dx-M$Y6_YtPkt_hIE4PjcM8mw_dN z!?t8MiTI(~Z*$t4hSFWvFFVmIcYDcz4a%V^XVOx^?z zJp5Z^q&EBXyYuJ)HV@u~On>L!`t-m3RQxooifiRLCBAdwm}Grj_R`{O19ji~Y{!a{ zd+vOOU7;QV*KH@5IoBKQ(KpINKlJ5b8{YGsI39Qn=ieONuRm)c1Sw#9pRS~;9e#PA zTrT+FO}5x)JqT*dQw}oQdwpJ^T!vi$Ea7lAB*vD#d0&OHzaAn*{D3anGcVLh_ZM&L zBPBO~*bn0WItDt%kD7YQfAc+_5gMGWJ4W8S;7nyPuqXEyseEFwHB3GVK!N%I{^SyQ z^0*|3(|_F9?0>$0*>Bal&k=x?G+oQ=X=-VA}r|+)44Wriw{3NlP-UD z_qOi5Ij3dl2z}LkbGh*Dj{jnP`SSxn>-$u}0f`c@ev5R0(&i=MP|4BoNR>?hXzJ-q z1LqocKVN>H+Gfdww??fAGdC2buSLmbBSLm(I`*qa?R`u{$-8ccd7rTy1T^haKQ`!j zV*`plFb;%Y_N9416#F_dFM%ipB8zLC&M=!&7VSrnuoBH3LKmkf=N4~(t+cd{E_!~? z`@c62*$Nwjt7pE6uR}+?UHGjRSC}7kEkZJcNjaEx5lane&|?5Ek4 z%Eiu!8sEE&0-UAm<1Bv7s36U`9q-9jZ2+EdCy@Hs;x*J=I+Uz`31GC4?+YOCE%@x&@e^JLQmX=< z&&5~r9;kf+bv580I|MyHTps3u?V(c~2Gb^INgskTzPtYYPlOk5;sVpl(@4G;SD2vr z=B?V)8pm#Wy;F;BztkFw?lbL<+=m*AGq1-pPgGylE$R-Fw+GDEt4G8%LhgWxG>3|-g1ptSML2T&w<)4LLMw05O!yGiUlXA|eXXURy@hUkUjxqeAB z2w$B0lHP(RbmQAIZhMOv_s_E2@C9HH5^sMFa`Z~F_`0Lsft3l4sP7EcF&-7A;2mynJ-4;&9Ul#=`su;!M zD<;&R8kJI;9`AHOtxnYweK>PlHTyia6WkrraMrrI4$H{E*uRDbW*z%VlS^yJe@7X& zaH({GZ$1RojO%5G#3WQlgGv}Iv_gz^NrV8cB-@via__m8{^p|JzgYTW_nhxnoR9LE zKEdw~5NM*jk1y2?(-~*`pHEyWgRa00KmVd&*P$*>&v5m=(f9D^>3C|M{(&=D0cXKt ztq;&5K);@I%ddk&WY0y77udN58^@eOspH!@w-yR+fEGF8M9NdwzYY_@d z&XPW$I=B=fKo*A7wi89LI+(-0ih zHaf6^U}=bW8E&e8q-bM^aGA2=wePv*e)RiVF~$A$>|i>3T6i3(Fryg4Lt0Yg!6hHLX4feqdZ-^n zYtvYrEp@r}T(+z0Y76>_Y#(7o(!SeQ&I(#R5{WdXw03|aZf17>x<^|TYJwN{;-(mg zk7gws8%rk>jif*~GW9{o&MISltUTi1jtSFCcvLKmJG%_fP1SxJO!L62D6H6`ffWdVpZoXa&WwJa;kN)2)<( zp&DyqE2f1>V{_Vxd1G5Tt+iYP>#oWr-;@8966=kf`biuqU~um)AcPJI@-=MYvb8f> zfL}H6V3a-$5oGlqI}WK^6h~AkteLN<=aBKXUWFC!P4N2DO{(AFx~4txC{*pg4!Y;# zuZ0wVY}vA!lq(?KSpk(iFv52aHO5*o;_FMwX=8ZR`$|n`3i#txf8@_t!V1}BrZ!fa zu8p;rSPjb-_1E$8+YN_riG3ENus{Q_O1UCB~bT%pyE=%6wwq&yJ1$ zndN2I`GZnZO>s}wOHQJwfWF4bY=SbG4#{>n6Agc@iuqU$&O4J_>R8%^_qP#O3KnN2 znH}0VX=*m#g)fp@p_${LhooC2yS3^yFARz%4sqYewjc%VzF*8 zb~LeWqn%J)CekdTKD!H4UY>Ecz<#%bHW3X@T{D%Vs*$5zq1g#&o8 z0La0HVzbiwt)aKNLHQa|Iwpl_)dpzTND|OH_?2J5DPRfLYhF`9xGy+%0AoPQH-PAd zBPW8Zk}CV&MvPud9K6ne=uq;ti25wN&WQTT zHeC`z?qpdw_=;6C9k*G%U&_r7$nEtD* z%ywkG;Cqb7kk(*=LplbjE}4;ha8-mbw&~Ytr1|iyBE{1S#u~80pcf;p66GtsMimu) zkLe%6QEPD}*`rE!a$X0wlGce&EegO&M4@;HWGRz={Tl0$?S347yrA%|kX`%#)fBAB zwIvcp&2M%X)RFC39DbajNWdm`w~nA1VsVi-iXYtMMg+Npi|`szpJdi=>17Ptxc5$v(q9W zt!hJKummG^V-86Yi4bG1IE4%)wYIhW1MYl(tEF8noNIm3>3M93tz zkV->}7F)SNh*vm;@uwCw#tg4$guL%FY$UZ>UiWEv#n9P5=7ZyZn}5GETGps1AFVol z=Ntho8`KId7aoH{ZDcz~1T>YtJTe%(E4}-&{1HWiR7k?yzy+Y+Oj! zVC22=&}>S`+W_XH@z`u`y#&C1T3Y=o)96=oL>TYg_rdtsxL^3 zW6|H#YBZa z!nsOI=f134zqXjb*WwGht?Q!c-3$7yC766bDCI%;70vILOe=_BYw1NBj8>p8&MAH$4{x{{QR$Q>eGuzAWI`JO9M@pWs&r;!V7E*T9~B zEzfJ)dq+T-Z#Mti5D?0agw3IonWwl~q&)23DGGi5pW=>9`d4H|%iP}&^!Etx|4Nu? z{@Z`ze+?zj*njG`{FKzm3k?r(@loNv55>YY#R)v#MAmJ~Kac=`nY5RWyqEeY>E0(X zTXF7A<~KhdVx7hC_gSsOmr0{|f?Ch5n`mgSE~41wy&ErXgSgSN9S;Gb+N1d^*N?Eh z2ao^9mjAPmXgbRDcNbw?>)?$W{}#&VI_{On-i4zNVH^_aoP?W9rEAs&4ZU$T`JJ-{^={47@63N>74Vv@4UdO=FQN<|oa32m zS=!c_iu9Y&+GbO`GUgHo7?A)PcMC>#XYi%P4+i6go-q0MVEOkYv9TXi zCjMtTK4~?My`Md}DbFp+VeWAjz{fLxz8CuXMk9w)Ohe~m`=LQZ`fu-19?yDS=e^fL z`jL%Qk>YUIcU)k8oOh1N!-3hbH{8kI87p#!Yc%sPo)82MU;~us7}}7c(Td1u(Nsv1 zgCsSKvzkiqK@Re;bTvc~&&#~e%VWL|e}0`}rmdvK!F#OWyT>1)h3(^3bJsvgIp zcu2YDa?cO_MGT}6)xN=3uH5{vCl#4AV6S@!A;aq5&yhIy2TQ{fErwAWS2D08Lr91e zr67YT5AWAk>j!7CJ1gnCHy6&kF!M2xTl?*o))`2}WX+28@Ar_k{NGqPT1zvE#?TuK zjD0bFeB`4Z6b6~AT$>#zVUx{-j%?SM$#xfSMwHX7o28D)W*hH|feo#gqop|Kvx#em zf7h)(IC!&bT%7ek)-AttPtJ0i0B~P*NM=T!Gc+dN%t7AVdr;0+ z?nR)j@}TJ0!)O#E$P(kO=E18E6ckCzO5bn?iI-QClJ`xaA$fSTJ?arrG;EAo)lZp9 z+|Y(w`h-auQ&yF~YZwMZGfCz?oE`<4=MgZ@1$GU!9J@j>py1O@ENx5D6F60#KV^CY zsu&7?osd#Gafji5mQrbSz?kD5HF0=YWbu1kD9vR%jdjz2M4k^W?>&Zwv>L@RZOFRJ zyKn5SJ|KI^$omNswrj@Dzn${EOa?`q^A66(i~MfzP;V>r0kzt8D0%Jv67}zS(wBPHE2FlpBkjIDX%qzDF!upC2bB5KJaIa{Fj~_h=BVeA!@Z3SLHP zx7Q&D^`^gu+fXXI)@Je~{QO}CQ!MjiGg28HDK;|FwQE~Hp*9wMm4nPXAd7kECCs#* zXv90vN`(_gp``O$)J{N=1>)VI4y-YY-aa1|8<`gLH|=AbPW{=SSPv1+u5IgSBf9r! zI?5_FEfq(m@pT2)g*8i9vq0Br6u~0Dp$5-Bg!1g-Soh)f2tr#O{dQUC@=((*y4VGr zV8eI&*6T`QZsZUhXdg9x**(Mn^$@>z>n4=ymc7BNO!Ug&Dg;A>=xMYOi#{8HWO9Lc zROU=vFrfjh5EE#nmdfc^!Optc8GT6r@9#vDndqQxbS!Hu?g-NCA>r{?wDoo2E+054 zX0!uh$u*Q~R+OAy^0Kcr=EE!=)6!yW(V$?EmhO%TA6SFxzdr$EZac>PY1Bgd0a-US zwq1a9+)$qj$R91FNIPKRj&Ks(+xGak&8XakT_9vrHzZH<^kViVY}BD4`U6!AT(DM( zKEn`FlxT>_SIDcC89q)l#h$Zr?oZBpx?fH~+d73de_1RE?0(^8b9j_$3ER5XipVjW zS|O#Ui4g!GjXLw|`Ouv*6mT@II~B~5exKxsDk(6etG1QFQ!HR%SVO$Iw9fq!7G0>G zooQ>Xg^FLb2NBb7=`J@UCE|ePkJ+A3;xD&{uxv7uaNKRC)@kk-PzTiB+Cfbj36fTo z&-*c?JE?BaV@ev@QMyVG42sIg2~Nt~EiBQ{pqKa}tM%HRh_jn$h;Z$5i{(7 zV_$>Ot7s{_|9rd>m3KYyLA21?_3dWl{tT`Z;7Ck6e}-oBvAQr!FHZeU^6TLXjlje|iAlRnEBsBi@Yc-dTBF>5` z12NglODs=_FmBT%8dM;`-9ewy_oy8z+~~GoW`mZ?%g_naiH0quf!3KqaKY0~;Yjj7 z$>^ZVcwymp@H**^#v^n9mpgblV(=hBaNPhudV0U;O(}_0Fb{d&Q!oxW2Ab7f5<6|8 z-FjnR(^v~(`MR`uSIo<%<)-b?h<(XOC^@1B2ZHPiuJ$Tz>C3v{ZL8gcOx*$dW8#NG zA#~iFWN9^NiGQY{9Rv95+!DHokI)i^y0d?8@|+b^9tN_Z;8p?0;ph5`|LX(;VsQ^{ z`^eZaOD{oEd{nKScgve)jOXLvgg0z+RjHe4&Z-Wkf0lqJPIe=c_Y!t#;{*R)_UmJ- zT{lbNwrAuC?(}Fj-E_gr9GFPW2rGZcp?L1TT>4ZilbgDZdjn^Np<$CFZx4%@$jZ_# z{NT-|&u`|GlGR}9v{&l|;yJ6*j6c*@S9YOQy?@?3HSO>ICcH0tqM6L8a!mq9VNM>f1tRx?oo99z1v%A@w zr${8tE%rDr!woFM?4`|S4}o#ya>CpB2qs6Te%&)I!9?G8p8YKRL%`ruB)`T8dyS&k z@RD&)=W`t=(py2u`e6_x&m}W4eqKZCY8-0pyCbP!y9n#1mA(FLxt+wTO&Bl>U|;Wl zceCj*_E~b%oaAF4$^LB?KUR~XD15VTne*R_+?5L4YCZf4qYl?y>&~la!XDXga$r`M zvb#m%G~>T3w=v*j#&isD+6wpG*1J-uYLsu_-6owl@=$?wYkAjs{m?zB3Mgg=9NGR> zkMCDIHctg+gRj{v)1n*GRjve_O|Pv^s-5()Q3{IAq}$~*y=of$41G)3^s&tGaoWuU zGL-(tj{!p2c3G|smBK*xJg?jfyo}FX$^E>Y%!p$xKP`U2bN1gi#QZe5^|Z*GM$)by zKDEI8d>Y$gC}6)WwBjcw@t%<$`Q#Dp!E(ZTo=7HkZKk9tuyCK)*5zM+H@|kXX9H{y z^4F4I4R7-73(emNY=?=OyeJ7CN{Q)hvk~Lcq|ZvjJYPko{=K4%>(|fm=z571EIYNN z7_{h#`$Yi=8H(ayG03zDF{PJT0u{&!+BE2Brwf+M2#MFM5uHF@6 z>(x#=zsh=^`RBm-H_W_;1?PS-UC1us(^lIVxX*X@nbie4ISZSp#Ut$ftBKts1eM{R zk25I6{ABU|2 zYthlas5|dtk>g`=$R-9MniVT<+|%Az{4x!lY} zarNRn1Fm|%*EE~>ijrply?tg4XnGp$L2TofLYCWq2erj5cDcM)N%raIpD%s!X%eV+#PUYFgK22V67{I=uUqCa6Lfpokc+P@{Ncx!IIH*eJ z(v(8&hD&}7sw-a~pLl8I38dAYxwK)#eGCobH2?ixay1LAUv3WX8-t|+Rg3i0RC1)F zR&6bt_??nHLKvpK^p(XGTnxD#Iv$FK=a?Gu3O-6~c`8 zuAC;w5}wdPY}tV%7vct$eFiy@xbchod=BRf_VMA5q0qcDPk(hnuqhlb(s&^)1_l0j zK_PBuZ!l-ca`j`lSsr74%u@ty=n^?0g6}Pji=gtcPn_rxX2=vdKpP+v=4Tl$J~dB- z`|~^QPyZrb{g36!KszGbE9y;)r3S=bmUVs`N^wypM(DGtMF)sKHtw`>yF>_8#? z*F`VQHkXjIhsK3zujNLUugvf?YtC#ZQ87f#$8q#M!%|_4jx2=U4ahqjigxQ;r0wbK z-g(LwtQjRBbT$wgEfN!A7k@W=6Ub5Et@;?5h8agI65E?aJ4rt%;wvE@ji11T`hmon z8O#WG25UFrCH~`YK*2z$#J0*PmMXcH=)Kg1ib#zYCC-omYD!$WYEG=vdv7($ED3SY5$MDJw}8k8924(gQ%qu$C2 zib=3{6%e>cF^UV+!_67LwTc$gr__o*&|Q*`7tN_-qElazZ|yr`7APQak}ef(V;IYA z)B@?WLbOs41V|x@ln@FO>6V5%l=Uf#m6T;Eii2(HTX{y3`rM(SHGmRB{aYwMH)%fkD$o*pZdkQ72;?KHc$+$X#RQX1x zqw)AK9Bkf(iDT^F2k6o7n+7zuB;zTkhQi`~S_iLs0iJdK+yc-V&|MP3>V0a^nXsmQ zooe5Vm66C69iaWd%tiHN{!oagw8f=k)Mk7uX@QUp|KSu{1svyq zlWN|m2)L$Ef(biBZypaj#(YAL$lPF}CIM>q#X$+`R4F~sYN@{!q0^d4W- zpMYEAMF6tD-I2%Cl)v0}xL&JaWR2XrefglK{GOANjU zhBKH@?tr>dQiJgk0ADUbQn{&8%%+k66qBTk+WO0A}0=U%k zrJ-=KjD>xa4#3gAVMEi*V1-TR6!k*rE$T6nKsUfZL%^t3F8am8-Qnfo5Ey)x+n^yq zaMYpH5{JVn8)IG$J0Y?sTI%qt$c*#=UT zgA&0Kc(!MOEg_ie>C1=gtmX+kR#7~PD~(~jU1+M7Pm z$Q(QODv^`Np)jyc9oLZ-hW~gSO3$Z|L?iqH7RuesnI}!6=HE)Z{)|}3$`w3h4(3LC ztw~~9zRnUQ-D_cpESI2A(X5*e1{>lBFbo)p0pNW&12V47AS2S}4UoaMsn7~6v8#_5SkBkqHPR1{a&7Gx?23Ms6=kA|pj3NDAByZ+ z7VoCGY-W^~TUkQ2BmVwqys`4j1fx2WC}qElN|chNf?}?ScSBxO?Ll;Xic?UslKrNn z!?XmPvbiy6DHyO0Dtxe5zX~^92rekX|;1c>NShEXoVDIO`)56Re@IY4XDj_ilLT!uSt^dPUAHiS$N1&co zZw77BhD@o8ap!zpCj1|d(>NTPB*CNkf^H_#Lc%oG$8|a*&%!p)r5(;o?wcN5Uz9aH zycx!UmK5zA+CP%;EzR=Nm9mpM^*Qj6O0keihbRNJE%4(gG-+Bem5S;yb;0;bpG4lG zA6BB7iAs>dKR^LtKor|U47jO&OVyQj8DNR%EI79gTr74#VTchi_=41bx&IU`FtzIc zQrhFVGwU=KzrfV++=^l8*oKst%HBwEWSGxHuIItt+0`I^XeA8NqW|#-7_j(b=>oU% z4a{v(1H&tUJoyFF?F`~xU03EBm zwOuC2z?3fmlHFBFK9;X3X&?r6Ae>XU7O4mAYyHI> zjI5Cg$o_*PiZLP@Vlw|R1eZA+a;b8bbey1amN7bB2ys}j+2(KU2wt-|^3nYZIN~ns zSUyO?rD7x^8Zd|{3AGTvNHmq1aDsl!4P*+)wjk~6gsL!p4I4WQu0nh>Iy0B(qa=wO zZE^_~vykYc6j%X{tO)!*b4X%3M0;9(Q=71g9u2?K7pGwfUo-z|X&m-ZdfA?%)o2&f~%JPFQ7tFN1h)D_hfmD_kKiSAXy}~hq@4Enai`5#683rr37t89b*Md_L%y%KXO6?Nu|P{WASXR zZ%gvr<|?@v!MfZBaWX^VNF>s21}Z&i!Gn|%XYr2tMLsm})UxAIiH?Ou?=hPThaT@$pc#EsgWP7|)(53myID{cz zUDvWIZ{A$6Hf)3FGVW@y$s*uECv&2iUBkv|qDd@j&(n9BjPzO4XaPlt7;hr!wyBG@ z7l=u~Fn5Up+;+Lux~LPHCt;W)Vfg=PzyNNHmGX*VhfrX~fp!()7tdf?{aP!D9G}qg zq4{DMLV@{dRv6e;7^zkmV3Oh{86P~zP}>zJ!EJw|x!F^x&5~-9B7KEe-RA1A?3p=C zAm}<(z7F(n^DkNAb!{ehhy4mQ*9LL_It=2ILS3sP%-8h0k~foj`#<0|VgEPMsGer0 zg9>^B!Cv`4e@SnDZ_W$t=5Vd|H7M-L|9=FZVIyO3?thpDAH0TQoO@uDH?I>c0^k1= zYdpw|_|7E&OwUKR52m^@bN>fjl&8l7{BK`M|Gx-p3ZFt$Lea4FyfKA7NdqDk8|5tb zV%CSJpzM@XJ8=wxm6H5?GHK>FsRilzlV1>Z^yuc9p8znrcjxFs^#88#%JH2ru6FRomA_{y z9&ut6mGBD0(-S+oyz|UUWEi%0;=xZCTWX?{E2ficLKh&GnQ6kzy?frO>}0*Pw6;F{ zhiij**p2L(`Jm}2uQ;+0uM66J?1fg{BiM(Rhv)`gJu>O#Sf*@0hm8pR75nC|AN0Fg z==Hg~eE9QF#}E%4ugC63B!m4?7epAf>yLRu|Ds(c2&^Q3pL-~OpxINbPyT7SCTw1ia~mRfCICBbZ*LUoUPW+Wt5>qPoF z6E$L!PPjj_N(`C?#`c9&7qsJ9POAlRmgw?NLFH@XitXWGViP(^Ugdd^eT+%J`^k`n8vY!@ za{8nN9&>}Ewk|1b_A(5sEstG%WMJYCr4E-$#}m7RvdZ8})EA8cq?K*>%?2SZE#X#6 zc}_Ma@soCu{QA;)pLmqc9PuDv#I_kZqh}PhZ_t(xB2ll1L(c&pFppsE*7<5s_CfZZ zw+g>EZDOVop_^h~Em)NSg^dAf+d?p44p!@fpNN|{Q5L|aN{XB;m2+{@ur0he!4&oE z&MtPm6Oj@}*Kumsnaw$S1Ytv-W|W={mw{9geId4^MxGF_syEq0kUSWXF1=tQNzXyW z>P9V!LPdo_ER;^eLjX=*Vnvf+miUr~H~-Aw*PBfP-!)fv!6SXbV{3z=VvM3{jG`iq zqALAA%J0RVn$&pBse;UJXRpC*OGdo%IB)fCZ z3%v-^k;8nB9uT(WY2NKiXfLC#SKxgZ-mdslHGGble)6#1Ce(u1e)$d#smGX6_4st; zhSJ4bCgaSf(-aLG4eR^sm!X!~ ziagbtG?zJCBTwk+SjblFC+fJ8`kr|9pg6w36$%a}-b79VhALh=bS9?anz3;V*(%`- z0efb7`tw}QGsI=cc-%`i4qG}-;SfF*uWMTMF_IDQEV(6LsNE>Zuz9*MnBa6 zDe6Ek#ii+2szb~W)QjkNX|^Nn(O|BhFE*}B?-X#poX zhiBQvRZ-`B(fock=K=c!HQ%~XQ0+#t%f3(i_i~XlTODWo$Z+w%-t!*F?%Wu z+Zi^FS+?0Go5*jR@g(}iAv)A#B4*OzShjx{lrW6WbpG(kmOUxHa*da|3-qt2YE%!L z5N4Wb>XkKymLtYC?3SU?5{e}LxOawqZ>Kk&Lnk?j^mhs7>nI>3+&=h`{Co6%mBZUD zAh#)zypNwU5r-_`!{YF<|DPIX+!JZsFgd@=enI)-@`stzSYLPV!!Uiq}b zi0qRYHe^=qc!fauDSg$;M5pE8&TORj%lSzYCPk@wq&LUo-z-jygm{k@(Qzl>Cg&ixv%GM#L7~%J& zAn%OB$oeCvfbfK)kn@LO7_RaQ^dV`gc5a)hz~3VK3Np)C?xCjWmpJ3U^O(gDP+Wu62V};}JmaNu%uB$H9?IC4(KAk5mYL+C+MO4fbuCd+JBP z#C1~4egCh=jA(>^NnUf9ie43tqvHAfv|a-dM}iL<)bA*&!yi;L`Hib#KdH>U;ZBTN z{Nt)NAE}ZsUuInre{CK71aJ86in)<2XYd!g6qp$2&|lkZ`(iW;B)i6~@3ur~|Trvg~_aUCTQqK1~%P){>-ERkf>aVe}&Ql2swt zB6j~t=x)1&Sx>)d&%rO?HHYADxM$oaCu>lsfb|)>^zZyC4^vLO1>wKn@hu{) zbbtJ9&{i2@FJ;6z2pX3^vQS*{|0$6fvZoMcSJ6WMf=!du8 zYVOff$tvZ+IrGVRy6TY3=%Y`mz}1xu)PCSZ5T#J4$z%{Ef2hfq=LZ$g!Q2(#k zk%sW-IS}|OCVv{Zag3z65Xq|+9gq6U6v=?H#^{#EmZi8)5d5!+Fl{7GiW>BU(Z#Hw zAv!AA7wK%Wf;0-L+&d+%e8oliiLl!nh8F4U0uUf4Hmd)-Q5 zLU#6m;It(L-YlMBEGp|32WCnw?nG7qTxal`m(m>MA(|EUo(B~{i8zc;}ms?FUSCwwvu3XBWbpAO$g zC1^W9t(TW7RW2c}{O}z~^n*(B`*E|}hoU=K*SlZXJ0wo-A=sLc2>U@xcRK==w z=qEtff2rtrC^&f@d*e%$wT%y?aqmk#)S=m1x!(AWA1|h_)}_y9*@<|c;4x3?^TyC{RoPvr1U2(#SO z(jbEO5BXc%r{!4TzE(3rG9Yc#P0jYbi)=RxSATJhjZ*7~rKa)DtZb>>w8ChaSfjQQ zM#o=brsG}gV`rxW%(`1(E0$_2_EC!$M@alF(uJ+WW&`Rm{hQkwy>;L1c2Dohe61dw z4~+p|wv6Ua+{2Jnw#BscpQ+uyhx2Ow)fDTgm@6+3@|OTVKm6li3mC{}x5-%EHL6XGxRgQX7iou-1DtGWN=tgkFSuS7L}Iup2tTi&yu360a7`IKM%LeFphlH60 zRT{SDHCs!b0No6LU+<;r)_`&FL|afq`P#`GNAD;7)?CO&La}XvxQ{>xkAm9mM2bVn zY*Pc#6*|+fCTog|b+B>TV##d$!rTELZLZ9$0@1tz)d^LMWuTY|VjQvQQ>-=|_Xq6v zS8270ysOR{WoN_^*%vht=$Zd3Mi9C0#T?X;jp3453)HQ()2K^yRLaRyIhX{P)(iTujcJx#o@sTnZAE|e;q zc#sQuls|Cb(>ab!e2uB|&4q6ki=HlztSwq17rsdjqVyO_!jv2pp;Ne-%D8pWSsC}$ z3D@fK^iAH z0`{AA+qVm_mw#!|&189#dQ;S~`E^lm9#d~(W>Q4|?kr8frrN|rp30CfD$Pb}O;c$u zRz4E7X^w&KO|_b&md)w@G1F9|^Dp)%>glRZhU-gHoOc2+l#N}Sn4-coH6|n8(zuEP zbC3%27^A_NcGx0oq>i*j05pEMC~8f4ocX$`^$s;qQ`cd!yp6aa=z%H5IE_`HeQ)FJ zhoh>e$1Gvj7Gy%i+0NCzR#{O0cx+?eCg)Xq9A&NhlX#zyDeK|#HvGWVKHjh{-mq?* zU?_jS*aS7Z{5$b$P(7KegkPm{*AuaVd}P81msKN~K7{4MSPL;)H11z|5@>1|HH4Pt z3>YR{DBogX=~!Z_*g&~#pa5+b8gCDSTTZqdkXbV-yeAQVUHGADQ#FNHp=8Y2qW*#47;!(ExZzP*vOelyOlzopQH4DNoCr^0v(WhWnfBZ@Rz9 z{-*nz>~Ffi$^NGMo9s`q@(*4|G?o`--4E8=`do0UB@2b8OEM-4g{R+}WTEhMq44CI z`IsO3k)icu5pY=l`KS*WwL54Q4%>Hw=G`!2cW_!7Pm^n^ZPp&sBxV1eX5DzGG#Yi{ zQMCA<)s~vh-b6TdC^?6~F{iXXms_fzZSPbrS zF*wp!FVj~qR6RmtnRM;#eV>%m#X+Dmm>zn;ONB-<=b;gvaY?5>9^E@aTh(};`3v<& z6!=!IDqf=0JH%x@^~NNSn`xGl#tM2k zy1xICev4{wx`u{TM+g_Z;wJP?ywIulIRP3v6xHl=28L8;W@oGcV^|B$+G69H45K2LX}Z-FmGJthL7LbgZ>q5>P)r5;x9yi)2X^ z$5|tfsaZ_PVY+n3fi-W!nm42iHu)ABr#|5=8}J@2i~#~mj7CNO29P$wkl+>S!dZ1h zA-lXY9SOT#QNa$or3(qw!HU|&iP|?X>ih9cl5kS&jdd%B?j-|bLtID$DU3#yUX>4s z1mP%9tnO~LG6ScrT&97%k)cec0Yq?Zi2^{{iz9I#V(p(}`$(Z|VH9kzmZhvM(Bo16 zUWiODMW$EQFu7Yr>}#4>%DdTcrEf=e;Ka=RVlK|1W7wpEjr3&Pd_ zm4E5}WYss8l)f5ywxK3B`*q<4SNWOfRH+p+p)0}>4@_MlR>uUlRBGKPi+`$@g!d`s z$xb-3Q-18*Vs2aNWn;Do%@zUOHq>X}NY7l0E9GiWPRg&$Dv8A$FBY+~i|IDHFSYzc}blzMesxYtAj6bVneC?1R>0_2W*VrHfti(45l>)7Cw<*tW*+e3h^`rS`;pX!uag9rCiA6qvr5ETT7Zk zv*vJUaad+uG4b`4V+q$I??|Z>jhclm6coN$-GyB`HJfhDrc<+V{C45htwx?M&~R1r za`mXG*diLdAH#)dkobhgg=zSGKJN+6r(EDFv@%$DW!O{J=_yluoAKQ(e|Wh`8oed2 z@CYF|ma$Hv(R>EYDaFUPsaR}l7MoRg$Xj_vc4$ zDtz($$SfS8cqQwHR}+?rub<9VS@bh6$7cc9X?lMv@H!ZOb)g6@3-Todxevlz2Nf7S zjE}xT4U_TD?f52PMf(+2Sd3S`_~T~$QSnCN1?~ppchiLk zfd3BtpGm*^6d%zYb_w-OsICh11>=1oImtVkmJj29VgwzCkKcxxcGZ7QmVtg;D6Hc> zULkKGuDd-}@P4aY`;Wf=Px=*56^`+%id|_=$BZt&0_c&C`H?}8EM{FST`QioK+lU- z+YT=SXYoXJwNxS11pRf2|N6c_13rf#pF_gukn%a83OxNj2Q@j)XCDA;uI7X;Z}t^8wk9#y*FSK8Mdf2l^n=J_E40eCD#Gtd1;n)^>Bt;%`O=tF3h2W9&#K!X3NWNYoPzv2D{`y1|W zvcKv6Ci`f|vCRI2bR#kZjpa#s+UJn;IarfnKsn8dlT4USv*ILKanh_fNmiV+ zrj=xs$)A#po0`ar>?*M9l8g!unc0+Nk=<;Ooph1ie37i@jmAC&{=Pk5Ab;Op+*n~_ z{~l&kA1)b(TF_{`(g#<1muTZkXXNYpXp?^~srGRt{=D)5c8SI^y{yWRTuhN%pI2&v zXV+&R0eo_3Ec5qXsT;HCn@g%SWRfkH)LJa5RgE=+q-G#Z0>~LOUdg%A$MZ=n%GM0% zgQ*x7=cGT-%S=2sZ9|xbvOO^^i81|}A_4nSIT*m(>&|v$c6Wb=x)@zmF|KU~M8Atv zmk+(RVxSKo$z2RKU=L}Q0>hUwuq&8Dl;yTo)~VHXYh_(p*YUlgR%HJX%fT(7NNb1X z@0>NXnS!}Jws~!ApKm|q?SJ?0zv-o~@B(fAT1s>$ky_)-8YI>zTLuQUa%E~?xGt>f z#^!=rRG8XY-L%%#N~v`oC?Q&CK-`EI0Skuw&#LJgmlKAyu>96Hjy-a|s2xJmkwM5|2YJ4%q|Ad{ITeCIgw(IGY%8`00G& zHCQ;&`rbmM7mc4}@X6{q#bK!UUr-@`_#4+=07Z#&FYLPh5B`mwF&;3V8OoWsatYF% z==Z3>EsO}Nzk7D(WZdkVLDtoj?Ym(YZ+XfBHghNf3wv7ejs$osRaJt4xtxxV*L0Y& zgA=}>RK1)8fl>-%o1j0s+aW0-afOHZm$LqDJ0bB}U<#%z!2}RGg@e8=dRgiP|7hzU zfvGl?rI3db_1^wcv3T@9`Re72}R z5`TYwtL_eC3--;E#HywaVt_?9hFVi%S@CdI)q z#KmC-4>=9KS10LP;wllNLG3Z*@U6!dqB=lNL%Ms6Wq#W=kiMNE29Jj$(b?mjJy?Os z*cfnmG;w;kE|Kal-*)*e7UBZVOyw7X*HCd(aKMH06_Ktej}g=yan~Wc;4%53Fhg^d zC#0G;AtM_MliI3Q5hNixkU!LwKh#mip{tGq;R#*bo8w#H&)BYE!y-1yFgq+obCkJ- z>@4wapnkPc`U6g*Aq2y)lM9 z4=#PxJW0H2skJ6iek4_YBvF9mZkHtRd5ZNEgm|f(h{T(fs=<|D!Aa$#-1cBb*361I zUqyFsyD&fNcUS@T(s#HMTX$lE*>r10-mnI@V-%RX#ZHpQSIXn>pvx$C{Co^yq-S!> z-Ayz36!uaxIgIF~_=#)ziQNea*L;2lYucUVr7taLOp3eg8oNix_X7^_BhEC|LoMUe z%dKklg#*z1-sS7s#}LtHwdj? zZ-j$3*CS^v=1!Leaj|~+x1RyPG^*`->a2gh9e_oaFB&h3s(+nc6gwBiZkcMM1H32) zRE>-==Kn3%hhB$^>al@AnvW$)w%Q@U&;f6{4XY8r7O|iGxxO>qkf!3VkPYKF=egm6 zZaBHYHr$NE+6X26xpbx0T%iL6+0=y)pYu6*!4k@FY!+ z@PO;&l0C2J%VCVT+6RD7LD|0_aJycQcT87_l(E}`yDxoW*gVM#+A)lyf5Hm3H}0=4 zDd*ew0PstYco*{!|7?F@)!GYvk`GY2R7i6@gKx^s-gduZNK|)Q zU$aVvx_Ws{3&hjLafy)pG(m*l_|z1Vb~qLPqb3a}EM#nzB=P@BO&ob%36rDJB%kM| zqkDKhgYPslt30efY5b?BS*<)G&O0X1s2l`NN{{gF2*P~2fig$MR*uQ2Yme9S+_`7+ zS)5lS^9cBS?5yj85Z=~_7+l=Yx*!Y7z~0srcxUVEVW-%km#{n{vVUP7}@RP000D$#VG&) From 7aae350fab56338f0bfc98620cbc71b694f05ad2 Mon Sep 17 00:00:00 2001 From: Beal Wang Date: Thu, 7 Sep 2023 13:37:07 +0800 Subject: [PATCH 017/122] [OPTIMIZER] async launch dots for hopper warp-specialized kernel (#2251) --- .../TritonNvidiaGPU/Transforms/WSMutex.cpp | 43 +++++++- .../TritonNvidiaGPU/Transforms/WSPipeline.cpp | 97 +++++++++++++++++-- test/TritonGPU/wspipeline.mlir | 3 +- 3 files changed, 134 insertions(+), 9 deletions(-) diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/WSMutex.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/WSMutex.cpp index a7bab9ff1366..4ed9f0c64996 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/WSMutex.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/WSMutex.cpp @@ -18,7 +18,9 @@ namespace ttng = triton::nvidia_gpu; namespace { // Target operations: dot, load, store. Add more when necessary. -#define KEY_TYPES triton::DotOp, ttg::InsertSliceOp, triton::StoreOp +#define KEY_TYPES \ + triton::DotOp, triton::nvidia_gpu::DotAsyncOp, ttg::InsertSliceOp, \ + triton::StoreOp template void getKeyTypeId(Operation *op, int &id, bool &found) { @@ -209,6 +211,45 @@ void mutexSync(ModuleOp &mod, scf::IfOp &ifOp, scf::ForOp &persistentForOp, unlockLocs[i] = op; } + // Update unlockLocs + // ====================== IR after async launch dots ====================== + // * %0:2 = scf.for %arg0 = %c0 to %1 step %c1 iter_args(%arg1 = %2, arg2 = + // %3) { + // * triton_nvidia_gpu.producer_wait arg2 + // * %5 = triton_nvidia_gpu.dot_async %4, %5 + // * triton_nvidia_gpu.dot_wait {pendings = 1} + // * %6 = arith.cmpi sgt, arg0, %c0 + // * scf.if %6 { + // * %7 = arith.subi arg2, c1 + // * triton_nvidia_gpu.consumer_release %7 + // * } + // * %8 = arith.addi arg2, c1 + // * scf.yield %5, %8 + // * } + // * triton_nvidia_gpu.dot_wait {pendings = 0} + // * %9 = arith.subi %0#1, c1 + // * triton_nvidia_gpu.consumer_release %9 + // * ======================================================================= + // after async launch dots, there will be outstanding consumerReleaseOp after + // ForOp. we should expend the unlockLocs from ForOp to the outstanding + // consumerReleaseOp. + for (int i = 0; i < numRoles; ++i) { + Operation *unlockOp = unlockLocs[i]; + auto filter = [&](Operation *op) { + return op->getBlock() == unlockOp->getBlock(); + }; + if (isa(unlockOp)) { + SetVector slices; + mlir::getForwardSlice(unlockOp->getResults().back(), &slices, {filter}); + auto iter = llvm::find_if(slices, [](Operation *op) { + return isa(op); + }); + if (iter != slices.end()) { + unlockLocs[i] = *iter; + } + } + } + // Only cases where all lock/unlock locations are in same level make sense. for (int i = 1; i < numRoles; ++i) { if (lockLocs[i]->getParentOp() != lockLocs[i - 1]->getParentOp() || diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/WSPipeline.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/WSPipeline.cpp index 459eff719fb7..373eac0e548b 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/WSPipeline.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/WSPipeline.cpp @@ -341,7 +341,9 @@ DenseMap createForOpsForEachAgentId(scf::ForOp forOp) { for (unsigned i = 0; i < usedArgs.size(); ++i) { auto oldResult = forOp.getResult(usedArgs[i]); auto newResult = newForOp.getResult(i); - oldResult.replaceAllUsesWith(newResult); + oldResult.replaceUsesWithIf(newResult, [&](OpOperand &operand) -> bool { + return hasAgentId(operand.getOwner(), agentId); + }); } agentsToForOp[agentId] = newForOp; @@ -642,6 +644,30 @@ void buildAsyncComm(const DenseMap> &map, agentsPC.insert(agentsPC.end(), agentP.begin(), agentP.end()); agentsPC.insert(agentsPC.end(), agentC.begin(), agentC.end()); }; + + // Don't pipeline dots that depend on ops other than scf.yield and scf.for. + // Because the DotOp will be replaced by a DotAsyncOp, which will be issued in + // iter_i but waited in iter_i+1. The use of DotAsyncOp should not be ops + // other than scf.for and scf.yield because the result of DotAsyncOp is not + // ready in iter_i. + auto getValidDot = [&](const SmallVector &block) -> Operation * { + Operation *headConsumer = block.front()->dstOp; + if (block.size() == 2 && + isa(*headConsumer->getUsers().begin()) && + headConsumer->getParentOfType()) { + auto dotOp = cast(*headConsumer->getUsers().begin()); + auto dot = dotOp.getResult(); + auto resTy = dot.getType().dyn_cast(); + auto cArg = dotOp.getOperand(2).dyn_cast(); + if (auto resEnc = resTy.getEncoding().dyn_cast()) + if (resEnc.isHopper() && dot.hasOneUse() && + isa(*dot.getUsers().begin()) && cArg && + cArg.hasOneUse()) + return dotOp.getOperation(); + } + return nullptr; + }; + // TODO: try to optimize locations of arriving and waiting token // for fused-attention for (auto kv : map) { @@ -694,12 +720,69 @@ void buildAsyncComm(const DenseMap> &map, builder.createWithAgentIds(headConsumer->getLoc(), token, pipelineIdx); - // insert ConsumerReleaseOp - auto consumerReleasePoint = - consumerReleaseHeutistic(tailProducer, tailConsumer); - builder.setInsertionPointAfter(consumerReleasePoint); - builder.createWithAgentIds( - consumerReleasePoint->getLoc(), token, pipelineIdx); + /// async launch dots + if (auto cvg = getValidDot(kv.second)) { + auto dotOp = cast(cvg); + auto dot = dotOp.getResult(); + auto loc = dot.getLoc(); + auto forOp = cvg->getParentOfType(); + + auto agentIds = collectAgentIds(dotOp); + OpBuilderWithAgentIds builder(dotOp.getContext()); + builder.setAgentIdsFromArray(agentIds); + builder.setInsertionPoint(dotOp); + + // 0. replace Dot with DotAsync + auto dotAsync = + builder.createWithAgentIds( + loc, dotOp.getA(), dotOp.getB(), dotOp.getC(), + dotOp.getAllowTF32()); + dot.replaceAllUsesWith(dotAsync.getResult()); + builder.createWithAgentIds(loc, 1); + + // 1. insert ConsumerReleaseOp for DotAsyncOps + Value cond = builder.createWithAgentIds( + loc, arith::CmpIPredicate::sgt, forOp.getInductionVar(), + forOp.getLowerBound()); + auto ifOp = + builder.createWithAgentIds(loc, ArrayRef{}, cond, + /*hasElse*/ false); + builder.setInsertionPointToStart(ifOp.thenBlock()); + Value one = builder.createWithAgentIds( + headConsumer->getLoc(), 1, 32); + auto oriIdx = forOp.getBody()->getArguments().back(); + Value consumerReleaseIdx = + builder.createWithAgentIds(loc, oriIdx, one); + consumerReleaseIdx = builder.createWithAgentIds( + loc, consumerReleaseIdx, numStagesVal); + builder.createWithAgentIds(loc, token, + consumerReleaseIdx); + setAgentIds(ifOp.thenYield().getOperation(), agentIds); + + // 2. If there's any outstanding DotAsyncOps, we need to wait for them. + builder.setInsertionPointAfter(forOp); + builder.createWithAgentIds(forOp.getLoc(), + 0); + + // 3. insert ConsumerReleaseOp for outstanding DotAsyncOps + Value one_ = builder.createWithAgentIds( + headConsumer->getLoc(), 1, 32); + consumerReleaseIdx = forOp.getResults().back(); + consumerReleaseIdx = builder.createWithAgentIds( + loc, consumerReleaseIdx, one_); + consumerReleaseIdx = builder.createWithAgentIds( + loc, consumerReleaseIdx, numStagesVal); + builder.createWithAgentIds(loc, token, + consumerReleaseIdx); + dotOp->erase(); + } else { + // insert ConsumerReleaseOp + auto consumerReleasePoint = + consumerReleaseHeutistic(tailProducer, tailConsumer); + builder.setInsertionPointAfter(consumerReleasePoint); + builder.createWithAgentIds( + consumerReleasePoint->getLoc(), token, pipelineIdx); + } /*****************Buffer related*****************/ /// splitLoadsInForLoop diff --git a/test/TritonGPU/wspipeline.mlir b/test/TritonGPU/wspipeline.mlir index c2b0a1b70813..a08e6fe1ef7c 100644 --- a/test/TritonGPU/wspipeline.mlir +++ b/test/TritonGPU/wspipeline.mlir @@ -21,7 +21,8 @@ // CHECK: triton_nvidia_gpu.consumer_wait // CHECK: triton_gpu.extract_slice // CHECK: triton_gpu.extract_slice -// CHECK: tt.dot +// CHECK: triton_nvidia_gpu.dot_async +// CHECK: triton_nvidia_gpu.dot_wait // CHECK: triton_nvidia_gpu.consumer_release // CHECK: scf.yield // CHECK: async_agent = dense<1> : vector<1xi32> From 7d01c1852a86ff71abbe00dcead4af9d3e35e7b7 Mon Sep 17 00:00:00 2001 From: Izzy Putterman Date: Thu, 7 Sep 2023 10:48:12 -0700 Subject: [PATCH 018/122] Revert unintentional change (#2257) This change seems to have been unintentionally reverted in the hopper PR: https://github.com/openai/triton/commit/38d767ea93ae31c949b68baa50429f7a7e1b0cc1 Adding it back. --- python/triton/runtime/autotuner.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/python/triton/runtime/autotuner.py b/python/triton/runtime/autotuner.py index 77091da42f25..e3f2794f7d46 100644 --- a/python/triton/runtime/autotuner.py +++ b/python/triton/runtime/autotuner.py @@ -115,9 +115,11 @@ def run(self, *args, **kwargs): full_nargs = {**self.nargs, **kwargs, **self.best_config.kwargs} if config.pre_hook is not None: config.pre_hook(full_nargs) - return self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, - num_ctas=config.num_ctas, - enable_warp_specialization=config.enable_warp_specialization, **kwargs, **config.kwargs) + ret = self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, + num_ctas=config.num_ctas, + enable_warp_specialization=config.enable_warp_specialization, **kwargs, **config.kwargs) + self.nargs = None + return ret def prune_configs(self, kwargs): pruned_configs = self.configs From 52aa663dcb5128ed63fffc6a92b914920c37bbbd Mon Sep 17 00:00:00 2001 From: Thomas Date: Thu, 7 Sep 2023 20:40:48 -0700 Subject: [PATCH 019/122] [BACKEND] Remove dead code (#2263) --- .../TritonGPUToLLVM/ElementwiseOpToLLVM.cpp | 158 ------------------ .../TritonGPUToLLVM/ElementwiseOpToLLVM.h | 6 - 2 files changed, 164 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp index d6ff83cd5187..b1c73cd3230f 100644 --- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -1281,161 +1281,3 @@ void populateElementwiseOpToLLVMPatterns( // __nv_expf for higher-precision calculation patterns.add(typeConverter, benefit); } - -struct FPExtOpConversion - : ElementwiseOpConversionBase { - using Base = ElementwiseOpConversionBase; - using Base::Base; - using Adaptor = typename Base::OpAdaptor; - - static bool isLegalOp(LLVM::FPExtOp op) { - auto retTy = op.getResult().getType(); - auto srcTy = op.getOperand().getType(); - if (retTy.isF32() && srcTy.isF16()) { - return false; - } - return true; - } - - SmallVector createDestOps(LLVM::FPExtOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter, - Type elemTy, MultipleOperandsRange operands, - Location loc) const { - return { - FpToFpOpConversion::convertFp16ToFp32(loc, rewriter, operands[0][0])}; - } -}; - -struct FPTruncOpConversion - : ElementwiseOpConversionBase { - using Base = - ElementwiseOpConversionBase; - using Base::Base; - using Adaptor = typename Base::OpAdaptor; - - static bool isLegalOp(LLVM::FPTruncOp op) { - auto retTy = op.getResult().getType(); - auto srcTy = op.getOperand().getType(); - if (retTy.isF16() && srcTy.isF32()) { - return false; - } - return true; - } - - SmallVector createDestOps(LLVM::FPTruncOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter, - Type elemTy, MultipleOperandsRange operands, - Location loc) const { - return { - FpToFpOpConversion::convertFp32ToFp16(loc, rewriter, operands[0][0])}; - } -}; - -struct TruncOpConversion - : ElementwiseOpConversionBase { - using Base = ElementwiseOpConversionBase; - using Base::Base; - using Adaptor = typename Base::OpAdaptor; - - static bool isLegalOp(LLVM::TruncOp op) { - auto retTy = op.getResult().getType(); - auto srcTy = op.getOperand().getType(); - if (retTy.isInteger(16) && srcTy.isInteger(32)) { - return false; - } - return true; - } - - SmallVector createDestOps(LLVM::TruncOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter, - Type elemTy, MultipleOperandsRange operands, - Location loc) const { - PTXBuilder builder; - auto &cvt = *builder.create("cvt.u16.u32"); - auto res = builder.newOperand("=h"); - auto operand = builder.newOperand(operands[0][0], "r"); - cvt(res, operand); - return {builder.launch(rewriter, loc, i16_ty, false)}; - } -}; - -struct SExtOpConversion - : ElementwiseOpConversionBase { - using Base = ElementwiseOpConversionBase; - using Base::Base; - using Adaptor = typename Base::OpAdaptor; - - static bool isLegalOp(LLVM::SExtOp op) { - auto retTy = op.getResult().getType(); - auto srcTy = op.getOperand().getType(); - if (retTy.isInteger(32) && srcTy.isInteger(16)) { - return false; - } - return true; - } - - SmallVector createDestOps(LLVM::SExtOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter, - Type elemTy, MultipleOperandsRange operands, - Location loc) const { - PTXBuilder builder; - auto &cvt = *builder.create("cvt.s32.s16"); - auto res = builder.newOperand("=r"); - auto operand = builder.newOperand(operands[0][0], "h"); - cvt(res, operand); - return {builder.launch(rewriter, loc, i32_ty, false)}; - } -}; - -struct ZExtOpConversion - : ElementwiseOpConversionBase { - using Base = ElementwiseOpConversionBase; - using Base::Base; - using Adaptor = typename Base::OpAdaptor; - - static bool isLegalOp(LLVM::ZExtOp op) { - auto retTy = op.getResult().getType(); - auto srcTy = op.getOperand().getType(); - if (retTy.isInteger(32) && srcTy.isInteger(16)) { - return false; - } - return true; - } - - SmallVector createDestOps(LLVM::ZExtOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter, - Type elemTy, MultipleOperandsRange operands, - Location loc) const { - PTXBuilder builder; - auto &cvt = *builder.create("cvt.u32.u16"); - auto res = builder.newOperand("=r"); - auto operand = builder.newOperand(operands[0][0], "h"); - cvt(res, operand); - return {builder.launch(rewriter, loc, i32_ty, false)}; - } -}; - -bool isLegalElementwiseOp(Operation *op) { - if (isa(op)) { - return FPExtOpConversion::isLegalOp(cast(op)); - } else if (isa(op)) { - return FPTruncOpConversion::isLegalOp(cast(op)); - } else if (isa(op)) { - return TruncOpConversion::isLegalOp(cast(op)); - } else if (isa(op)) { - return SExtOpConversion::isLegalOp(cast(op)); - } else if (isa(op)) { - return ZExtOpConversion::isLegalOp(cast(op)); - } - return true; -} - -void populateElementwiseOpToPTXPatterns( - TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, - PatternBenefit benefit) { - patterns.add(typeConverter, benefit); - patterns.add(typeConverter, benefit); - patterns.add(typeConverter, benefit); - patterns.add(typeConverter, benefit); - patterns.add(typeConverter, benefit); -} diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.h b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.h index fbcbe95bd85b..22b2e2101206 100644 --- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.h +++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.h @@ -13,10 +13,4 @@ void populateElementwiseOpToLLVMPatterns( ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo, int computeCapability, PatternBenefit benefit); -bool isLegalElementwiseOp(Operation *op); - -void populateElementwiseOpToPTXPatterns( - TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, - PatternBenefit benefit); - #endif From 10f59d8ce04052521c1bc0cb3a3f8b98918fc7e3 Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Fri, 8 Sep 2023 02:31:07 -0400 Subject: [PATCH 020/122] [RUNTIME] Get the correct end idx for regular arguments of GPU kernels (#2262) Previously, if there were any specializations of "1" or "constexpr" mixed with unspecialized arguments in arbitrary order, we might have encountered errors due to passing incorrect arguments. This was because the length of the signature did not indicate the maximum index of regular arguments. https://github.com/openai/triton/issues/2229 @shunting314 @amjames More specifically for cases like: ``` kernel( b: tl.tensor, a: tl.constexpr, c: tl.int = 1, d, e: tl.constexpr, ... ) ``` --- python/triton/compiler/make_launcher.py | 7 ++++--- python/triton/compiler/utils.py | 7 +++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/python/triton/compiler/make_launcher.py b/python/triton/compiler/make_launcher.py index eb8079e84754..68dd8aeb19e0 100644 --- a/python/triton/compiler/make_launcher.py +++ b/python/triton/compiler/make_launcher.py @@ -63,8 +63,9 @@ def ty_to_cpp(ty): def generate_launcher(constants, signature, ids): - start_desc = len(signature) - signature = generate_cu_signature(constants, signature, ids) + # Record the end of regular arguments; + # subsequent arguments are architecture-specific descriptors, such as tensor descriptors for CUDA. + signature, desc_start_idx = generate_cu_signature(constants, signature, ids) arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) def _extracted_type(ty): @@ -99,7 +100,7 @@ def format_of(ty): # generate glue code folded_without_constexprs = [c for c in ids['ids_of_folded_args'] if c not in ids['ids_of_const_exprs']] - params = [i for i in signature.keys() if i >= start_desc or (i not in constants and i not in folded_without_constexprs)] + params = [i for i in signature.keys() if i >= desc_start_idx or (i not in constants and i not in folded_without_constexprs)] src = f""" #include \"cuda.h\" #include diff --git a/python/triton/compiler/utils.py b/python/triton/compiler/utils.py index cb4f1f3ab832..d4b24a93ee49 100644 --- a/python/triton/compiler/utils.py +++ b/python/triton/compiler/utils.py @@ -26,12 +26,11 @@ def generate_cu_signature(constants, signature, ids): # CUtensorMap*s are always the last arguments + num_regular_signatures = max(signature.keys()) + 1 if len(signature) > 0 else 0 if ids["ids_of_tensormaps"] is not None: - signature = signature.copy() - num_signature = len(signature) for i, _ in enumerate(ids["ids_of_tensormaps"]): - signature[num_signature + i] = '*CUtensorMap' - return signature + signature[num_regular_signatures + i] = '*CUtensorMap' + return signature, num_regular_signatures def dummy_tensormaps_info(n=2): From 37478431439d964671ebb98f6b94e8efaa0f5636 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sat, 9 Sep 2023 22:06:27 -0700 Subject: [PATCH 021/122] [OPTIMIZER] improvements to layout conversion removal (#2268) * Improved heuristics for RemoveLayoutConversion; * add LayoutConversionOp canonicalizer for ViewOps --- .../Dialect/TritonGPU/IR/TritonGPUOps.td | 2 +- lib/Dialect/TritonGPU/IR/Dialect.cpp | 314 ++++++++++-------- .../Transforms/RemoveLayoutConversions.cpp | 13 +- 3 files changed, 186 insertions(+), 143 deletions(-) diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td index d91fa076479b..c69b62eb5a8a 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td @@ -28,7 +28,7 @@ def TTG_ConvertLayoutOp : TTG_Op<"convert_layout", let results = (outs TT_Tensor:$result); - let hasCanonicalizeMethod = 1; + let hasCanonicalizer = 1; let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)"; } diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index ce7912a43a7a..58437f850ac7 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -1510,158 +1510,194 @@ struct TritonGPUInferLayoutInterface // Canonicalizer //===----------------------------------------------------------------------===// -LogicalResult ConvertLayoutOp::canonicalize(ConvertLayoutOp op, - PatternRewriter &rewriter) { - // we don't handle conversions to DotOperandEncodingAttr - // this is a heuristics to accommodate fused attention - auto srcType = op.getOperand().getType().cast(); - auto dstType = op.getType().cast(); - if (dstType.getEncoding().isa() && - srcType.getEncoding().isa()) - return mlir::failure(); - // for hopper MMAv3 - if (!op.use_empty()) { - bool hasDotUser = false; - for (Operation *dot : op.getResult().getUsers()) - if (isa(dot)) - hasDotUser = true; - - if (hasDotUser) { - if (dstType.getEncoding().isa() && - srcType.getEncoding().isa()) - return mlir::failure(); - } - } +struct CanonicalizeConvertFromView + : public mlir::OpRewritePattern { - // convert to the same layout -- we can delete - if (op->getResultTypes() == op->getOperandTypes()) { - rewriter.replaceOp(op, op->getOperands()); - return mlir::success(); - } - Operation *arg = op->getOperand(0).getDefiningOp(); - // block argument - if (!arg) - return mlir::failure(); - // cvt(view) -> view - if (auto view = dyn_cast(arg)) { - rewriter.replaceOpWithNewOp(op, op->getResult(0).getType(), - view.getResult()); - return mlir::success(); - } - // cvt(cat) -> cat - if (auto cat = dyn_cast(arg)) { - auto encoding = - op->getResult(0).getType().cast().getEncoding(); - if (isExpensiveCat(cat, encoding)) - return mlir::failure(); - rewriter.replaceOpWithNewOp(op, op->getResult(0).getType(), - cat.getOperands()); - return mlir::success(); - } - // cvt(alloc_tensor(x), type2) -> alloc_tensor(x, type2) - auto alloc_tensor = dyn_cast(arg); - if (alloc_tensor) { - if (!triton::gpu::isSharedEncoding(op->getResult(0))) { + CanonicalizeConvertFromView(MLIRContext *context) + : OpRewritePattern(context, 1) {} + + mlir::LogicalResult + matchAndRewrite(triton::ViewOp op, PatternRewriter &rewriter) const override { + Operation *arg = op->getOperand(0).getDefiningOp(); + if (!arg) return mlir::failure(); + // view(convert) -> view + if (auto convert = dyn_cast(arg)) { + rewriter.replaceOpWithNewOp( + op, op->getResult(0).getType(), convert.getOperand()); + return mlir::success(); } - rewriter.replaceOpWithNewOp( - op, op->getResult(0).getType()); - return mlir::success(); + return mlir::failure(); } - // cvt(insert_slice(x), type2) -> insert_slice(cvt(x, type2)) - auto insert_slice = dyn_cast(arg); - if (insert_slice) { - if (!triton::gpu::isSharedEncoding(op->getResult(0))) { +}; + +struct CanonicalizeConvertFromConvert + : public mlir::OpRewritePattern { + + CanonicalizeConvertFromConvert(mlir::MLIRContext *context) + : OpRewritePattern(context, 1) {} + + mlir::LogicalResult + matchAndRewrite(ConvertLayoutOp op, + mlir::PatternRewriter &rewriter) const override { + // we don't handle conversions to DotOperandEncodingAttr + // this is a heuristics to accommodate fused attention + auto srcType = op.getOperand().getType().cast(); + auto dstType = op.getType().cast(); + if (dstType.getEncoding().isa() && + srcType.getEncoding().isa()) return mlir::failure(); + // for hopper MMAv3 + if (!op.use_empty()) { + bool hasDotUser = false; + for (Operation *dot : op.getResult().getUsers()) + if (isa(dot)) + hasDotUser = true; + + if (hasDotUser) { + if (dstType.getEncoding().isa() && + srcType.getEncoding().isa()) + return mlir::failure(); + } } - auto newType = op->getResult(0).getType().cast(); - // Ensure that the new insert_slice op is placed in the same place as - // the old insert_slice op. Otherwise, the new insert_slice op may be - // placed after the async_wait op, which is not allowed. - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPoint(insert_slice); - auto newArg = rewriter.create( - op->getLoc(), newType, insert_slice.getDst()); - rewriter.replaceOpWithNewOp( - op, newType, insert_slice.getSrc(), newArg.getResult(), - insert_slice.getIndex(), insert_slice.getMask(), - insert_slice.getOther(), insert_slice.getCache(), - insert_slice.getEvict(), insert_slice.getIsVolatile(), - insert_slice.getAxis()); - return mlir::success(); - } - // cvt(extract_slice(x), type2) -> extract_slice(cvt(x, type2)) - auto extract_slice = dyn_cast(arg); - if (extract_slice) { - if (!triton::gpu::isSharedEncoding(op->getResult(0))) { + + // convert to the same layout -- we can delete + if (op->getResultTypes() == op->getOperandTypes()) { + rewriter.replaceOp(op, op->getOperands()); + return mlir::success(); + } + Operation *arg = op->getOperand(0).getDefiningOp(); + // block argument + if (!arg) return mlir::failure(); + // cvt(view) -> view + if (auto view = dyn_cast(arg)) { + rewriter.replaceOpWithNewOp( + op, op->getResult(0).getType(), view.getResult()); + return mlir::success(); + } + // cvt(cat) -> cat + if (auto cat = dyn_cast(arg)) { + auto encoding = + op->getResult(0).getType().cast().getEncoding(); + if (isExpensiveCat(cat, encoding)) + return mlir::failure(); + rewriter.replaceOpWithNewOp(op, op->getResult(0).getType(), + cat.getOperands()); + return mlir::success(); + } + // cvt(alloc_tensor(x), type2) -> alloc_tensor(x, type2) + auto alloc_tensor = dyn_cast(arg); + if (alloc_tensor) { + if (!triton::gpu::isSharedEncoding(op->getResult(0))) { + return mlir::failure(); + } + rewriter.replaceOpWithNewOp( + op, op->getResult(0).getType()); + return mlir::success(); + } + // cvt(insert_slice(x), type2) -> insert_slice(cvt(x, type2)) + auto insert_slice = dyn_cast(arg); + if (insert_slice) { + if (!triton::gpu::isSharedEncoding(op->getResult(0))) { + return mlir::failure(); + } + auto newType = op->getResult(0).getType().cast(); + // Ensure that the new insert_slice op is placed in the same place as + // the old insert_slice op. Otherwise, the new insert_slice op may be + // placed after the async_wait op, which is not allowed. + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(insert_slice); + auto newArg = rewriter.create( + op->getLoc(), newType, insert_slice.getDst()); + rewriter.replaceOpWithNewOp( + op, newType, insert_slice.getSrc(), newArg.getResult(), + insert_slice.getIndex(), insert_slice.getMask(), + insert_slice.getOther(), insert_slice.getCache(), + insert_slice.getEvict(), insert_slice.getIsVolatile(), + insert_slice.getAxis()); + return mlir::success(); + } + // cvt(extract_slice(x), type2) -> extract_slice(cvt(x, type2)) + auto extract_slice = dyn_cast(arg); + if (extract_slice) { + if (!triton::gpu::isSharedEncoding(op->getResult(0))) { + return mlir::failure(); + } + auto origType = + extract_slice.getSource().getType().cast(); + auto newType = RankedTensorType::get( + origType.getShape(), origType.getElementType(), + op->getResult(0).getType().cast().getEncoding()); + auto origResType = op->getResult(0).getType().cast(); + auto resType = RankedTensorType::get( + origResType.getShape(), origResType.getElementType(), + extract_slice.getType().cast().getEncoding()); + // Ensure that the new extract_slice op is placed in the same place as + // the old extract_slice op. Otherwise, the new extract_slice op may be + // placed after the async_wait op, which is not allowed. + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(extract_slice); + auto newArg = rewriter.create( + op->getLoc(), newType, extract_slice.getSource()); + rewriter.replaceOpWithNewOp( + op, resType, newArg.getResult(), extract_slice.getOffsets(), + extract_slice.getSizes(), extract_slice.getStrides(), + extract_slice.getStaticOffsets(), extract_slice.getStaticSizes(), + extract_slice.getStaticStrides()); + return mlir::success(); } - auto origType = - extract_slice.getSource().getType().cast(); - auto newType = RankedTensorType::get( - origType.getShape(), origType.getElementType(), - op->getResult(0).getType().cast().getEncoding()); - auto origResType = op->getResult(0).getType().cast(); - auto resType = RankedTensorType::get( - origResType.getShape(), origResType.getElementType(), - extract_slice.getType().cast().getEncoding()); - // Ensure that the new extract_slice op is placed in the same place as - // the old extract_slice op. Otherwise, the new extract_slice op may be - // placed after the async_wait op, which is not allowed. - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPoint(extract_slice); - auto newArg = rewriter.create( - op->getLoc(), newType, extract_slice.getSource()); - rewriter.replaceOpWithNewOp( - op, resType, newArg.getResult(), extract_slice.getOffsets(), - extract_slice.getSizes(), extract_slice.getStrides(), - extract_slice.getStaticOffsets(), extract_slice.getStaticSizes(), - extract_slice.getStaticStrides()); - return mlir::success(); - } - // cvt(cvt(x, type1), type2) -> cvt(x, type2) - if (llvm::isa(arg)) { - if (arg->getOperand(0).getDefiningOp() && - !triton::gpu::isSharedEncoding(arg->getOperand(0)) && - triton::gpu::isSharedEncoding(op.getOperand()) && - !triton::gpu::isSharedEncoding(op.getResult())) { - return mlir::failure(); + // cvt(cvt(x, type1), type2) -> cvt(x, type2) + if (llvm::isa(arg)) { + if (arg->getOperand(0).getDefiningOp() && + !triton::gpu::isSharedEncoding(arg->getOperand(0)) && + triton::gpu::isSharedEncoding(op.getOperand()) && + !triton::gpu::isSharedEncoding(op.getResult())) { + return mlir::failure(); + } + if (triton::gpu::isSharedEncoding(op.getOperand()) && + triton::gpu::isSharedEncoding(op.getResult())) { + return mlir::failure(); + } + auto srcType = op.getOperand().getType().cast(); + auto srcShared = + srcType.getEncoding().dyn_cast(); + if (srcShared && srcShared.getVec() > 1) + return mlir::failure(); + rewriter.replaceOpWithNewOp( + op, op->getResultTypes().front(), arg->getOperand(0)); + return mlir::success(); } - if (triton::gpu::isSharedEncoding(op.getOperand()) && - triton::gpu::isSharedEncoding(op.getResult())) { - return mlir::failure(); + // cvt(type1, splat(type2, x)) -> splat(type1, x) + if (auto splat = llvm::dyn_cast(arg)) { + rewriter.replaceOpWithNewOp(op, op->getResultTypes(), + splat.getSrc()); + return mlir::success(); } - auto srcType = op.getOperand().getType().cast(); - auto srcShared = - srcType.getEncoding().dyn_cast(); - if (srcShared && srcShared.getVec() > 1) - return mlir::failure(); - rewriter.replaceOpWithNewOp( - op, op->getResultTypes().front(), arg->getOperand(0)); - return mlir::success(); - } - // cvt(type1, splat(type2, x)) -> splat(type1, x) - if (auto splat = llvm::dyn_cast(arg)) { - rewriter.replaceOpWithNewOp(op, op->getResultTypes(), - splat.getSrc()); - return mlir::success(); - } - // cvt(type1, make_range(type2, x)) -> make_range(type1, x) - if (auto range = llvm::dyn_cast(arg)) { - rewriter.replaceOpWithNewOp( - op, op->getResultTypes(), range.getStart(), range.getEnd()); - return mlir::success(); - } - // cvt(type, constant) -> constant - if (auto cst = llvm::dyn_cast(arg)) - if (auto ret = cst.getValue().dyn_cast()) { - auto ty = op->getResultTypes().front().cast(); - auto newRet = SplatElementsAttr::get(ty, ret.getSplatValue()); - rewriter.replaceOpWithNewOp(op, newRet); + // cvt(type1, make_range(type2, x)) -> make_range(type1, x) + if (auto range = llvm::dyn_cast(arg)) { + rewriter.replaceOpWithNewOp( + op, op->getResultTypes(), range.getStart(), range.getEnd()); return mlir::success(); } - return mlir::failure(); + // cvt(type, constant) -> constant + if (auto cst = llvm::dyn_cast(arg)) + if (auto ret = cst.getValue().dyn_cast()) { + auto ty = op->getResultTypes().front().cast(); + auto newRet = + SplatElementsAttr::get(ty, ret.getSplatValue()); + rewriter.replaceOpWithNewOp(op, newRet); + return mlir::success(); + } + return mlir::failure(); + } +}; + +void ConvertLayoutOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(context); + patterns.add(context); } //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp index c4cdaa1e1c4d..0b0dd9442990 100644 --- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp @@ -241,7 +241,8 @@ static bool hasConvertToMMATransisitiveUse(Operation *op, Attribute encoding) { static bool isLayoutAnchor(Operation *op) { if (isa(op)) return isExpensiveLoadOrStore(op); - if (isa(op)) + if (isa(op)) return true; return false; } @@ -258,7 +259,7 @@ void LayoutPropagation::initAnchorLayout() { if (tensorType.getEncoding().isa() && !hasConvertToMMATransisitiveUse(op, tensorType.getEncoding())) continue; - layouts.insert({result, tensorType.getEncoding()}); + layouts.insert({result, LayoutInfo(tensorType.getEncoding())}); } } } @@ -355,14 +356,20 @@ void LayoutPropagation::propagateLayout() { void LayoutPropagation::resolveConflicts() { for (auto &it : layouts) { + Operation *op = it.first.getDefiningOp(); LayoutInfo &info = it.second; if (info.encodings.size() <= 1) continue; // Hacky resolve, prefer block encoding. // TODO: add a proper heuristic. + int maxSizePerThread = 1; Attribute encoding = *info.encodings.begin(); + bool isLoadOrStore = + op && isa(op); for (Attribute e : info.encodings) { - if (e.isa()) { + if ((isLoadOrStore && e.isa()) || + (!isLoadOrStore && e.isa())) { encoding = e; break; } From f6828e1a6f3d768e6e2b2f6858c92bbdeac865a0 Mon Sep 17 00:00:00 2001 From: Christian Sigg Date: Mon, 11 Sep 2023 16:16:54 +0200 Subject: [PATCH 022/122] [Backend] Make `ConvertTritonGPUToLLVMPass`'s `tmaMetadata` a member (#2271) .. instead of an option. This partially addresses https://github.com/openai/triton/issues/2265 to no longer crash when printing a pass pipeline in textual form. It is not a proper solution for the fact that pass results should be stored in the IR and not in a pointer argument. --- .../triton/Conversion/TritonGPUToLLVM/Passes.td | 3 --- .../TritonGPUToLLVM/TritonGPUToLLVMPass.h | 3 ++- .../TritonGPUToLLVM/TritonGPUToLLVMPass.cpp | 14 +++++++++++--- lib/Target/LLVMIR/LLVMIRTranslation.cpp | 2 +- 4 files changed, 14 insertions(+), 8 deletions(-) diff --git a/include/triton/Conversion/TritonGPUToLLVM/Passes.td b/include/triton/Conversion/TritonGPUToLLVM/Passes.td index f94b8d30ae73..3d4e1d4e3c4a 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Passes.td +++ b/include/triton/Conversion/TritonGPUToLLVM/Passes.td @@ -27,9 +27,6 @@ def ConvertTritonGPUToLLVM : Pass<"convert-triton-gpu-to-llvm", "mlir::ModuleOp" Option<"computeCapability", "compute-capability", "int32_t", /*default*/"80", "device compute capability">, - Option<"tmaMetadata", "tma-metadata", - "mlir::triton::gpu::TMAMetadataTy*", /*default*/"nullptr", - "tma metadata to the runtime">, Option<"target", "target", "enum Target", "mlir::triton::Target::Default", "compile for target compatible LLVM", "llvm::cl::values(" diff --git a/include/triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h b/include/triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h index 3be5c9009014..54f26145b335 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h +++ b/include/triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h @@ -21,7 +21,8 @@ enum Target { NVVM, ROCDL, Default = NVVM }; std::unique_ptr> createConvertTritonGPUToLLVMPass(); std::unique_ptr> -createConvertTritonGPUToLLVMPass(const ConvertTritonGPUToLLVMOptions &options); +createConvertTritonGPUToLLVMPass(int32_t computeCapability, Target target, + mlir::triton::gpu::TMAMetadataTy *tmaMetadata); } // namespace triton diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp index b2f21f0c0a10..29aece938b5c 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp @@ -387,6 +387,11 @@ struct ConvertTritonGPUToLLVM using ConvertTritonGPUToLLVMBase< ConvertTritonGPUToLLVM>::ConvertTritonGPUToLLVMBase; + ConvertTritonGPUToLLVM(int32_t computeCapability, Target target, + mlir::triton::gpu::TMAMetadataTy *tmaMetadata) + : ConvertTritonGPUToLLVMBase({computeCapability, target}), + tmaMetadata(tmaMetadata) {} + void runOnOperation() override { MLIRContext *context = &getContext(); ModuleOp mod = getOperation(); @@ -576,6 +581,7 @@ struct ConvertTritonGPUToLLVM DenseMap>, CacheKeyDenseMapInfo> indexCache; + mlir::triton::gpu::TMAMetadataTy *tmaMetadata = nullptr; void initSharedMemory(ModuleAllocation &allocation, TritonGPUToLLVMTypeConverter &typeConverter) { @@ -869,9 +875,11 @@ namespace triton { std::unique_ptr> createConvertTritonGPUToLLVMPass() { return std::make_unique(); } -std::unique_ptr> -createConvertTritonGPUToLLVMPass(const ConvertTritonGPUToLLVMOptions &options) { - return std::make_unique(options); +std::unique_ptr> createConvertTritonGPUToLLVMPass( + int32_t computeCapability, Target target, + mlir::triton::gpu::TMAMetadataTy *tmaMetadata) { + return std::make_unique(computeCapability, target, + tmaMetadata); } } // namespace triton diff --git a/lib/Target/LLVMIR/LLVMIRTranslation.cpp b/lib/Target/LLVMIR/LLVMIRTranslation.cpp index 16a7f69f992d..6d64fcbb1a29 100644 --- a/lib/Target/LLVMIR/LLVMIRTranslation.cpp +++ b/lib/Target/LLVMIR/LLVMIRTranslation.cpp @@ -353,7 +353,7 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext, pm.addPass(mlir::createConvertSCFToCFPass()); pm.addPass(mlir::createConvertIndexToLLVMPass()); pm.addPass( - createConvertTritonGPUToLLVMPass({computeCapability, &tmaInfos, target})); + createConvertTritonGPUToLLVMPass(computeCapability, target, &tmaInfos)); pm.addPass(createConvertNVGPUToLLVMPass()); pm.addPass(mlir::createArithToLLVMConversionPass()); pm.addPass(mlir::createCanonicalizerPass()); From 28d4c3bdb4d100ac59c539035e98f00301f30ffd Mon Sep 17 00:00:00 2001 From: Lixun Zhang Date: Mon, 11 Sep 2023 13:02:56 -0500 Subject: [PATCH 023/122] [BACKEND] Make sure `getAxisBlockStride` does not return 0 (#2273) This can happen when the CTA shape is larger than the tensor shape along the non-axis dim during scanOp lowering. --- lib/Analysis/Utility.cpp | 5 +++-- python/test/unit/language/test_core.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 3beb816d2057..f3acb8963c22 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -317,8 +317,9 @@ unsigned ScanLoweringHelper::getAxisBlockStride() { for (unsigned dim : order) { if (dim == getAxis()) return stride; - stride *= type.getShape()[dim] / - (sizePerThreads[dim] * threadsPerWarp[dim] * warpsPerCTA[dim]); + stride *= ceil(type.getShape()[dim], sizePerThreads[dim] * + threadsPerWarp[dim] * + warpsPerCTA[dim]); } llvm_unreachable("Axis not found in order"); } diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 9961dc51cff5..dbeca94c48c2 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1750,7 +1750,7 @@ def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexp ] -@pytest.mark.parametrize("M, N", [[32, 32], [32, 64], [64, 32]]) +@pytest.mark.parametrize("M, N", [[32, 16], [32, 32], [32, 64], [64, 32]]) @pytest.mark.parametrize("src_layout", scan_layouts) @pytest.mark.parametrize("axis", [0, 1]) def test_scan_layouts(M, N, src_layout, axis, device): From 5231d57c71ba2ddfea9bcd1419e2ab30a4650dcd Mon Sep 17 00:00:00 2001 From: jon-chuang <9093549+jon-chuang@users.noreply.github.com> Date: Tue, 12 Sep 2023 03:31:17 +0800 Subject: [PATCH 024/122] [TESTS] replace deprecated `torch.testing.assert_allclose` (#2250) Prior to this PR, matmul on sm_89 (RTX 4070) (`test/unit/operators/test_matmul.py::test_op`) would result in test failure due to too strict atol/rtol. To avoid having to choose strictness ourselves, and to have better defaults based on dtype, use the non-deprecated torch testing util. See: https://github.com/pytorch/pytorch/issues/61844 Replace: https://github.com/openai/triton/pull/2242 --- .../test/regression/test_functional_regressions.py | 2 +- .../test_persistent_warp_specialized_gemm.py | 6 +++--- python/test/unit/language/test_block_pointer.py | 2 +- python/test/unit/language/test_core.py | 6 +++--- python/test/unit/operators/test_blocksparse.py | 14 +++++++------- python/test/unit/operators/test_cross_entropy.py | 4 ++-- python/test/unit/operators/test_flash_attention.py | 8 ++++---- python/test/unit/operators/test_inductor.py | 4 ++-- python/test/unit/operators/test_matmul.py | 2 +- 9 files changed, 24 insertions(+), 24 deletions(-) diff --git a/python/test/regression/test_functional_regressions.py b/python/test/regression/test_functional_regressions.py index 684cbfb4dd6f..b873db7a3fa0 100644 --- a/python/test/regression/test_functional_regressions.py +++ b/python/test/regression/test_functional_regressions.py @@ -227,4 +227,4 @@ def grid(META): b.stride(0), b.stride(1), triton_output.stride(0), triton_output.stride(1), BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, type=type, num_stages=num_stages) - torch.testing.assert_allclose(torch_output, triton_output, rtol=1e-2, atol=1e-2) + torch.testing.assert_close(torch_output, triton_output, rtol=1e-2, atol=1e-2) diff --git a/python/test/unit/hopper/test_persistent_warp_specialized_gemm.py b/python/test/unit/hopper/test_persistent_warp_specialized_gemm.py index fd7c14e6c85a..d66c3b7952fc 100644 --- a/python/test/unit/hopper/test_persistent_warp_specialized_gemm.py +++ b/python/test/unit/hopper/test_persistent_warp_specialized_gemm.py @@ -149,7 +149,7 @@ def test_user_defined_persistent_non_warp_specialized_gemm(M, N, K, BLOCK_M, BLO static_persistent_matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c, M=M, N=N, K=K, stride_am=a.stride(0), stride_ak=a.stride(1), stride_bk=b.stride(0), stride_bn=b.stride(1), stride_cm=c.stride(0), stride_cn=c.stride(1), BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, NUM_SM=num_SMs, num_warps=NUM_WARPS, num_ctas=NUM_CTAS) th_c = torch.matmul(a, b) - torch.testing.assert_allclose(th_c, c, atol=1e-2, rtol=0) + torch.testing.assert_close(th_c, c, atol=1e-2, rtol=0, check_dtype=False) @triton.jit @@ -300,7 +300,7 @@ def test_non_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K enable_warp_specialization=True) th_c = torch.matmul(a, b) - torch.testing.assert_allclose(th_c, c, atol=1e-2, rtol=0) + torch.testing.assert_close(th_c, c, atol=1e-2, rtol=0, check_dtype=False) @triton.jit @@ -456,7 +456,7 @@ def test_user_defined_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N enable_warp_specialization=True) th_c = torch.matmul(a, b) - torch.testing.assert_allclose(th_c, c, atol=1e-2, rtol=0) + torch.testing.assert_close(th_c, c, atol=1e-2, rtol=0, check_dtype=False) @triton.jit diff --git a/python/test/unit/language/test_block_pointer.py b/python/test/unit/language/test_block_pointer.py index 147249076181..3cc4bdced339 100644 --- a/python/test/unit/language/test_block_pointer.py +++ b/python/test/unit/language/test_block_pointer.py @@ -99,4 +99,4 @@ def test_block_ptr_matmul_no_scf(shape, num_warps): BLOCK_M=m, BLOCK_N=n, BLOCK_K=k, num_warps=num_warps) golden = torch.matmul(a, b) - torch.testing.assert_allclose(c, golden) + torch.testing.assert_close(c, golden, check_dtype=False) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index dbeca94c48c2..f962facb4696 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -881,7 +881,7 @@ def abs_kernel(X, Z, SIZE: tl.constexpr): f32_tensor = convert_float_to_float32(f8_tensor, in_dtype) expect = f32_tensor.abs() actual_f8 = convert_float_to_float32(out_f8, in_dtype) - torch.testing.assert_allclose(actual_f8, expect) + torch.testing.assert_close(actual_f8, expect, equal_nan=True) # ---------------- @@ -2594,7 +2594,7 @@ def _kernel(in_ptr, out_ptr, in_size: tl.constexpr, out_size: tl.constexpr): reference_out = torch.cat((input, torch.ones((size_diff,), dtype=dtype, device=device))) # print((output - reference_out).nonzero()) - torch.testing.assert_allclose(output, reference_out) + torch.testing.assert_close(output, reference_out) # Testing masked loads with an intermate copy to shared memory run. @@ -2649,7 +2649,7 @@ def _kernel(in1_ptr, in2_ptr, output_ptr, M=M, N=N, K=K) reference_out = torch.matmul(in1, in2) - torch.testing.assert_allclose(out, reference_out, atol=1e-2, rtol=0) + torch.testing.assert_close(out, reference_out, atol=1e-2, rtol=0) @pytest.mark.parametrize("cache", ["", ".ca", ".cg"]) diff --git a/python/test/unit/operators/test_blocksparse.py b/python/test/unit/operators/test_blocksparse.py index 5f94cd8b31bf..7e6f820a374d 100644 --- a/python/test/unit/operators/test_blocksparse.py +++ b/python/test/unit/operators/test_blocksparse.py @@ -86,9 +86,9 @@ def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K= da_tri = a_tri.grad db_tri = b_tri.grad # compare - torch.testing.assert_allclose(c_ref, c_tri) - torch.testing.assert_allclose(da_ref, da_tri) - torch.testing.assert_allclose(db_ref, db_tri) + torch.testing.assert_close(c_ref, c_tri) + torch.testing.assert_close(da_ref, da_tri) + torch.testing.assert_close(db_ref, db_tri) configs = [ @@ -138,8 +138,8 @@ def test_softmax(BLOCK, WIDTH, is_dense, Z=2, H=2, is_causal=True, scale=0.4): out_tri.backward(dout_tri) da_tri = a_tri.grad # compare - torch.testing.assert_allclose(out_tri, out_ref) - torch.testing.assert_allclose(da_tri, da_ref) + torch.testing.assert_close(out_tri, out_ref, equal_nan=True) + torch.testing.assert_close(da_tri, da_ref, equal_nan=True) @pytest.mark.parametrize("block", [16, 32, 64]) @@ -195,9 +195,9 @@ def test_attention_fwd_bwd( # comparison # print(f"Triton loss {loss} and torch loss {torch_loss}. Also checking grads...") - torch.testing.assert_allclose(loss, torch_loss, atol=1e-3, rtol=0) + torch.testing.assert_close(loss, torch_loss, atol=1e-3, rtol=0) for g1, g2 in zip(grads, torch_grads): - torch.testing.assert_allclose(g1, g2) + torch.testing.assert_close(g1, g2) @pytest.mark.parametrize("block", [16, 32, 64]) diff --git a/python/test/unit/operators/test_cross_entropy.py b/python/test/unit/operators/test_cross_entropy.py index f4e40d3a65a7..be59fc42ab57 100644 --- a/python/test/unit/operators/test_cross_entropy.py +++ b/python/test/unit/operators/test_cross_entropy.py @@ -25,7 +25,7 @@ def test_op(M, N, dtype, mode): tt_y = triton.ops.cross_entropy(x, idx) th_y = torch.nn.CrossEntropyLoss(reduction="none")(x, idx) if mode == 'forward': - torch.testing.assert_allclose(th_y, tt_y) + torch.testing.assert_close(th_y, tt_y) # backward pass elif mode == 'backward': dy = torch.randn_like(tt_y) @@ -37,4 +37,4 @@ def test_op(M, N, dtype, mode): th_y.backward(dy) th_dx = x.grad.clone() - torch.testing.assert_allclose(th_dx, tt_dx) + torch.testing.assert_close(th_dx, tt_dx) diff --git a/python/test/unit/operators/test_flash_attention.py b/python/test/unit/operators/test_flash_attention.py index 4bacf53b71ca..b6f74f2fc33d 100644 --- a/python/test/unit/operators/test_flash_attention.py +++ b/python/test/unit/operators/test_flash_attention.py @@ -51,7 +51,7 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par): tri_dq, q.grad = q.grad.clone(), None # compare atol = 1e-1 if dtype == torch.bfloat16 else 1e-2 - torch.testing.assert_allclose(ref_out, tri_out, atol=atol, rtol=0) - torch.testing.assert_allclose(ref_dv, tri_dv, atol=atol, rtol=0) - torch.testing.assert_allclose(ref_dk, tri_dk, atol=atol, rtol=0) - torch.testing.assert_allclose(ref_dq, tri_dq, atol=atol, rtol=0) + torch.testing.assert_close(ref_out, tri_out, atol=atol, rtol=0) + torch.testing.assert_close(ref_dv, tri_dv, atol=atol, rtol=0) + torch.testing.assert_close(ref_dk, tri_dk, atol=atol, rtol=0) + torch.testing.assert_close(ref_dq, tri_dq, atol=atol, rtol=0) diff --git a/python/test/unit/operators/test_inductor.py b/python/test/unit/operators/test_inductor.py index f7e2ce2aa7e0..fa157d2c9771 100644 --- a/python/test/unit/operators/test_inductor.py +++ b/python/test/unit/operators/test_inductor.py @@ -52,7 +52,7 @@ def triton_(in_out_ptr0, in_out_ptr1, in_ptr0, in_ptr1, in_ptr2, in_ptr3, xnumel arg8_1 = torch.rand(64, device="cuda") arg9_1 = torch.rand(64, device="cuda") triton_[(512,)](buf14, buf16, arg114_1, arg115_1, arg8_1, arg9_1, 512, 4096, 1, 2048) - torch.testing.assert_allclose(buf16.mean().item(), buf14.mean().item(), atol=1e-7, rtol=0) + torch.testing.assert_close(buf16.mean().item(), buf14.mean().item(), atol=1e-7, rtol=0) def test_avg_pool_bw(): @@ -152,4 +152,4 @@ def triton_(in_ptr0, out_ptr0, XBLOCK: tl.constexpr): out_ref[:, :, 1:7, 0::7] = 2 / 3 out_ref[:, :, 0::7, 1:7] = 2 / 3 out_ref[:, :, 0::7, 0::7] = 4 / 9 - torch.testing.assert_allclose(out, out_ref) + torch.testing.assert_close(out, out_ref) diff --git a/python/test/unit/operators/test_matmul.py b/python/test/unit/operators/test_matmul.py index a7afa02f10b3..19b5e0f050a2 100644 --- a/python/test/unit/operators/test_matmul.py +++ b/python/test/unit/operators/test_matmul.py @@ -177,6 +177,6 @@ def init_input(m, n, dtype): if b_fp8: b = triton.reinterpret(b, getattr(tl, BDTYPE)) tt_c = triton.ops.matmul(a, b, None, ALLOW_TF32) - torch.testing.assert_allclose(th_c, tt_c, atol=0, rtol=0) + torch.testing.assert_close(th_c, tt_c) except triton.OutOfResources as e: pytest.skip(str(e)) From ec4a968d44cffe734b2e735da880f20aa749df86 Mon Sep 17 00:00:00 2001 From: "danny.jang" Date: Tue, 12 Sep 2023 04:31:30 +0900 Subject: [PATCH 025/122] [TESTS] Enhance benchmark flexibility (#2239) User can pass custom arguments to benchmarks. For example, user can pass `dtype` which will be used to create tensors in a benchmark. Co-authored-by: Keren Zhou --- python/triton/testing.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/triton/testing.py b/python/triton/testing.py index c4357bd243c0..69ee467d6b07 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -266,7 +266,7 @@ def __init__(self, fn, benchmarks): self.fn = fn self.benchmarks = benchmarks - def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: bool): + def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: bool, **kwrags): import os import matplotlib.pyplot as plt @@ -287,7 +287,7 @@ def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: b row_mean, row_min, row_max = [], [], [] for y in bench.line_vals: - ret = self.fn(**x_args, **{bench.line_arg: y}, **bench.args) + ret = self.fn(**x_args, **{bench.line_arg: y}, **bench.args, **kwrags) try: y_mean, y_min, y_max = ret except TypeError: @@ -328,14 +328,14 @@ def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: b if save_path: df.to_csv(os.path.join(save_path, f"{bench.plot_name}.csv"), float_format='%.1f', index=False) - def run(self, show_plots=False, print_data=False, save_path=''): + def run(self, show_plots=False, print_data=False, save_path='', **kwargs): has_single_bench = isinstance(self.benchmarks, Benchmark) benchmarks = [self.benchmarks] if has_single_bench else self.benchmarks if save_path: html = open(os.path.join(save_path, "results.html"), "w") html.write("\n") for bench in benchmarks: - self._run(bench, save_path, show_plots, print_data) + self._run(bench, save_path, show_plots, print_data, **kwargs) if save_path: html.write(f"\n") if save_path: From a9db6b94b97100815e27faf9a3e59720e8d90e28 Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Mon, 11 Sep 2023 16:30:13 -0700 Subject: [PATCH 026/122] Remove wrong dependency between TritonGPU and NVGPU dialect (#2276) --- bin/RegisterTritonDialects.h | 3 ++- include/triton/Dialect/TritonGPU/IR/Dialect.h | 1 - include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td | 1 - lib/Conversion/TritonGPUToLLVM/Utility.h | 1 + lib/Dialect/NVGPU/IR/CMakeLists.txt | 1 + python/test/unit/operators/test_cross_entropy.py | 6 ++++-- unittest/Conversion/TritonGPUToLLVM/EmitIndicesTest.cpp | 1 + 7 files changed, 9 insertions(+), 5 deletions(-) diff --git a/bin/RegisterTritonDialects.h b/bin/RegisterTritonDialects.h index 5cf1c3a25707..d3e0ee102a51 100644 --- a/bin/RegisterTritonDialects.h +++ b/bin/RegisterTritonDialects.h @@ -40,5 +40,6 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) { mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect, mlir::triton::gpu::TritonGPUDialect, mlir::math::MathDialect, mlir::arith::ArithDialect, mlir::scf::SCFDialect, - mlir::gpu::GPUDialect>(); + mlir::gpu::GPUDialect, mlir::LLVM::LLVMDialect, + mlir::triton::nvgpu::NVGPUDialect>(); } diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h index 6ba2fe711816..24d971701242 100644 --- a/include/triton/Dialect/TritonGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -7,7 +7,6 @@ #include "mlir/IR/Dialect.h" // TritonGPU depends on Triton -#include "triton/Dialect/NVGPU/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h.inc" diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td index 7533044b41a6..136b90ee65e5 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td @@ -16,7 +16,6 @@ def TritonGPU_Dialect : Dialect { let dependentDialects = [ "triton::TritonDialect", - "mlir::triton::nvgpu::NVGPUDialect", "mlir::gpu::GPUDialect", "tensor::TensorDialect", ]; diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.h b/lib/Conversion/TritonGPUToLLVM/Utility.h index f074b250a7bb..13aef00a6d0b 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.h +++ b/lib/Conversion/TritonGPUToLLVM/Utility.h @@ -6,6 +6,7 @@ #include "triton/Analysis/Utility.h" #include "triton/Conversion/MLIRTypes.h" #include "triton/Conversion/TritonGPUToLLVM/PTXAsmFormat.h" +#include "triton/Dialect/NVGPU/IR/Dialect.h" // Shortcuts for some commonly used LLVM ops to keep code simple and intuitive // Operators diff --git a/lib/Dialect/NVGPU/IR/CMakeLists.txt b/lib/Dialect/NVGPU/IR/CMakeLists.txt index 4e9e1ada172c..24a93ce58ea3 100644 --- a/lib/Dialect/NVGPU/IR/CMakeLists.txt +++ b/lib/Dialect/NVGPU/IR/CMakeLists.txt @@ -6,4 +6,5 @@ add_mlir_dialect_library(NVGPUIR NVGPUAttrDefsIncGen LINK_LIBS PUBLIC + MLIRLLVMDialect ) diff --git a/python/test/unit/operators/test_cross_entropy.py b/python/test/unit/operators/test_cross_entropy.py index be59fc42ab57..12739e56722c 100644 --- a/python/test/unit/operators/test_cross_entropy.py +++ b/python/test/unit/operators/test_cross_entropy.py @@ -36,5 +36,7 @@ def test_op(M, N, dtype, mode): x.grad = None th_y.backward(dy) th_dx = x.grad.clone() - - torch.testing.assert_close(th_dx, tt_dx) + if dtype == 'float16': + torch.testing.assert_close(th_dx, tt_dx, rtol=0.001, atol=0.001) + else: + torch.testing.assert_close(th_dx, tt_dx) diff --git a/unittest/Conversion/TritonGPUToLLVM/EmitIndicesTest.cpp b/unittest/Conversion/TritonGPUToLLVM/EmitIndicesTest.cpp index 90e6ef8c3d1d..20603cd2e41c 100644 --- a/unittest/Conversion/TritonGPUToLLVM/EmitIndicesTest.cpp +++ b/unittest/Conversion/TritonGPUToLLVM/EmitIndicesTest.cpp @@ -22,6 +22,7 @@ */ #include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "DumpLayout.h" From 8da27c1c9596a7a9116c6234bc7b995e8b291ca3 Mon Sep 17 00:00:00 2001 From: Shintaro Iwasaki Date: Mon, 11 Sep 2023 19:28:31 -0700 Subject: [PATCH 027/122] [Build] Fix very minor compilation problems (#2277) This PR fixes a few very minor compilation issues found in internal deployment at Meta. It looks like nit-picking, but it'd be really appreciated if it could be addressed in OSS Triton (to reduce differences from OSS), and we believe these changes are not bad in general. Neither performance nor functionality is affected by this PR. 1. Type cast in `python/triton/runtime/backends/cuda.c`. Implicit `void *` -> `cuuint{32,64}_t *` cast is not allowed by many compilers (with certain flags). It'd be nice to add an explicit cast (like `backends/hip.c`). 2. Inconsistent include path specification in `lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/WGMMA.cpp`. Unlike other `DotOpToLLVM/*.cpp`, include paths used in `WGMMA.cpp` are not relative. This is problematic in some compilation settings since a compiler somehow needs to find headers in a parent directory. It'd be great to use a relative path, like other source files in Triton. cc: @yuguo68 --- lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/WGMMA.cpp | 4 ++-- python/triton/runtime/backends/cuda.c | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/WGMMA.cpp b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/WGMMA.cpp index f03d09788c9d..9f943a615fa1 100644 --- a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/WGMMA.cpp +++ b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/WGMMA.cpp @@ -21,8 +21,8 @@ * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ -#include "DotOpToLLVM.h" -#include "Utility.h" +#include "../DotOpToLLVM.h" +#include "../Utility.h" using namespace mlir; using namespace mlir::triton; diff --git a/python/triton/runtime/backends/cuda.c b/python/triton/runtime/backends/cuda.c index 7dd60528f28f..278310473597 100644 --- a/python/triton/runtime/backends/cuda.c +++ b/python/triton/runtime/backends/cuda.c @@ -330,7 +330,7 @@ static PyObject *memFree(PyObject *self, PyObject *args) { // Helper function to convert a Python list to a cuuint64_t array static cuuint64_t *list_to_cuuint64_array(PyObject *listObj) { Py_ssize_t len = PyList_Size(listObj); - cuuint64_t *array = malloc(len * sizeof(cuuint64_t)); + cuuint64_t *array = (cuuint64_t *)malloc(len * sizeof(cuuint64_t)); for (Py_ssize_t i = 0; i < len; i++) { PyObject *item = PyList_GetItem(listObj, i); array[i] = (cuuint64_t)PyLong_AsUnsignedLongLong(item); @@ -341,7 +341,7 @@ static cuuint64_t *list_to_cuuint64_array(PyObject *listObj) { // Helper function to convert a Python list to a cuuint32_t array static cuuint32_t *list_to_cuuint32_array(PyObject *listObj) { Py_ssize_t len = PyList_Size(listObj); - cuuint32_t *array = malloc(len * sizeof(cuuint32_t)); + cuuint32_t *array = (cuuint32_t *)malloc(len * sizeof(cuuint32_t)); for (Py_ssize_t i = 0; i < len; i++) { PyObject *item = PyList_GetItem(listObj, i); array[i] = (cuuint32_t)PyLong_AsUnsignedLong(item); From a5e483652b5f8c72d96fdc4da28832259e12c1cc Mon Sep 17 00:00:00 2001 From: Lixun Zhang Date: Mon, 11 Sep 2023 21:35:18 -0500 Subject: [PATCH 028/122] [NFC] Remove hard-coded warpSize=32 in scanOp lowering (#2272) - To make the development on AMD GPUs a little easier - Also changed `laneId` to `laneIdAxis` in some helper functions in scanOp lowering --- .../TritonGPUToLLVM/ScanOpToLLVM.cpp | 29 +++++++++++-------- 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp index 5f8333e75009..42db21bd05a6 100644 --- a/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp @@ -59,7 +59,7 @@ static void scanThreadContiguousElements(SmallVector &srcValues, // contiguous group of elements. static void warpScan(SmallVector &srcValues, ConversionPatternRewriter &rewriter, - ScanLoweringHelper &helper, Value laneId) { + ScanLoweringHelper &helper, Value laneIdAxis) { Location loc = helper.getLoc(); unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread(); unsigned elementStride = helper.getAxisElementStride(); @@ -76,7 +76,7 @@ static void warpScan(SmallVector &srcValues, Value shfl = shflUpSync(loc, rewriter, acc, i * threadStride); Value tempAcc = acc; accumulate(rewriter, helper.getCombineOp(), tempAcc, shfl); - Value mask = icmp_slt(laneId, i32_val(i)); + Value mask = icmp_slt(laneIdAxis, i32_val(i)); acc = select(mask, acc, tempAcc); } srcValues[srcIndex] = acc; @@ -124,7 +124,8 @@ static void storeWarpAccumulator(SmallVector &srcValues, static void AddPartialReduce(SmallVector &srcValues, ConversionPatternRewriter &rewriter, ScanLoweringHelper &helper, Value sharedMemoryPtr, - Value warpId, Value laneId, Value parallelLaneId) { + Value warpId, Value laneIdAxis, + Value parallelLaneId) { Location loc = helper.getLoc(); unsigned numParallelLane = helper.getNonAxisNumThreadsPerCTA(); unsigned numWarps = helper.getAxisNumWarps(); @@ -133,7 +134,7 @@ static void AddPartialReduce(SmallVector &srcValues, unsigned elementStride = helper.getAxisElementStride(); unsigned threadStride = helper.getAxisThreadStride(); Value maskFirstWarp = icmp_eq(warpId, i32_val(0)); - Value maskFirstLane = icmp_eq(laneId, i32_val(0)); + Value maskFirstLane = icmp_eq(laneIdAxis, i32_val(0)); Value maskFirstThread = and_(maskFirstWarp, maskFirstLane); struct Accumulator { Value acc; @@ -228,7 +229,8 @@ struct ScanOpConversion private: std::tuple getDelinearizedIds(ConversionPatternRewriter &rewriter, - ScanLoweringHelper &helper) const; + ScanLoweringHelper &helper, Value laneId, + Value warpId) const; LogicalResult emitFastScan(triton::ScanOp op, triton::ScanOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const; }; @@ -237,16 +239,12 @@ struct ScanOpConversion // compute a flat id for the parallel dimensions. std::tuple ScanOpConversion::getDelinearizedIds(ConversionPatternRewriter &rewriter, - ScanLoweringHelper &helper) const { + ScanLoweringHelper &helper, Value laneId, + Value warpId) const { auto loc = helper.getLoc(); unsigned axis = helper.getAxis(); auto srcEncoding = helper.getEncoding(); - Value threadId = getThreadId(rewriter, loc); - Value warpSize = i32_val(32); - Value warpId = udiv(threadId, warpSize); - Value laneId = urem(threadId, warpSize); - auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcEncoding); auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcEncoding); auto order = triton::gpu::getOrder(srcEncoding); @@ -281,8 +279,15 @@ ScanOpConversion::emitFastScan(triton::ScanOp op, triton::ScanOpAdaptor adaptor, if (!helper.isSupported()) return failure(); + Value threadId = getThreadId(rewriter, loc); + auto mod = op->getParentOfType(); + unsigned iWarpSize = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + Value warpSize = i32_val(iWarpSize); + Value warpId = udiv(threadId, warpSize); + Value laneId = urem(threadId, warpSize); + auto [laneIdAxis, warpIdAxis, flatIdParallel] = - getDelinearizedIds(rewriter, helper); + getDelinearizedIds(rewriter, helper, laneId, warpId); auto input = adaptor.getOperands()[0]; auto type = op.getOperand(0).getType().cast(); SmallVector srcValues = From bf4f9375a7ab4c37e4901157e1a76f4e77a46465 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Mon, 11 Sep 2023 20:54:29 -0700 Subject: [PATCH 029/122] [FRONTEND] allow mixed precision FP8 matmul on pre-H100 hardware (#2281) --- lib/Analysis/Utility.cpp | 3 ++- lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp | 2 +- python/triton/language/semantic.py | 2 ++ 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index f3acb8963c22..7cabed48a58f 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -384,7 +384,8 @@ bool supportMMA(Value value, int version) { // FP8 is not natively supported on all mma versions but it can always be // promoted to fp16 therefore we can always support it. bool isFP8 = elemTy.isFloat8E5M2() || elemTy.isFloat8E4M3FN() || - elemTy.isFloat8E5M2FNUZ() || elemTy.isFloat8E4M3FNUZ(); + elemTy.isFloat8E5M2FNUZ() || elemTy.isFloat8E4M3FNUZ() || + elemTy.isFloat8E4M3B11FNUZ(); return isFP8 || elemTy.isF16() || elemTy.isBF16() || (elemTy.isF32() && version >= 2) || (elemTy.isInteger(8) && version >= 2); diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp index 29aece938b5c..4ef4f65ea1c7 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp @@ -845,7 +845,7 @@ struct ConvertTritonGPUToLLVM bool isNativeHopperFP8 = AElType.isFloat8E5M2() || AElType.isFloat8E4M3FNUZ(); bool isFP8 = isNativeHopperFP8 || AElType.isFloat8E5M2FNUZ() || - AElType.isFloat8E4M3FN(); + AElType.isFloat8E4M3FN() || AElType.isFloat8E4M3B11FNUZ(); if (!isFP8 || (isNativeHopperFP8 && mmaLayout.isHopper())) return; promoteType = builder.getF16Type(); diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 8cccda9bed92..1df2376da406 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1266,6 +1266,8 @@ def assert_dtypes_valid(lhs_dtype, rhs_dtype, arch): # Checks for cuda arch if arch < 90: assert not lhs_dtype.is_fp8e4nv() and not rhs_dtype.is_fp8e4nv(), "Dot op does not support fp8e4nv on CUDA arch < 90" + if lhs_dtype.is_fp8() and rhs_dtype.is_fp8(): + return assert lhs_dtype == rhs_dtype, f"First input ({lhs_dtype}) and second input ({rhs_dtype}) must have the same dtype!" else: assert not lhs_dtype.is_fp8e4b15() and not rhs_dtype.is_fp8e4b15(), "Dot op does not support fp8e4b15 on CUDA arch >= 90" From ab9da3b2b86931b3dd7d5b15516722b3e3c6352e Mon Sep 17 00:00:00 2001 From: peterbell10 Date: Tue, 12 Sep 2023 04:59:13 +0100 Subject: [PATCH 030/122] [FRONTEND] Fix expand_dims and tl.full to handle scalar tensors (#2275) This fixes a few bugs related to scalar tensors: - `tl.full([], fill_value, dtype)` fails with `TypeError('0d block_type is forbidden')` - `scalar[None]` fails with `TypeError("'constexpr' object is not iterable")` - `scalar[None, None]` fails with `AttributeError("'dtype' object has no attribute 'shape'")` - `scalar.shape` returns `[1]` instead of 0-dim `[]` - Also related, `tl.zeros_like(scalar)` returns a 1d tensor instead of another scalar --- python/test/unit/language/test_core.py | 91 +++++++++++++++++--------- python/triton/language/core.py | 6 +- python/triton/language/semantic.py | 26 +++++--- 3 files changed, 79 insertions(+), 44 deletions(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index f962facb4696..5bde5f1ddd94 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -568,7 +568,7 @@ def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, num_ctas, device): # test broadcast # --------------- @pytest.mark.parametrize("dtype", dtypes_with_bfloat16) -def test_broadcast(dtype): +def test_broadcast(dtype, device): @triton.jit def broadcast_kernel(x_ptr, y_ptr, y_broadcasted_ptr, M: tl.constexpr, N: tl.constexpr): offset1 = tl.arange(0, M) @@ -585,41 +585,42 @@ def broadcast_kernel(x_ptr, y_ptr, y_broadcasted_ptr, M: tl.constexpr, N: tl.con y = numpy_random(N, dtype_str=dtype, rs=rs) _, y_broadcasted_np = np.broadcast_arrays(x, y) - x_tri = to_triton(x, device='cuda', dst_type=dtype) - y_tri = to_triton(y, device='cuda', dst_type=dtype) - y_broadcasted_tri = to_triton(np.empty((M, N), dtype=y_broadcasted_np.dtype), device='cuda', dst_type=dtype) + x_tri = to_triton(x, device=device, dst_type=dtype) + y_tri = to_triton(y, device=device, dst_type=dtype) + y_broadcasted_tri = to_triton(np.empty((M, N), dtype=y_broadcasted_np.dtype), device=device, dst_type=dtype) broadcast_kernel[(1,)](x_tri, y_tri, y_broadcasted_tri, M=M, N=N) assert (y_broadcasted_np == to_numpy(y_broadcasted_tri)).all() +# ---------- +# test slice +# ---------- + + +def test_slice(device): -# --------------- -# test broadcast -# --------------- -@pytest.mark.parametrize("dtype", dtypes_with_bfloat16) -def test_broadcast(dtype, device): @triton.jit - def broadcast_kernel(x_ptr, y_ptr, y_broadcasted_ptr, M: tl.constexpr, N: tl.constexpr): - offset1 = tl.arange(0, M) - offset2 = tl.arange(0, N) - x = tl.load(x_ptr + N * offset1[:, None] + offset2[None, :]) - y = tl.load(y_ptr + offset2) - _, y_broadcasted = tl.broadcast(x, y) - tl.store(y_broadcasted_ptr + N * offset1[:, None] + offset2[None, :], y_broadcasted) + def slice_kernel(XBLOCK: tl.constexpr): + data = tl.arange(0, XBLOCK) + tl.static_assert(data.shape == [XBLOCK]) - M = 32 - N = 64 - rs = RandomState(17) - x = numpy_random((M, N), dtype_str=dtype, rs=rs) - y = numpy_random(N, dtype_str=dtype, rs=rs) - _, y_broadcasted_np = np.broadcast_arrays(x, y) + t = data[None, :] + tl.static_assert(t.shape == [1, XBLOCK]) - x_tri = to_triton(x, device=device, dst_type=dtype) - y_tri = to_triton(y, device=device, dst_type=dtype) - y_broadcasted_tri = to_triton(np.empty((M, N), dtype=y_broadcasted_np.dtype), device=device, dst_type=dtype) + t = data[None, :, None] + tl.static_assert(t.shape == [1, XBLOCK, 1]) + + scalar = tl.full([], 1, tl.int32) + tl.static_assert(scalar.shape == []) + + t = scalar[None] + tl.static_assert(t.shape == [1]) + + t = scalar[None, None] + tl.static_assert(t.shape == [1, 1]) + + slice_kernel[(1,)](XBLOCK=32) - broadcast_kernel[(1,)](x_tri, y_tri, y_broadcasted_tri, M=M, N=N) - assert (y_broadcasted_np == to_numpy(y_broadcasted_tri)).all() # ------------------ # test invalid slice @@ -669,6 +670,14 @@ def expand_dims_kernel(dummy, N: tl.constexpr): t = tl.expand_dims(offset1, (3, 1, 2)) tl.static_assert(t.shape == [N, 1, 1, 1]) + scalar = tl.sum(offset1) + tl.static_assert(scalar.shape == []) + t = tl.expand_dims(scalar, 0) + tl.static_assert(t.shape == [1]) + + t = tl.expand_dims(scalar, -1) + tl.static_assert(t.shape == [1]) + N = 32 dummy_tensor = torch.empty((), device=device) expand_dims_kernel[(1,)](dummy_tensor, N) @@ -689,6 +698,13 @@ def dim_out_of_range2(dummy, N: tl.constexpr): t = tl.expand_dims(offset1, 1) t = tl.expand_dims(offset1, 2) + @triton.jit + def dim_out_of_range3(dummy, N: tl.constexpr): + offset1 = tl.arange(0, 1) + scalar = tl.sum(offset1) + + t = tl.expand_dims(scalar, 1) + @triton.jit def duplicate_dim1(dummy, N: tl.constexpr): offset1 = tl.arange(0, N) @@ -710,6 +726,9 @@ def duplicate_dim2(dummy, N: tl.constexpr): with pytest.raises(triton.CompilationError, match="invalid axis 2"): dim_out_of_range2[(1,)](dummy_tensor, N) + with pytest.raises(triton.CompilationError, match="invalid axis 1"): + dim_out_of_range3[(1,)](dummy_tensor, N) + with pytest.raises(triton.CompilationError, match=r"duplicate axes, normalized axes = \[0, 0\]"): duplicate_dim1[(1,)](dummy_tensor, N) @@ -2467,7 +2486,8 @@ def kernel(Z, X, Y, @pytest.mark.parametrize("dtype_str", int_dtypes + uint_dtypes + float_dtypes + ['bfloat16']) -def test_full(dtype_str, device): +@pytest.mark.parametrize("shape", [(), (1,), (128,)]) +def test_full(dtype_str, shape, device): if dtype_str in uint_dtypes and not hasattr(torch, dtype_str): # PyTorch only has unsigned 8, but not 16, 32, or 64 dtype = getattr(torch, dtype_str[1:]) # uintx -> intx @@ -2478,21 +2498,28 @@ def test_full(dtype_str, device): @triton.jit def kernel_static(out): a = GENERATE_TEST_HERE + tl.static_assert(a.shape == SHAPE) out_ptr = out + tl.arange(0, 128)[:] tl.store(out_ptr, a) @triton.jit def kernel_dynamic(out, val, dtype: tl.constexpr): - a = tl.full((128,), val, dtype) + a = tl.full(SHAPE, val, dtype) + tl.static_assert(a.shape == SHAPE) out_ptr = out + tl.arange(0, 128)[:] tl.store(out_ptr, a) - kernel_static_patched = patch_kernel(kernel_static, {'GENERATE_TEST_HERE': f"tl.full((128,), 2, tl.{dtype_str})"}) + kernel_static_patched = patch_kernel(kernel_static, { + 'GENERATE_TEST_HERE': f"tl.full({shape}, 2, tl.{dtype_str})", + 'SHAPE': str(list(shape)), + }) out_static = torch.zeros((128), dtype=dtype, device=device) kernel_static_patched[(1,)](out_static) - out_dynamic = torch.zeros((128), dtype=dtype, device=device) - kernel_dynamic[(1,)](out_dynamic, 2, getattr(triton.language, dtype_str)) assert torch.all(out_static == 2) + + kernel_dynamic_patched = patch_kernel(kernel_dynamic, {'SHAPE': str(list(shape))}) + out_dynamic = torch.zeros((128), dtype=dtype, device=device) + kernel_dynamic_patched[(1,)](out_dynamic, 2, getattr(triton.language, dtype_str)) assert torch.all(out_dynamic == 2) diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 3caf8ee3fea5..3b9205f9d02a 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -531,9 +531,7 @@ def __init__(self, handle, type: dtype): # IR handle self.handle = handle # Block shape - self.shape = (1, ) - if type.is_block(): - self.shape = type.shape + self.shape = type.shape if type.is_block() else () self.numel = 1 for s in self.shape: self.numel *= s @@ -743,7 +741,7 @@ def __not__(self, _builder=None): @builtin def __getitem__(self, slices, _builder=None): - if isinstance(slices, slice): + if isinstance(slices, (slice, constexpr)): slices = [slices] ret = self for dim, sl in enumerate(slices): diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 1df2376da406..49597e4cb7af 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -501,25 +501,31 @@ def full(shape: List[int], value, dtype: tl.dtype, builder: ir.builder) -> tl.te if isinstance(value, tl.tensor): assert value.numel.value == 1, "only accepts size-1 tensor" value = cast(value, dtype, builder) - ret_ty = tl.block_type(value.dtype, shape) - return tl.tensor(builder.create_splat(value.handle, shape), ret_ty) else: # scalar + if dtype is None: + raise ValueError("dtype must be specified when value is not a tensor") if value == 0: value = builder.get_null_value(dtype.to_ir(builder)) else: get_value_fn = getattr(builder, f"get_{dtype.name}") value = get_value_fn(value) - if dtype is None: - raise ValueError("dtype must be specified when value is not a tensor") - ret_ty = tl.block_type(dtype, shape) - return tl.tensor(builder.create_splat(value, shape), ret_ty) + value = tl.tensor(value, dtype) + + return splat(value, shape, builder) # ===----------------------------------------------------------------------===// # Shape Manipulation # ===----------------------------------------------------------------------===// +def splat(value: tl.tensor, shape: List[int], builder: ir.builder) -> tl.tensor: + assert not value.type.is_block(), "Cannot splat a block tensor" + if len(shape) == 0: + return value + ret_ty = tl.block_type(value.dtype, shape) + return tl.tensor(builder.create_splat(value.handle, shape), ret_ty) + def view(input: tl.tensor, dst_shape: List[int], @@ -544,8 +550,12 @@ def reshape(input: tl.tensor, def expand_dims(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: - dst_shape = list(input.type.shape) + dst_shape = [tl._constexpr_to_value(x) for x in input.shape] dst_shape.insert(axis, 1) + + if not input.type.is_block(): + return splat(input, shape=dst_shape, builder=builder) + ret_ty = tl.block_type(input.type.scalar, dst_shape) return tl.tensor(builder.create_expand_dims(input.handle, axis), ret_ty) @@ -1506,7 +1516,7 @@ def abs(x: tl.tensor, builder: ir.builder) -> tl.tensor: def multiple_of(x: tl.tensor, values: List[int]) -> tl.tensor: - if len(x.shape) != len(values): + if max(1, len(x.shape)) != len(values): raise ValueError("Shape of input to multiple_of does not match the length of values") x.handle.set_attr("tt.divisibility", ir.make_attr(values, x.handle.get_context())) return x From fc5d7e6e7c840b66c2f574148be9a5214875bd72 Mon Sep 17 00:00:00 2001 From: jsh-20 <123707385+jsh-20@users.noreply.github.com> Date: Tue, 12 Sep 2023 17:14:47 +0800 Subject: [PATCH 031/122] =?UTF-8?q?[FRONTEND]=20Improve=20grid=20calculati?= =?UTF-8?q?on=20for=20persistent=20kernels=20to=20hoist=20pe=E2=80=A6=20(#?= =?UTF-8?q?2283)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …rf on problems that need few blocks. constrain the number of launched blocks to what it exactely needs for persistent warp specialized kernel. It's useful when problems need very few blocks. e.g. MxNxK=800x800x60000, f16_f16_f32, block size=128x128x64, non-split-k. Experiments show it can achieve ~16% speedup. --- .../unit/hopper/test_persistent_warp_specialized_gemm.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/test/unit/hopper/test_persistent_warp_specialized_gemm.py b/python/test/unit/hopper/test_persistent_warp_specialized_gemm.py index d66c3b7952fc..bd1e70ec58a9 100644 --- a/python/test/unit/hopper/test_persistent_warp_specialized_gemm.py +++ b/python/test/unit/hopper/test_persistent_warp_specialized_gemm.py @@ -141,7 +141,7 @@ def test_user_defined_persistent_non_warp_specialized_gemm(M, N, K, BLOCK_M, BLO c = torch.empty((M, N), device=a.device, dtype=torch.float32) num_SMs = torch.cuda.get_device_properties('cuda').multi_processor_count - grid = lambda META: (num_SMs,) + grid = lambda META: (min(META['NUM_SM'], triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])),) if USE_TMA: static_persistent_tma_matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c, M=M, N=N, K=K, stride_am=a.stride(0), stride_ak=a.stride(1), stride_bk=b.stride(0), stride_bn=b.stride(1), stride_cm=c.stride(0), stride_cn=c.stride(1), BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, NUM_SM=num_SMs, num_warps=NUM_WARPS, num_ctas=NUM_CTAS) @@ -432,7 +432,7 @@ def test_user_defined_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N c = torch.empty((M, N), device=a.device, dtype=torch.float32) num_SMs = torch.cuda.get_device_properties('cuda').multi_processor_count - grid = lambda META: (num_SMs,) + grid = lambda META: (min(META['NUM_SM'], triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])),) if USE_TMA: static_persistent_tma_warp_specialized_matmul_kernel[grid]( @@ -899,7 +899,7 @@ def process_epilogue(d, bias, w, epilogue): num_SMs = torch.cuda.get_device_properties('cuda').multi_processor_count def grid(META): - return (num_SMs,) + return (min(META['NUM_SM'], triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])),) full_static_persistent_matmul_kernel[grid]( a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, M=M, N=N, K=K, From a47f1f5c28c51be0eff3d35b5076c6f4d4bc2a05 Mon Sep 17 00:00:00 2001 From: Zahi Moudallal <128723247+zahimoud@users.noreply.github.com> Date: Tue, 12 Sep 2023 08:46:19 -0700 Subject: [PATCH 032/122] [BACKEND] Unify slow/fast reduce codegen (#2220) --- include/triton/Analysis/Utility.h | 12 +- lib/Analysis/Utility.cpp | 92 ++-- .../TritonGPUToLLVM/ReduceOpToLLVM.cpp | 425 ++++++------------ lib/Conversion/TritonGPUToLLVM/Utility.cpp | 18 + lib/Conversion/TritonGPUToLLVM/Utility.h | 3 + python/test/unit/language/test_core.py | 94 ++-- test/Analysis/test-allocation.mlir | 4 +- 7 files changed, 272 insertions(+), 376 deletions(-) diff --git a/include/triton/Analysis/Utility.h b/include/triton/Analysis/Utility.h index 4b6dad26cde3..0af8eceaade1 100644 --- a/include/triton/Analysis/Utility.h +++ b/include/triton/Analysis/Utility.h @@ -36,7 +36,9 @@ class ReduceOpHelper { triton::ReduceOp getOperation() { return op; } - bool isFastReduction(); + bool isReductionOnLayoutFastAxis(); + + unsigned getThreadOffsetOnReductionAxis(); bool isWarpSynchronous(); @@ -50,14 +52,16 @@ class ReduceOpHelper { unsigned getThreadsReductionAxis(); - SmallVector getScratchConfigBasic(); - - SmallVector> getScratchConfigsFast(); + SmallVector getScratchConfig(); unsigned getScratchSizeInBytes(); bool isSupportedLayout(); + bool isReduceWithinCTA(); + + unsigned getAxis() { return axis; } + private: triton::ReduceOp op; ArrayRef srcShape; diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 7cabed48a58f..eff4eb527a0f 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -33,14 +33,39 @@ SmallVector getParentOrder(Attribute layout) { } // namespace -bool ReduceOpHelper::isFastReduction() { - // Disable fast reduction only for debugging purpose - if (::triton::tools::getBoolEnv("DISABLE_FAST_REDUCTION")) - return false; +bool ReduceOpHelper::isReductionOnLayoutFastAxis() { return getParentAxis(getSrcLayout(), axis) == getParentOrder(getSrcLayout())[0]; } +// Thread offset is the thread index offset of two adjacent threads on the +// reduction axis within the warp. +unsigned ReduceOpHelper::getThreadOffsetOnReductionAxis() { + auto srcLayout = getSrcLayout(); + + // If the reduction axis is the fast axis of the parent layout + if (isReductionOnLayoutFastAxis()) { + return 1; + } + + unsigned threadOffset = 1; + if (auto sliceLayout = + srcLayout.dyn_cast()) { + auto parentLayout = sliceLayout.getParent(); + auto threadsPerWarp = triton::gpu::getThreadsPerWarp(parentLayout); + threadOffset = threadsPerWarp[sliceLayout.getDim()]; + } else { + auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcLayout); + if (threadsPerWarp.size() == 1) { + threadOffset = 1; + } else { + assert(threadsPerWarp.size() == 2 && "Only supports 2D layouts"); + threadOffset = axis == 0 ? threadsPerWarp[1] : threadsPerWarp[0]; + } + } + return threadOffset; +} + // Cases where distributed shared memory is not required in ConvertLayout: // (1) numCTAs == 1 // (2) numCTAs > 1 but srcCTALayout == dstCTALayout @@ -124,53 +149,26 @@ unsigned ReduceOpHelper::getThreadsReductionAxis() { triton::gpu::getWarpsPerCTAWithUniqueData(srcLayout, srcShape)[axis]; } -SmallVector ReduceOpHelper::getScratchConfigBasic() { - auto smemShape = convertType(getSrcShape()); - smemShape[axis] = std::min(smemShape[axis], getThreadsReductionAxis()); - return smemShape; -} - bool ReduceOpHelper::isWarpSynchronous() { auto argsLayout = getSrcLayout(); - return isFastReduction() && - (triton::gpu::getWarpsPerCTA(argsLayout)[axis] == 1); + return triton::gpu::getWarpsPerCTA(argsLayout)[axis] == 1; } -SmallVector> ReduceOpHelper::getScratchConfigsFast() { - SmallVector> smemShapes(3); - - auto argLayout = getSrcLayout(); - auto argLayoutMma = argLayout.dyn_cast(); - +SmallVector ReduceOpHelper::getScratchConfig() { + SmallVector smemShape; // that case doesn't need inter-warp communication if (isWarpSynchronous()) - return {{0, 0}, {0, 0}}; + return {0, 0}; - /// shared memory block0 - smemShapes[0] = convertType(getSrcShape()); - smemShapes[0][axis] = getInterWarpSize(); - - /// FIXME(Qingyi): This size is actually larger than required. - /// shared memory block1: - auto mod = op->getParentOfType(); - unsigned numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); - unsigned threadsPerWarp = - triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); - smemShapes[1].push_back(numWarps * threadsPerWarp); + smemShape = convertType(getSrcShape()); + smemShape[axis] = getInterWarpSizeWithUniqueData(); - return smemShapes; + return smemShape; } unsigned ReduceOpHelper::getScratchSizeInBytes() { - unsigned elems = 0; - if (isFastReduction()) { - auto smemShapes = getScratchConfigsFast(); - for (const auto &smemShape : smemShapes) - elems = std::max(elems, product(smemShape)); - } else { - auto smemShape = getScratchConfigBasic(); - elems = product(smemShape); - } + auto smemShape = getScratchConfig(); + auto elems = product(smemShape); unsigned bytesPerElem = 0; for (const auto &ty : srcElementTypes) { @@ -179,7 +177,21 @@ unsigned ReduceOpHelper::getScratchSizeInBytes() { return bytesPerElem * elems; } +bool ReduceOpHelper::isReduceWithinCTA() { + auto axis = getAxis(); + auto srcLayout = getSrcLayout(); + auto CTASplitNum = mlir::triton::gpu::getCTASplitNum(srcLayout); + assert(axis < CTASplitNum.size()); + return CTASplitNum[axis] == 1; +} + bool ReduceOpHelper::isSupportedLayout() { + // Layout optimization passes such as PlanCTAPass and + // RemoveLayoutConversionPass should avoid cross-CTA reduction + if (!isReduceWithinCTA()) { + return false; + } + auto srcLayout = getSrcLayout(); if (srcLayout.isa()) { return true; diff --git a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp index a00e105cf5ee..914696b8d246 100644 --- a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp @@ -9,9 +9,9 @@ using namespace mlir::triton; using ::mlir::LLVM::delinearize; using ::mlir::LLVM::linearize; +using ::mlir::LLVM::loadShared; using ::mlir::LLVM::shflSync; using ::mlir::LLVM::storeShared; -using ::mlir::triton::gpu::getCTASplitNum; using ::mlir::triton::gpu::getOrder; using ::mlir::triton::gpu::getTotalElemsPerThread; @@ -29,29 +29,59 @@ struct ReduceOpConversion LogicalResult matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - // When cross-CTA reduction is implemented in the future, this assertion can - // be removed - assert(isReduceWithinCTA(op) && - "Layout optimization passes such as PlanCTAPass and " - "RemoveLayoutConversionPass should avoid cross-CTA reduction"); - - if (ReduceOpHelper(op).isFastReduction()) - return matchAndRewriteFast(op, adaptor, rewriter); - return matchAndRewriteBasic(op, adaptor, rewriter); + ReduceOpHelper helper(op); + assert(helper.isSupportedLayout() && + "Unexpected srcLayout in ReduceOpConversion"); + Location loc = op->getLoc(); + + auto srcValues = unpackInputs(loc, op, adaptor, rewriter); + std::map, SmallVector> accs; + std::map, SmallVector> indices; + // First reduce all the values along axis within each thread. + reduceWithinThreads(helper, srcValues, accs, indices, rewriter); + + // Then reduce across threads within a warp. + reduceWithinWarps(helper, accs, rewriter); + + if (helper.isWarpSynchronous()) { + // If all the values to be reduced are within the same warp there is + // nothing left to do. + packResults(helper, accs, rewriter); + return success(); + } + + // Compute a shared memory base per operand. + auto smemShape = helper.getScratchConfig(); + + SmallVector smemBases = + getSmemBases(helper, op, smemShape, rewriter); + + storeWarpReduceToSharedMemory(helper, accs, indices, smemBases, rewriter); + + sync(rewriter, loc, op); + + // The second round of shuffle reduction + // now the problem size: sizeInterWarps, s1, s2, .. , sn + // where sizeInterWarps is 2^m + // + // Each thread needs to process: + // elemsPerThread = sizeInterWarps * s1 * s2 .. Sn / numThreads + accumulatePartialReductions(helper, smemBases, rewriter); + + // We could avoid this barrier in some of the layouts, however this is not + // the general case. + // TODO: optimize the barrier in case the layouts are accepted. + sync(rewriter, loc, op); + + // set output values + loadReductionAndPackResult(helper, smemShape, smemBases, rewriter); + + return success(); } private: int computeCapability; - bool isReduceWithinCTA(triton::ReduceOp op) const { - auto axis = op.getAxis(); - ReduceOpHelper helper(op); - auto srcLayout = helper.getSrcLayout(); - auto CTASplitNum = getCTASplitNum(srcLayout); - assert(axis < CTASplitNum.size()); - return CTASplitNum[axis] == 1; - } - void accumulate(ConversionPatternRewriter &rewriter, Region &combineOp, SmallVector &acc, ValueRange cur, bool isFirst) const { if (isFirst) { @@ -103,203 +133,36 @@ struct ReduceOpConversion return srcValues; } - // Calculates the write index in the shared memory where we would be writing - // the within-thread accumulations before we start doing across-threads - // accumulations. `index` is the index of the within-thread accumulations in - // the full tensor, whereas `writeIdx` is the mapped-to index in the shared - // memory - void getWriteIndexBasic(ConversionPatternRewriter &rewriter, Location loc, - Attribute layout, SmallVector &index, - SmallVector &writeIdx, - std::map &ints, unsigned originalAxis, - unsigned axis) const { - if (auto sliceLayout = layout.dyn_cast()) { - // Recover the axis in the parent layout - auto parentAxis = axis < sliceLayout.getDim() ? axis : axis + 1; - auto parentLayout = sliceLayout.getParent(); - getWriteIndexBasic(rewriter, loc, parentLayout, index, writeIdx, ints, - originalAxis, parentAxis); - return; - } - - writeIdx = index; - auto sizePerThread = triton::gpu::getSizePerThread(layout); - Value axisSizePerThread = ints[sizePerThread[axis]]; - Value _8 = ints[8]; - Value _16 = ints[16]; - if (layout.isa()) { - // A single thread owns axisSizePerThread contiguous values - // on the reduction axis. After within thread reduction, - // we would have a single accumulation every `axisSizePerThread` - // contiguous values in the original tensor, so we would need - // to map every `axisSizePerThread` to 1 value in smem as: - // writeIdx[originalAxis] = index[originalAxis] / axisSizePerThread - writeIdx[originalAxis] = udiv(index[originalAxis], axisSizePerThread); - } else if (auto mmaLayout = layout.dyn_cast()) { - if (!mmaLayout.isAmpere() && !mmaLayout.isHopper()) { - llvm::report_fatal_error("Unsupported layout"); - } - if (originalAxis == 0) { - // Because warpTileSize = [16, 8] and threadsPerWarp = [8, 4], each 8 - // rows in smem would correspond to a warp. The mapping - // is: (warp_index) x 8 + (row index within warp) - writeIdx[originalAxis] = add(mul(udiv(index[originalAxis], _16), _8), - urem(index[originalAxis], _8)); - } else { - // Same as BlockedEncodingAttr case - writeIdx[originalAxis] = udiv(index[originalAxis], axisSizePerThread); - } - } else { - llvm::report_fatal_error("Unsupported layout"); - } - } - - // Use shared memory for reduction within warps and across warps - LogicalResult - matchAndRewriteBasic(triton::ReduceOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - ReduceOpHelper helper(op); - Location loc = op.getLoc(); - unsigned axis = op.getAxis(); - - auto srcTys = op.getInputTypes(); - auto srcLayout = helper.getSrcLayout(); - if (!helper.isSupportedLayout()) { - assert(false && "Unexpected srcLayout in ReduceOpConversion"); - } - // The order of the axes for the the threads within the warp - auto srcOrd = triton::gpu::getOrder(srcLayout); - auto sizePerThread = triton::gpu::getSizePerThread(srcLayout); - auto srcShape = helper.getSrcShape(); - - SmallVector elemPtrTys(srcTys.size()); - for (unsigned i = 0; i < op.getNumOperands(); ++i) { - auto ty = srcTys[i].getElementType(); - auto llvmElemTy = getTypeConverter()->convertType(ty); - elemPtrTys[i] = LLVM::LLVMPointerType::get(llvmElemTy, 3); - } - auto llvmIndexTy = getTypeConverter()->getIndexType(); - - auto smemShape = helper.getScratchConfigBasic(); + SmallVector getSmemBases(ReduceOpHelper &helper, triton::ReduceOp op, + SmallVector smemShape, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); unsigned elems = product(smemShape); - - SmallVector smemBases(op.getNumOperands()); - smemBases[0] = bitcast( - getSharedMemoryBase(loc, rewriter, op.getOperation()), elemPtrTys[0]); + // indices will store the index of the op operands in descending order + // of their bitwidths + std::vector indices(op.getNumOperands()); + std::iota(indices.begin(), indices.end(), 0); + std::sort(indices.begin(), indices.end(), [&](unsigned i, unsigned j) { + return op.getElementTypes()[i].getIntOrFloatBitWidth() > + op.getElementTypes()[j].getIntOrFloatBitWidth(); + }); + // Assign base index to each operand in their order in indices + std::map indexToBase; + indexToBase[indices[0]] = + bitcast(getSharedMemoryBase(loc, rewriter, op.getOperation()), + getElementPtrType(op, indices[0])); for (unsigned i = 1; i < op.getNumOperands(); ++i) { - smemBases[i] = - bitcast(gep(elemPtrTys[i - 1], smemBases[i - 1], i32_val(elems)), - elemPtrTys[i]); + indexToBase[indices[i]] = + bitcast(gep(getElementPtrType(op, indices[i - 1]), + indexToBase[indices[i - 1]], i32_val(elems)), + getElementPtrType(op, indices[i])); } - - auto srcValues = unpackInputs(loc, op, adaptor, rewriter); - std::map, SmallVector> accs; - std::map, SmallVector> indices; - reduceWithinThreads(helper, srcValues, accs, indices, rewriter); - - // cached int32 constants - std::map ints; - ints[0] = i32_val(0); - for (int N = smemShape[axis] / 2; N > 0; N >>= 1) - ints[N] = i32_val(N); - ints[sizePerThread[axis]] = i32_val(sizePerThread[axis]); - ints[8] = i32_val(8); - ints[16] = i32_val(16); - - // reduce across threads - for (auto it : accs) { - const SmallVector &key = it.first; - auto &acc = it.second; - // get the writeIdx at which to write in smem - SmallVector writeIdx; - getWriteIndexBasic(rewriter, loc, srcLayout, indices[key], writeIdx, ints, - axis, axis); - - // calculate the offset in smem for that writeIdx - Value writeOffset = linearize(rewriter, loc, writeIdx, smemShape, srcOrd); - SmallVector writePtrs(op.getNumOperands()); - for (unsigned i = 0; i < op.getNumOperands(); ++i) { - // Store the within-thread accumulated value into shared memory - writePtrs[i] = gep(elemPtrTys[i], smemBases[i], writeOffset); - store(acc[i], writePtrs[i]); - } - - SmallVector readIdx(writeIdx.size(), ints[0]); - // Perform parallel reduction with sequential addressing - // E.g. We reduce `smemShape[axis]` elements into `smemShape[axis]/2` - // elements using `smemShape[axis]/2` threads where each thread - // would accumalte values that are `smemShape[axis]/2` apart - // to avoid bank conflicts. Then we repeat with `smemShape[axis]/4` - // threads, .. etc. - for (int N = smemShape[axis] / 2; N > 0; N >>= 1) { - // The readIdx will be N elements away on the reduction axis - readIdx[axis] = ints[N]; - // If the writeIdx is greater or equal to N, do nothing - Value readMask = icmp_slt(writeIdx[axis], ints[N]); - // Calculate the readOffset, if readMask is False, readOffset=0 - // meaning we reduce the value at writeIdx with itself - Value readOffset = select( - readMask, linearize(rewriter, loc, readIdx, smemShape, srcOrd), - ints[0]); - SmallVector readPtrs(op.getNumOperands()); - for (unsigned i = 0; i < op.getNumOperands(); ++i) { - // The readPtr is readOffset away from writePtr - readPtrs[i] = gep(elemPtrTys[i], writePtrs[i], readOffset); - } - - sync(rewriter, loc, op); - - // Combine accumulator value from another thread - SmallVector cur(op.getNumOperands()); - for (unsigned i = 0; i < op.getNumOperands(); ++i) { - cur[i] = load(readPtrs[i]); - } - accumulate(rewriter, op.getCombineOp(), acc, cur, false); - - sync(rewriter, loc, op); - - // Publish our new accumulator value to shared memory - for (unsigned i = 0; i < op.getNumOperands(); ++i) { - store(acc[i], writePtrs[i]); - } - } - } - - sync(rewriter, loc, op); - - // set output values - SmallVector results(op.getNumOperands()); + // smemBases[k] is the base pointer for the k-th operand + SmallVector smemBases(op.getNumOperands()); for (unsigned i = 0; i < op.getNumOperands(); ++i) { - if (auto resultTy = - op.getResult()[i].getType().dyn_cast()) { - // nd-tensor where n >= 1 - - auto resultLayout = resultTy.getEncoding(); - - unsigned resultElems = getTotalElemsPerThread(resultTy); - auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultTy); - assert(resultIndices.size() == resultElems); - - SmallVector resultVals(resultElems); - for (unsigned j = 0; j < resultElems; ++j) { - SmallVector readIdx = resultIndices[j]; - readIdx.insert(readIdx.begin() + axis, ints[0]); - Value readOffset = - linearize(rewriter, loc, readIdx, smemShape, srcOrd); - Value readPtr = gep(elemPtrTys[i], smemBases[i], readOffset); - resultVals[j] = load(readPtr); - } - results[i] = getTypeConverter()->packLLElements(loc, resultVals, - rewriter, resultTy); - } else { - // 0d-tensor -> scalar - results[i] = load(smemBases[i]); - } + smemBases[i] = indexToBase[i]; } - - auto parentBlock = op.getOperation()->getBlock(); - rewriter.replaceOp(op, results); - return success(); + return smemBases; } void sync(ConversionPatternRewriter &rewriter, Location loc, @@ -381,7 +244,7 @@ struct ReduceOpConversion // region and the accumulator values as source. void warpReduce(ConversionPatternRewriter &rewriter, Location loc, SmallVector &acc, triton::ReduceOp op, - unsigned numLaneToReduce) const { + unsigned numLaneToReduce, unsigned interleave) const { if (auto kind = matchReduxKind(op)) { // Based on benchmarking on A100 redux op gives a speed up only when doing // a single reduction (not partioned) and when the mask is static. @@ -420,7 +283,7 @@ struct ReduceOpConversion for (unsigned N = numLaneToReduce / 2; N > 0; N >>= 1) { SmallVector shfl(acc.size()); for (unsigned i = 0; i < acc.size(); ++i) { - shfl[i] = shflSync(loc, rewriter, acc[i], N); + shfl[i] = shflSync(loc, rewriter, acc[i], N * interleave); } accumulate(rewriter, op.getCombineOp(), acc, shfl, false); } @@ -433,10 +296,13 @@ struct ReduceOpConversion ConversionPatternRewriter &rewriter) const { triton::ReduceOp op = helper.getOperation(); unsigned sizeIntraWarps = helper.getIntraWarpSizeWithUniqueData(); + unsigned threadOffsetOnReductionAxis = + helper.getThreadOffsetOnReductionAxis(); for (auto it : accs) { const SmallVector &key = it.first; SmallVector &acc = accs[key]; - warpReduce(rewriter, op.getLoc(), acc, op, sizeIntraWarps); + warpReduce(rewriter, op.getLoc(), acc, op, sizeIntraWarps, + threadOffsetOnReductionAxis); } } @@ -476,6 +342,32 @@ struct ReduceOpConversion return LLVM::LLVMPointerType::get(llvmElemTy, 3); } + SmallVector + getMultiDimWarpId(ReduceOpHelper &helper, Value &warpId, Location &loc, + ConversionPatternRewriter &rewriter) const { + auto srcLayout = helper.getSrcLayout(); + auto srcShape = helper.getSrcShape(); + auto order = getOrder(srcLayout); + SmallVector multiDimWarpId; + + // 2x2 warps with slice dim = 0, warpId = 2 ends up writing at the same + // address as warpId = 0 since the warpsPerCTA is [1, 2], need to figure out + // a way to properly delinearize warpId in the slice case + if (auto sliceLayout = srcLayout.dyn_cast()) { + auto parentLayout = sliceLayout.getParent(); + auto parentWarpsPerCTA = triton::gpu::getWarpsPerCTA(parentLayout); + auto parentOrder = triton::gpu::getOrder(parentLayout); + multiDimWarpId = + delinearize(rewriter, loc, warpId, parentWarpsPerCTA, parentOrder); + multiDimWarpId.erase(multiDimWarpId.begin() + sliceLayout.getDim()); + } else { + auto warpsPerCTA = + triton::gpu::getWarpsPerCTAWithUniqueData(srcLayout, srcShape); + multiDimWarpId = delinearize(rewriter, loc, warpId, warpsPerCTA, order); + } + return multiDimWarpId; + } + void storeWarpReduceToSharedMemory( ReduceOpHelper &helper, std::map, SmallVector> &accs, @@ -491,32 +383,31 @@ struct ReduceOpConversion auto srcLayout = helper.getSrcLayout(); auto srcShape = helper.getSrcShape(); unsigned axis = op.getAxis(); - auto smemShapes = helper.getScratchConfigsFast(); + auto smemShape = helper.getScratchConfig(); auto threadsPerWarp = triton::gpu::getThreadsPerWarpWithUniqueData(srcLayout, srcShape); - auto warpsPerCTA = - triton::gpu::getWarpsPerCTAWithUniqueData(srcLayout, srcShape); auto order = getOrder(srcLayout); SmallVector multiDimLaneId = delinearize(rewriter, loc, laneId, threadsPerWarp, order); - SmallVector multiDimWarpId = - delinearize(rewriter, loc, warpId, warpsPerCTA, order); - Value laneIdAxis = multiDimLaneId[axis]; - Value warpIdAxis = multiDimWarpId[axis]; - Value zero = i32_val(0); Value laneZero = icmp_eq(laneIdAxis, zero); + SmallVector multiDimWarpId = + getMultiDimWarpId(helper, warpId, loc, rewriter); + Value warpIdAxis = multiDimWarpId[axis]; + + if (!helper.isReductionOnLayoutFastAxis()) { + std::reverse(order.begin(), order.end()); + } for (auto it : accs) { const SmallVector &key = it.first; SmallVector &acc = it.second; SmallVector writeIdx = indices[key]; writeIdx[axis] = warpIdAxis; - Value writeOffset = - linearize(rewriter, loc, writeIdx, smemShapes[0], order); + Value writeOffset = linearize(rewriter, loc, writeIdx, smemShape, order); for (unsigned i = 0; i < op.getNumOperands(); ++i) { auto elemPtrTy = getElementPtrType(op, i); Value writePtr = gep(elemPtrTy, smemBases[i], writeOffset); @@ -532,8 +423,8 @@ struct ReduceOpConversion ConversionPatternRewriter &rewriter) const { triton::ReduceOp op = helper.getOperation(); auto srcLayout = helper.getSrcLayout(); - auto smemShapes = helper.getScratchConfigsFast(); - unsigned elems = product(smemShapes[0]); + auto smemShape = helper.getScratchConfig(); + unsigned elems = product(smemShape); unsigned sizeInterWarps = helper.getInterWarpSizeWithUniqueData(); Location loc = op.getLoc(); @@ -547,18 +438,16 @@ struct ReduceOpConversion product(triton::gpu::getWarpsPerCTA(srcLayout)) * triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); unsigned elemsPerThread = std::max(elems / numThreads, 1); + Value threadIsNeeded = icmp_slt(threadId, i32_val(elems)); Value readOffset = threadId; for (unsigned round = 0; round < elemsPerThread; ++round) { - // FIXME(Qingyi): need predicate icmp_slt(threadId, - // i32_val(sizeInerWarps)) SmallVector acc(op.getNumOperands()); for (unsigned i = 0; i < op.getNumOperands(); ++i) { auto elemPtrTy = getElementPtrType(op, i); Value readPtr = gep(elemPtrTy, smemBases[i], readOffset); - acc[i] = load(readPtr); + acc[i] = loadShared(rewriter, loc, readPtr, threadIsNeeded); } - warpReduce(rewriter, loc, acc, op, sizeInterWarps); - + warpReduce(rewriter, loc, acc, op, sizeInterWarps, 1 /* interleave */); // only the first thread in each sizeInterWarps is writing Value writeOffset = readOffset; SmallVector writePtrs(op.getNumOperands()); @@ -566,7 +455,7 @@ struct ReduceOpConversion auto elemPtrTy = getElementPtrType(op, i); writePtrs[i] = gep(elemPtrTy, smemBases[i], writeOffset); } - Value threadIsNeeded = icmp_slt(threadId, i32_val(elems)); + Value laneIdModSizeInterWarps = urem(laneId, i32_val(sizeInterWarps)); Value laneIdModSizeInterWarpsIsZero = icmp_eq(laneIdModSizeInterWarps, zero); @@ -585,12 +474,17 @@ struct ReduceOpConversion // Load the final reduction from shared memory and replace the reduce result // with it. void loadReductionAndPackResult(ReduceOpHelper &helper, + SmallVector smemShape, SmallVector &smemBases, ConversionPatternRewriter &rewriter) const { triton::ReduceOp op = helper.getOperation(); Location loc = op.getLoc(); - auto smemShapes = helper.getScratchConfigsFast(); - auto order = getOrder(helper.getSrcLayout()); + auto srcLayout = helper.getSrcLayout(); + auto axis = op.getAxis(); + auto order = getOrder(srcLayout); + if (!helper.isReductionOnLayoutFastAxis()) { + std::reverse(order.begin(), order.end()); + } SmallVector results(op.getNumOperands()); for (unsigned i = 0; i < op.getNumOperands(); ++i) { if (auto resultTy = @@ -606,7 +500,7 @@ struct ReduceOpConversion SmallVector readIdx = resultIndices[j]; readIdx.insert(readIdx.begin() + op.getAxis(), i32_val(0)); Value readOffset = - linearize(rewriter, loc, readIdx, smemShapes[0], order); + linearize(rewriter, loc, readIdx, smemShape, order); Value readPtr = gep(getElementPtrType(op, i), smemBases[i], readOffset); resultVals[j] = load(readPtr); @@ -621,67 +515,6 @@ struct ReduceOpConversion } rewriter.replaceOp(op, results); } - - // Use warp shuffle for reduction within warps and shared memory for data - // exchange across warps - LogicalResult matchAndRewriteFast(triton::ReduceOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - ReduceOpHelper helper(op); - assert(helper.isSupportedLayout() && - "Unexpected srcLayout in ReduceOpConversion"); - Location loc = op->getLoc(); - - auto srcValues = unpackInputs(loc, op, adaptor, rewriter); - std::map, SmallVector> accs; - std::map, SmallVector> indices; - // First reduce all the values along axis within each thread. - reduceWithinThreads(helper, srcValues, accs, indices, rewriter); - - // Then reduce across threads within a warp. - reduceWithinWarps(helper, accs, rewriter); - - if (helper.isWarpSynchronous()) { - // If all the values to be reduced are within the same warp there is - // nothing left to do. - packResults(helper, accs, rewriter); - return success(); - } - - // Compute a shared memory base per operand. - auto smemShapes = helper.getScratchConfigsFast(); - unsigned elems = product(smemShapes[0]); - unsigned maxElems = std::max(elems, product(smemShapes[1])); - SmallVector smemBases(op.getNumOperands()); - smemBases[0] = - bitcast(getSharedMemoryBase(loc, rewriter, op.getOperation()), - getElementPtrType(op, 0)); - for (unsigned i = 1; i < op.getNumOperands(); ++i) { - smemBases[i] = bitcast(gep(getElementPtrType(op, i - 1), smemBases[i - 1], - i32_val(maxElems)), - getElementPtrType(op, i)); - } - storeWarpReduceToSharedMemory(helper, accs, indices, smemBases, rewriter); - - sync(rewriter, loc, op); - - // The second round of shuffle reduction - // now the problem size: sizeInterWarps, s1, s2, .. , sn - // where sizeInterWarps is 2^m - // - // Each thread needs to process: - // elemsPerThread = sizeInterWarps * s1 * s2 .. Sn / numThreads - accumulatePartialReductions(helper, smemBases, rewriter); - - // We could avoid this barrier in some of the layouts, however this is not - // the general case. - // TODO: optimize the barrier in case the layouts are accepted. - sync(rewriter, loc, op); - - // set output values - loadReductionAndPackResult(helper, smemBases, rewriter); - - return success(); - } }; void populateReduceOpToLLVMPatterns( diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index 1c9fc9a58f50..f89cfe7b8c7d 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -250,6 +250,24 @@ Value storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr, return builder.launch(rewriter, loc, void_ty(ctx)); } +Value loadShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr, + Value pred) { + MLIRContext *ctx = rewriter.getContext(); + auto ptrTy = ptr.getType().cast(); + assert(ptrTy.getAddressSpace() == 3 && "Invalid addr space for loadShared"); + auto elemTy = ptrTy.getElementType(); + unsigned bitwidth = elemTy.getIntOrFloatBitWidth(); + + const char *c = bitwidth == 64 ? "=l" : (bitwidth == 16 ? "=h" : "=r"); + + PTXBuilder builder; + auto *dOpr = builder.newOperand(c); + auto *ptrOpr = builder.newAddrOperand(ptr, "r"); + auto &ld = builder.create<>("ld")->shared().b(bitwidth); + ld(dOpr, ptrOpr).predicate(pred, "b"); + return builder.launch(rewriter, loc, elemTy); +} + static Value commonShflSync(Location loc, ConversionPatternRewriter &rewriter, Value val, int i, const std::string &shuffleType, const std::string &clamp) { diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.h b/lib/Conversion/TritonGPUToLLVM/Utility.h index 13aef00a6d0b..5a7adeab492f 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.h +++ b/lib/Conversion/TritonGPUToLLVM/Utility.h @@ -321,6 +321,9 @@ Value linearize(ConversionPatternRewriter &rewriter, Location loc, Value storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr, Value val, Value pred); +Value loadShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr, + Value pred); + Value shflSync(Location loc, ConversionPatternRewriter &rewriter, Value val, int i); Value shflUpSync(Location loc, ConversionPatternRewriter &rewriter, Value val, diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 5bde5f1ddd94..70f8d3d80900 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1838,46 +1838,73 @@ def test_scan_layouts(M, N, src_layout, axis, device): @pytest.mark.parametrize("M, N", [[128, 16], [128, 128], [32, 128], [32, 32]]) @pytest.mark.parametrize("src_layout", layouts) @pytest.mark.parametrize("axis", [0, 1]) -def test_reduce_layouts(M, N, src_layout, axis, device): +@pytest.mark.parametrize("reduce2d", [False, True]) +@pytest.mark.parametrize("dtype_str", ["int32", "float32", "float16"]) +@pytest.mark.parametrize("reduce_op", ["sum", "max"]) +def test_reduce_layouts(M, N, src_layout, axis, reduce2d, dtype_str, reduce_op, device): if is_hip(): pytest.skip("test_reduce_layouts is not supported in HIP") - - rdims_2d = f"1x{N}" if axis == 0 else f"{M}x1" + if reduce_op == "sum" and dtype_str == "float16" and M * N > 1024: + pytest.skip("Skipping sum reduction on float16 due to accuracy issues") + + ty = {"int32": "i32", "float32": "f32", "float16": "f16"}[dtype_str] + arith_op = { + "max": {"int32": "arith.maxsi", "float32": "arith.maxf", "float16": "arith.maxf"}, + "sum": {"int32": "arith.addi", "float32": "arith.addf", "float16": "arith.addf"} + }[reduce_op][dtype_str] + numpy_op = { + "max": np.max, + "sum": np.sum + }[reduce_op] rdims_1d = f"{N}" if axis == 0 else f"{M}" + rdims_2d = f"1x{N}" if axis == 0 else f"{M}x1" store_range = "%7" if axis == 0 else "%1" blocked = BlockedLayout([1, 1], [32, 1], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]) + epilogue = f""" + %14 = "tt.reduce"(%13) ({{ + ^bb0(%arg3: {ty}, %arg4: {ty}): + %17 = {arith_op} %arg3, %arg4 : {ty} + tt.reduce.return %17 : {ty} + }}) {{axis = 0 : i32}} : (tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #src}}>>) -> {ty} + tt.store %arg2, %14 {{cache = 1 : i32, evict = 1 : i32}} : {ty} + tt.return + }} + }} + """ if reduce2d else f""" + %14 = tt.splat %arg2 : (!tt.ptr<{ty}>) -> tensor<{rdims_2d}x!tt.ptr<{ty}>, #blocked> + %15 = tt.addptr %14, {store_range} : tensor<{rdims_2d}x!tt.ptr<{ty}>, #blocked>, tensor<{rdims_2d}xi32, #blocked> + %16 = {GPU_DIALECT}.convert_layout %13 : (tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #src}}>>) -> tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #blocked}}>> + %17 = tt.expand_dims %16 {{axis = {axis} : i32}} : (tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #blocked}}>>) -> tensor<{rdims_2d}x{ty}, #blocked> + tt.store %15, %17 {{cache = 1 : i32, evict = 1 : i32}} : tensor<{rdims_2d}x{ty}, #blocked> + tt.return + }} + }} + """ + ir = f""" #blocked = {blocked} #src = {src_layout} module attributes {{"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32}} {{ - tt.func public @kernel_0d1d2c3d4c(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: i32 {{tt.divisibility = 16 : i32}}, %arg2: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ + tt.func public @kernel_0d1d2c3d4c(%arg0: !tt.ptr<{ty}> {{tt.divisibility = 16 : i32}}, %arg1: i32 {{tt.divisibility = 16 : i32}}, %arg2: !tt.ptr<{ty}> {{tt.divisibility = 16 : i32}}) {{ %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #blocked}}>> %1 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #blocked}}>>) -> tensor<{M}x1xi32, #blocked> %2 = tt.splat %arg1 : (i32) -> tensor<{M}x1xi32, #blocked> %3 = arith.muli %1, %2 : tensor<{M}x1xi32, #blocked> - %4 = tt.splat %arg0 : (!tt.ptr) -> tensor<{M}x1x!tt.ptr, #blocked> - %5 = tt.addptr %4, %3 : tensor<{M}x1x!tt.ptr, #blocked>, tensor<{M}x1xi32, #blocked> + %4 = tt.splat %arg0 : (!tt.ptr<{ty}>) -> tensor<{M}x1x!tt.ptr<{ty}>, #blocked> + %5 = tt.addptr %4, %3 : tensor<{M}x1x!tt.ptr<{ty}>, #blocked>, tensor<{M}x1xi32, #blocked> %6 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #{GPU_DIALECT}.slice<{{dim = 0, parent = #blocked}}>> %7 = tt.expand_dims %6 {{axis = 0 : i32}} : (tensor<{N}xi32, #{GPU_DIALECT}.slice<{{dim = 0, parent = #blocked}}>>) -> tensor<1x{N}xi32, #blocked> - %8 = tt.broadcast %5 : (tensor<{M}x1x!tt.ptr, #blocked>) -> tensor<{M}x{N}x!tt.ptr, #blocked> + %8 = tt.broadcast %5 : (tensor<{M}x1x!tt.ptr<{ty}>, #blocked>) -> tensor<{M}x{N}x!tt.ptr<{ty}>, #blocked> %9 = tt.broadcast %7 : (tensor<1x{N}xi32, #blocked>) -> tensor<{M}x{N}xi32, #blocked> - %10 = tt.addptr %8, %9 : tensor<{M}x{N}x!tt.ptr, #blocked>, tensor<{M}x{N}xi32, #blocked> - %11 = tt.splat %arg2 : (!tt.ptr) -> tensor<{rdims_2d}x!tt.ptr, #blocked> - %12 = tt.addptr %11, {store_range} : tensor<{rdims_2d}x!tt.ptr, #blocked>, tensor<{rdims_2d}xi32, #blocked> - %13 = tt.load %10 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}xi32, #blocked> - %14 = {GPU_DIALECT}.convert_layout %13 : (tensor<{M}x{N}xi32, #blocked>) -> tensor<{M}x{N}xi32, #src> - %15 = "tt.reduce"(%14) ({{ - ^bb0(%arg3: i32, %arg4: i32): - %17 = arith.addi %arg3, %arg4 : i32 - tt.reduce.return %17 : i32 - }}) {{axis = {axis} : i32}} : (tensor<{M}x{N}xi32, #src>) -> tensor<{rdims_1d}xi32, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #src}}>> - %18 = {GPU_DIALECT}.convert_layout %15 : (tensor<{rdims_1d}xi32, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #src}}>>) -> tensor<{rdims_1d}xi32, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #blocked}}>> - %19 = tt.expand_dims %18 {{axis = {axis} : i32}} : (tensor<{rdims_1d}xi32, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #blocked}}>>) -> tensor<{rdims_2d}xi32, #blocked> - tt.store %12, %19 {{cache = 1 : i32, evict = 1 : i32}} : tensor<{rdims_2d}xi32, #blocked> - tt.return - }} - }} - """ + %10 = tt.addptr %8, %9 : tensor<{M}x{N}x!tt.ptr<{ty}>, #blocked>, tensor<{M}x{N}xi32, #blocked> + %11 = tt.load %10 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}x{ty}, #blocked> + %12 = {GPU_DIALECT}.convert_layout %11 : (tensor<{M}x{N}x{ty}, #blocked>) -> tensor<{M}x{N}x{ty}, #src> + %13 = "tt.reduce"(%12) ({{ + ^bb0(%arg3: {ty}, %arg4: {ty}): + %17 = {arith_op} %arg3, %arg4 : {ty} + tt.reduce.return %17 : {ty} + }}) {{axis = {axis} : i32}} : (tensor<{M}x{N}x{ty}, #src>) -> tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #src}}>> + """ + epilogue import tempfile with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: @@ -1886,21 +1913,20 @@ def test_reduce_layouts(M, N, src_layout, axis, device): kernel = triton.compile(f.name) rs = RandomState(17) - x = rs.randint(0, 20, (M, N)).astype('int32') - - if axis == 0: - z = np.zeros((1, N)).astype('int32') - else: - z = np.zeros((M, 1)).astype('int32') + x = numpy_random((M, N), dtype_str=dtype_str, rs=rs, low=0, high=10) + z_shape = (1, 1) if reduce2d else (1, N) if axis == 0 else (M, 1) + z = np.zeros(z_shape).astype(dtype_str) x_tri = torch.tensor(x, device=device) z_tri = torch.tensor(z, device=device) - pgm = kernel[(1, 1, 4)](x_tri, x_tri.stride(0), z_tri) + pgm = kernel[(1, 1, 1)](x_tri, x_tri.stride(0), z_tri) + z_ref = numpy_op(x) if reduce2d else numpy_op(x, axis=axis, keepdims=True) - z_ref = np.sum(x, axis=axis, keepdims=True) - - np.testing.assert_equal(z_ref, z_tri.cpu().numpy()) + if dtype_str == 'float16': + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-2) + else: + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3) layouts = [ diff --git a/test/Analysis/test-allocation.mlir b/test/Analysis/test-allocation.mlir index 93d80448c998..919dca69201d 100644 --- a/test/Analysis/test-allocation.mlir +++ b/test/Analysis/test-allocation.mlir @@ -265,14 +265,14 @@ tt.func @alloc_m_barrier_scalar() { // CHECK-LABEL: scratch tt.func @scratch() { %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> - // CHECK: scratch offset = 0, size = 512 + // CHECK: scratch offset = 0, size = 128 %b = "tt.reduce" (%cst0) ({ ^bb0(%arg0: f16, %arg1: f16): %add = arith.addf %arg0, %arg1 : f16 tt.reduce.return %add : f16 }) {axis = 0 : i32} : (tensor<16x16xf16, #AL>) -> tensor<16xf16, #sliceAd0> tt.return - // CHECK-NEXT: size = 512 + // CHECK-NEXT: size = 128 } // CHECK-LABEL: trans From 37f12497b0cbcc6754f2d2f7557055103f832d37 Mon Sep 17 00:00:00 2001 From: Ying Zhang Date: Tue, 12 Sep 2023 08:57:01 -0700 Subject: [PATCH 033/122] [FRONTEND] Add PyTorch fp8 dtypes to Triton (#2279) Add PyTorch fp8 dtypes (https://github.com/pytorch/pytorch/blob/8025b193a966a6d8e3afc9c03a54e577bc04eb3d/torchgen/api/types/types.py#L50-L51) to Triton. --- python/triton/runtime/jit.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index d306c160e8ee..f664b33c71ab 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -250,6 +250,8 @@ def _type_of(key): "float8e5": "fp8e5", "float8e4b15": "fp8e4b15", "float8e4b15x4": "fp8e4b15x4", + "float8_e4m3fn": "fp8e4nv", + "float8_e5m2": "fp8e5", "float16": "fp16", "bfloat16": "bf16", "float32": "fp32", From 994f7e44601602545340898b7c8aabb61309c9a5 Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Tue, 12 Sep 2023 11:02:20 -0700 Subject: [PATCH 034/122] [BACKEND] Remove dependency between NVGPU and TritonNvidiaGPU (#2282) --- bin/RegisterTritonDialects.h | 1 + include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h | 1 - .../Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td | 1 - lib/Conversion/NVGPUToLLVM/NVGPUToLLVMPass.cpp | 1 + lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h | 1 + lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp | 6 ++++++ lib/Conversion/TritonGPUToLLVM/Utility.cpp | 1 + lib/Conversion/TritonGPUToLLVM/Utility.h | 1 - python/test/unit/operators/test_cross_entropy.py | 2 +- 9 files changed, 11 insertions(+), 4 deletions(-) diff --git a/bin/RegisterTritonDialects.h b/bin/RegisterTritonDialects.h index d3e0ee102a51..29ba31eaf1f3 100644 --- a/bin/RegisterTritonDialects.h +++ b/bin/RegisterTritonDialects.h @@ -1,4 +1,5 @@ #pragma once +#include "triton/Dialect/NVGPU/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" diff --git a/include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h b/include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h index 680af81ac41c..d07f0743615d 100644 --- a/include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h @@ -30,7 +30,6 @@ #include "mlir/IR/Dialect.h" // TritonNvidiaGPU depends on Triton -#include "triton/Dialect/NVGPU/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Traits.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h.inc" diff --git a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td index 08ff21f523f0..f2ab288c1799 100644 --- a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td +++ b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td @@ -38,7 +38,6 @@ def TritonNvidiaGPU_Dialect : Dialect { let dependentDialects = [ "triton::TritonDialect", "triton::gpu::TritonGPUDialect", - "mlir::triton::nvgpu::NVGPUDialect", "mlir::gpu::GPUDialect", "tensor::TensorDialect", ]; diff --git a/lib/Conversion/NVGPUToLLVM/NVGPUToLLVMPass.cpp b/lib/Conversion/NVGPUToLLVM/NVGPUToLLVMPass.cpp index 69545a00d83e..fcce9884bc6e 100644 --- a/lib/Conversion/NVGPUToLLVM/NVGPUToLLVMPass.cpp +++ b/lib/Conversion/NVGPUToLLVM/NVGPUToLLVMPass.cpp @@ -6,6 +6,7 @@ #include "mlir/Pass/Pass.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Dialect/NVGPU/IR/Dialect.h" #include "triton/Conversion/TritonGPUToLLVM/PTXAsmFormat.h" diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h index ba0423203a0b..2b12e727025f 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h @@ -11,6 +11,7 @@ #include "Utility.h" #include "mlir/IR/TypeUtilities.h" #include "triton/Analysis/AxisInfo.h" +#include "triton/Dialect/NVGPU/IR/Dialect.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" #include "triton/Target/PTX/TmaMetadata.h" #include diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp index 4ef4f65ea1c7..fb2f46f9936a 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp @@ -18,6 +18,7 @@ #include "triton/Analysis/Allocation.h" #include "triton/Analysis/AxisInfo.h" #include "triton/Analysis/Membar.h" +#include "triton/Dialect/NVGPU/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" @@ -387,6 +388,11 @@ struct ConvertTritonGPUToLLVM using ConvertTritonGPUToLLVMBase< ConvertTritonGPUToLLVM>::ConvertTritonGPUToLLVMBase; + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + ConvertTritonGPUToLLVM(int32_t computeCapability, Target target, mlir::triton::gpu::TMAMetadataTy *tmaMetadata) : ConvertTritonGPUToLLVMBase({computeCapability, target}), diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index f89cfe7b8c7d..06d338685909 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -1,5 +1,6 @@ #include "Utility.h" #include "TypeConverter.h" +#include "triton/Dialect/NVGPU/IR/Dialect.h" namespace mlir { diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.h b/lib/Conversion/TritonGPUToLLVM/Utility.h index 5a7adeab492f..9dd072d0c942 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.h +++ b/lib/Conversion/TritonGPUToLLVM/Utility.h @@ -6,7 +6,6 @@ #include "triton/Analysis/Utility.h" #include "triton/Conversion/MLIRTypes.h" #include "triton/Conversion/TritonGPUToLLVM/PTXAsmFormat.h" -#include "triton/Dialect/NVGPU/IR/Dialect.h" // Shortcuts for some commonly used LLVM ops to keep code simple and intuitive // Operators diff --git a/python/test/unit/operators/test_cross_entropy.py b/python/test/unit/operators/test_cross_entropy.py index 12739e56722c..f6ae42ac3e9a 100644 --- a/python/test/unit/operators/test_cross_entropy.py +++ b/python/test/unit/operators/test_cross_entropy.py @@ -36,7 +36,7 @@ def test_op(M, N, dtype, mode): x.grad = None th_y.backward(dy) th_dx = x.grad.clone() - if dtype == 'float16': + if dtype == torch.float16: torch.testing.assert_close(th_dx, tt_dx, rtol=0.001, atol=0.001) else: torch.testing.assert_close(th_dx, tt_dx) From e95e1f12eb5bb061ac28f47c5841824ef9d124b4 Mon Sep 17 00:00:00 2001 From: Zahi Moudallal <128723247+zahimoud@users.noreply.github.com> Date: Wed, 13 Sep 2023 10:02:25 -0700 Subject: [PATCH 035/122] [BACKEND] Convert layout illegal mem access fix (#2287) --- include/triton/Analysis/Allocation.h | 1 + lib/Analysis/Allocation.cpp | 52 ++++++++------ .../TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp | 31 +++++++-- .../TritonGPUToLLVM/TritonGPUToLLVMBase.h | 18 ++--- python/test/unit/language/test_core.py | 68 ++++++++++--------- 5 files changed, 106 insertions(+), 64 deletions(-) diff --git a/include/triton/Analysis/Allocation.h b/include/triton/Analysis/Allocation.h index 6370eba55b32..521ffec3a739 100644 --- a/include/triton/Analysis/Allocation.h +++ b/include/triton/Analysis/Allocation.h @@ -21,6 +21,7 @@ class AllocationAnalysis; SmallVector getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec, unsigned &outVec); +SmallVector getRepShapeForCvtLayout(triton::gpu::ConvertLayoutOp op); } // namespace triton diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index 567501d3dbce..ec3757208016 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -18,6 +18,7 @@ using ::mlir::triton::gpu::getOrder; using ::mlir::triton::gpu::getShapePerCTA; using ::mlir::triton::gpu::getShapePerCTATile; using ::mlir::triton::gpu::getSizePerThread; +using ::mlir::triton::gpu::getUniqueContigPerThread; using ::mlir::triton::gpu::MmaEncodingAttr; using ::mlir::triton::gpu::SharedEncodingAttr; using ::mlir::triton::gpu::SliceEncodingAttr; @@ -50,9 +51,7 @@ getCvtOrder(Attribute srcLayout, Attribute dstLayout) { return {inOrd, outOrd}; } -SmallVector -getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec, - unsigned &outVec) { +SmallVector getRepShapeForCvtLayout(triton::gpu::ConvertLayoutOp op) { auto srcTy = op.getSrc().getType().cast(); auto dstTy = op.getResult().getType().cast(); Attribute srcLayout = srcTy.getEncoding(); @@ -76,15 +75,7 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec, } } - assert(srcLayout && dstLayout && - "Unexpected layout in getScratchConfigForCvtLayout()"); - auto [inOrd, outOrd] = getCvtOrder(srcLayout, dstLayout); - unsigned srcContigPerThread = getContigPerThread(srcLayout)[inOrd[0]]; - unsigned dstContigPerThread = getContigPerThread(dstLayout)[outOrd[0]]; - // TODO: Fix the legacy issue that ourOrd[0] == 0 always means - // that we cannot do vectorization. - inVec = outOrd[0] == 0 ? 1 : inOrd[0] == 0 ? 1 : srcContigPerThread; - outVec = outOrd[0] == 0 ? 1 : dstContigPerThread; + assert(srcLayout && dstLayout && "Unexpected layout in getRepShape()"); auto srcShapePerCTA = getShapePerCTA(srcTy); auto dstShapePerCTA = getShapePerCTA(dstTy); @@ -92,21 +83,44 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec, auto dstShapePerCTATile = getShapePerCTATile(dstLayout, dstTy.getShape()); unsigned rank = dstTy.getRank(); - SmallVector paddedRepShape(rank); - unsigned pad = std::max(inVec, outVec); + SmallVector repShape(rank); for (unsigned d = 0; d < rank; ++d) { - paddedRepShape[d] = + repShape[d] = std::max(std::min(srcShapePerCTA[d], srcShapePerCTATile[d]), std::min(dstShapePerCTA[d], dstShapePerCTATile[d])); } - if (rank == 1) - return paddedRepShape; + return repShape; +} + +SmallVector +getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec, + unsigned &outVec) { + auto repShape = getRepShapeForCvtLayout(op); + + auto srcTy = op.getSrc().getType().cast(); + auto dstTy = op.getResult().getType().cast(); + Attribute srcLayout = srcTy.getEncoding(); + Attribute dstLayout = dstTy.getEncoding(); + + auto [inOrd, outOrd] = getCvtOrder(srcLayout, dstLayout); + unsigned srcContigPerThread = + getUniqueContigPerThread(srcLayout, srcTy.getShape())[inOrd[0]]; + unsigned dstContigPerThread = + getUniqueContigPerThread(dstLayout, dstTy.getShape())[outOrd[0]]; + // TODO: Fix the legacy issue that ourOrd[0] == 0 always means + // that we cannot do vectorization. + inVec = outOrd[0] == 0 ? 1 : inOrd[0] == 0 ? 1 : srcContigPerThread; + outVec = outOrd[0] == 0 ? 1 : dstContigPerThread; + + if (repShape.size() <= 1) + return repShape; unsigned paddedDim = 1; if (auto dstBlockedLayout = dstLayout.dyn_cast()) { paddedDim = dstBlockedLayout.getOrder()[0]; } - paddedRepShape[paddedDim] += pad; - return paddedRepShape; + unsigned pad = std::max(inVec, outVec); + repShape[paddedDim] += pad; + return repShape; } SmallVector diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 1aadf5093884..70e675c7bc5f 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -237,12 +237,30 @@ struct ConvertLayoutOpConversion llvm_unreachable("unexpected layout in getMultiDimOffset"); } + SmallVector + getWrappedMultiDimOffset(ConversionPatternRewriter &rewriter, Location loc, + ArrayRef multiDimOffset, + ArrayRef shape, + SmallVector shapePerCTATile, + SmallVector shapePerCTA) const { + unsigned rank = shape.size(); + SmallVector multiDimOffsetWrapped(rank); + for (unsigned d = 0; d < rank; ++d) { + if (shapePerCTATile[d] > shapePerCTA[d]) + multiDimOffsetWrapped[d] = urem(multiDimOffset[d], i32_val(shape[d])); + else + multiDimOffsetWrapped[d] = multiDimOffset[d]; + } + return multiDimOffsetWrapped; + } + // shared memory rd/st for blocked or mma layout with data padding void processReplica(Location loc, ConversionPatternRewriter &rewriter, bool stNotRd, RankedTensorType type, ArrayRef numCTAsEachRep, ArrayRef multiDimRepId, unsigned vec, ArrayRef paddedRepShape, + ArrayRef origRepShape, ArrayRef outOrd, SmallVector &vals, Value smemBase) const { auto accumNumCTAsEachRep = product(numCTAsEachRep); @@ -286,8 +304,11 @@ struct ConvertLayoutOpConversion SmallVector multiDimOffset = getMultiDimOffset(layout, loc, rewriter, elemId, type, multiDimCTAInRepId, shapePerCTATile); - Value offset = - linearize(rewriter, loc, multiDimOffset, paddedRepShape, outOrd); + SmallVector multiDimOffsetWrapped = getWrappedMultiDimOffset( + rewriter, loc, multiDimOffset, origRepShape, shapePerCTATile, + shapePerCTA); + Value offset = linearize(rewriter, loc, multiDimOffsetWrapped, + paddedRepShape, outOrd); auto elemPtrTy = ptr_ty(llvmElemTy, 3); Value ptr = gep(elemPtrTy, smemBase, offset); auto vecTy = vec_ty(llvmElemTy, vec); @@ -575,6 +596,7 @@ struct ConvertLayoutOpConversion rewriter, srcTy); unsigned inVec = 0; unsigned outVec = 0; + auto origRepShape = getRepShapeForCvtLayout(op); auto paddedRepShape = getScratchConfigForCvtLayout(op, inVec, outVec); if (getElementTypeOrSelf(op.getType()) .isa()) { @@ -618,7 +640,7 @@ struct ConvertLayoutOpConversion else processReplica(loc, rewriter, /*stNotRd*/ true, srcTy, inNumCTAsEachRep, multiDimRepId, inVec, paddedRepShape, - outOrd, vals, smemBase); + origRepShape, outOrd, vals, smemBase); } else { assert(0 && "ConvertLayout with input layout not implemented"); return failure(); @@ -651,7 +673,8 @@ struct ConvertLayoutOpConversion else processReplica(loc, rewriter, /*stNotRd*/ false, dstTy, outNumCTAsEachRep, multiDimRepId, outVec, - paddedRepShape, outOrd, outVals, smemBase); + paddedRepShape, origRepShape, outOrd, outVals, + smemBase); } else { assert(0 && "ConvertLayout with output layout not implemented"); return failure(); diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h index 2b12e727025f..78c4f92caa53 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h @@ -339,6 +339,8 @@ class ConvertTritonGPUOpToLLVMPatternBase { // Order auto inOrder = triton::gpu::getOrder(srcEncoding); auto outOrder = triton::gpu::getOrder(resSharedLayout); + assert(outVec * (maxPhase - 1) <= srcShape[outOrder[0]] && + "Swizzling would generate out of bounds memory accesses"); // Tensor indices held by the current thread, as LLVM values auto srcIndices = emitIndices(loc, rewriter, srcEncoding, srcTy, false); // Swizzling with leading offsets (e.g. Hopper GMMA) @@ -452,10 +454,10 @@ class ConvertTritonGPUOpToLLVMPatternBase { auto dstElemTy = dstTy.getElementType(); auto inOrd = triton::gpu::getOrder(srcSharedLayout); auto outOrd = triton::gpu::getOrder(dstDistributedLayout); - unsigned outVec = - inOrd == outOrd - ? triton::gpu::getContigPerThread(dstDistributedLayout)[outOrd[0]] - : 1; + unsigned outVec = inOrd == outOrd + ? triton::gpu::getUniqueContigPerThread( + dstDistributedLayout, dstShape)[outOrd[0]] + : 1; unsigned inVec = srcSharedLayout.getVec(); unsigned minVec = std::min(outVec, inVec); unsigned outElems = triton::gpu::getTotalElemsPerThread(dstTy); @@ -501,10 +503,10 @@ class ConvertTritonGPUOpToLLVMPatternBase { auto dstElemTy = dstTy.getElementType(); auto inOrd = triton::gpu::getOrder(srcDistributedLayout); auto outOrd = dstSharedLayout.getOrder(); - unsigned inVec = - inOrd == outOrd - ? triton::gpu::getContigPerThread(srcDistributedLayout)[inOrd[0]] - : 1; + unsigned inVec = inOrd == outOrd + ? triton::gpu::getUniqueContigPerThread( + srcDistributedLayout, srcShape)[inOrd[0]] + : 1; unsigned outVec = dstSharedLayout.getVec(); unsigned minVec = std::min(outVec, inVec); unsigned numElems = triton::gpu::getTotalElemsPerThread(srcTy); diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 70f8d3d80900..df798d5ac6c2 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -3607,6 +3607,7 @@ def kernel(Out): # MmaLayout((2, 0), [1, 4], [1, 1], [1, 1], [0, 1], [16, 8]), # MmaLayout(1, [4, 1], [1, 1], [0, 1]), # MmaLayout((2, 0), [4, 1], [1, 1], [1, 1], [0, 1], [16, 8]), + BlockedLayout([1, 16], [8, 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), BlockedLayout([1, 8], [2, 16], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), BlockedLayout([1, 4], [4, 8], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), BlockedLayout([1, 1], [1, 32], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), @@ -3624,15 +3625,16 @@ def kernel(Out): ] -@pytest.mark.parametrize("shape", [(128, 128)]) +@pytest.mark.parametrize("M, N", [[64, 1], [64, 64], [128, 128], [1, 64]]) @pytest.mark.parametrize("dtype", ['float16']) @pytest.mark.parametrize("src_layout", layouts) @pytest.mark.parametrize("interm_layout", intermediate_layouts) @pytest.mark.parametrize("dst_layout", layouts) -def test_convert2d(dtype, shape, src_layout, interm_layout, dst_layout, device): +def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device): if is_hip(): pytest.skip("test_convert2d is not supported in HIP") - + if (M == 1 or N == 1) and interm_layout: + pytest.skip("Out of bound access when maxPhase > 1") if str(src_layout) == str(dst_layout): pytest.skip() if 'mma' in str(src_layout) and 'mma' in str(dst_layout): @@ -3648,43 +3650,43 @@ def test_convert2d(dtype, shape, src_layout, interm_layout, dst_layout, device): """ conversion = f""" - %12 = triton_gpu.convert_layout %9 : (tensor<128x128xi32, #src>) -> tensor<128x128xi32, #dst> - %13 = triton_gpu.convert_layout %11 : (tensor<128x128xf16, #src>) -> tensor<128x128xf16, #dst> + %12 = triton_gpu.convert_layout %9 : (tensor<{M}x{N}xi32, #src>) -> tensor<{M}x{N}xi32, #dst> + %13 = triton_gpu.convert_layout %11 : (tensor<{M}x{N}xf16, #src>) -> tensor<{M}x{N}xf16, #dst> """ if interm_layout is None else f""" - %15 = triton_gpu.convert_layout %9 : (tensor<128x128xi32, #src>) -> tensor<128x128xi32, #interm> - %16 = triton_gpu.convert_layout %15 : (tensor<128x128xi32, #interm>) -> tensor<128x128xi32, #src> - %17 = triton_gpu.convert_layout %11 : (tensor<128x128xf16, #src>) -> tensor<128x128xf16, #interm> - %18 = triton_gpu.convert_layout %17 : (tensor<128x128xf16, #interm>) -> tensor<128x128xf16, #src> + %15 = triton_gpu.convert_layout %9 : (tensor<{M}x{N}xi32, #src>) -> tensor<{M}x{N}xi32, #interm> + %16 = triton_gpu.convert_layout %15 : (tensor<{M}x{N}xi32, #interm>) -> tensor<{M}x{N}xi32, #src> + %17 = triton_gpu.convert_layout %11 : (tensor<{M}x{N}xf16, #src>) -> tensor<{M}x{N}xf16, #interm> + %18 = triton_gpu.convert_layout %17 : (tensor<{M}x{N}xf16, #interm>) -> tensor<{M}x{N}xf16, #src> - %12 = triton_gpu.convert_layout %16 : (tensor<128x128xi32, #src>) -> tensor<128x128xi32, #dst> - %13 = triton_gpu.convert_layout %18 : (tensor<128x128xf16, #src>) -> tensor<128x128xf16, #dst> + %12 = triton_gpu.convert_layout %16 : (tensor<{M}x{N}xi32, #src>) -> tensor<{M}x{N}xi32, #dst> + %13 = triton_gpu.convert_layout %18 : (tensor<{M}x{N}xf16, #src>) -> tensor<{M}x{N}xf16, #dst> """ - ir = layouts + """ - module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { - tt.func public @kernel_0d1d(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) { - %cst = arith.constant dense<128> : tensor<128x1xi32, #src> - %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #src}>> - %1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #src}>> - %2 = tt.splat %arg0 : (!tt.ptr) -> tensor<128x128x!tt.ptr, #src> - %4 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #src}>>) -> tensor<128x1xi32, #src> - %5 = arith.muli %4, %cst : tensor<128x1xi32, #src> - %6 = tt.expand_dims %1 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #src}>>) -> tensor<1x128xi32, #src> - %7 = tt.broadcast %6 : (tensor<1x128xi32, #src>) -> tensor<128x128xi32, #src> - %8 = tt.broadcast %5 : (tensor<128x1xi32, #src>) -> tensor<128x128xi32, #src> - %9 = arith.addi %8, %7 : tensor<128x128xi32, #src> - %10 = tt.addptr %2, %9 : tensor<128x128x!tt.ptr, #src>, tensor<128x128xi32, #src> - %11 = tt.load %10 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf16, #src> - %3 = tt.splat %arg1 : (!tt.ptr) -> tensor<128x128x!tt.ptr, #dst> - """ + conversion + """ - %14 = tt.addptr %3, %12 : tensor<128x128x!tt.ptr, #dst>, tensor<128x128xi32, #dst> - tt.store %14, %13 : tensor<128x128xf16, #dst> + ir = layouts + f""" + module attributes {{"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32}} {{ + tt.func public @kernel_0d1d(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ + %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #src> + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>> + %1 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>> + %2 = tt.splat %arg0 : (!tt.ptr) -> tensor<{M}x{N}x!tt.ptr, #src> + %4 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xi32, #src> + %5 = arith.muli %4, %cst : tensor<{M}x1xi32, #src> + %6 = tt.expand_dims %1 {{axis = 0 : i32}} : (tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>) -> tensor<1x{N}xi32, #src> + %7 = tt.broadcast %6 : (tensor<1x{N}xi32, #src>) -> tensor<{M}x{N}xi32, #src> + %8 = tt.broadcast %5 : (tensor<{M}x1xi32, #src>) -> tensor<{M}x{N}xi32, #src> + %9 = arith.addi %8, %7 : tensor<{M}x{N}xi32, #src> + %10 = tt.addptr %2, %9 : tensor<{M}x{N}x!tt.ptr, #src>, tensor<{M}x{N}xi32, #src> + %11 = tt.load %10 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}xf16, #src> + %3 = tt.splat %arg1 : (!tt.ptr) -> tensor<{M}x{N}x!tt.ptr, #dst> + """ + conversion + f""" + %14 = tt.addptr %3, %12 : tensor<{M}x{N}x!tt.ptr, #dst>, tensor<{M}x{N}xi32, #dst> + tt.store %14, %13 : tensor<{M}x{N}xf16, #dst> tt.return - } -} + }} +}} """ - x = to_triton(numpy_random(shape, dtype_str=dtype), device=device) + x = to_triton(numpy_random((M, N), dtype_str=dtype), device=device) z = torch.empty_like(x) # write the IR to a temporary file using mkstemp From d3956a21f3c9092ea756b818f2dc5f3d93bcbf17 Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Wed, 13 Sep 2023 10:03:29 -0700 Subject: [PATCH 036/122] [BACKEND] Add LLVM pre-processing pass to break struct types (#2285) Add infrastructure to be able to add and test custom LLVM passes in the backend. This will allow use to apply some low level optimizations and cleanup on LLVM IR. Add a first pass that breaks up phi of struct created by lowering to LLVM. Those can often pessimise the optimizer as it would block optimizations going through phi nodes. --- bin/CMakeLists.txt | 17 +++ bin/triton-llvm-opt.cpp | 114 +++++++++++++++++++++ lib/Target/LLVMIR/CMakeLists.txt | 1 + lib/Target/LLVMIR/LLVMIRBreakPhiStruct.cpp | 60 +++++++++++ lib/Target/LLVMIR/LLVMIRTranslation.cpp | 97 +++++++++++++++++- lib/Target/LLVMIR/LLVMPasses.h | 16 +++ test/LLVMIR/break-phi-struct.ll | 33 ++++++ test/lit.cfg.py | 3 +- 8 files changed, 338 insertions(+), 3 deletions(-) create mode 100644 bin/triton-llvm-opt.cpp create mode 100644 lib/Target/LLVMIR/LLVMIRBreakPhiStruct.cpp create mode 100644 lib/Target/LLVMIR/LLVMPasses.h create mode 100644 test/LLVMIR/break-phi-struct.ll diff --git a/bin/CMakeLists.txt b/bin/CMakeLists.txt index 2b2f6afeb5ce..9da8e5628667 100644 --- a/bin/CMakeLists.txt +++ b/bin/CMakeLists.txt @@ -79,3 +79,20 @@ llvm_update_compile_flags(triton-translate) MLIRROCDLToLLVMIRTranslation ) mlir_check_all_link_libraries(triton-translate) + +add_llvm_executable(triton-llvm-opt + triton-llvm-opt.cpp + + DEPENDS + intrinsics_gen + SUPPORT_PLUGINS + ) +target_link_libraries(triton-llvm-opt PRIVATE + TritonLLVMIR + + LLVMCore + LLVMSupport + LLVMOption + LLVMCodeGen + ) +export_executable_symbols_for_plugins(triton-llvm-opt) diff --git a/bin/triton-llvm-opt.cpp b/bin/triton-llvm-opt.cpp new file mode 100644 index 000000000000..fe82a1dce28e --- /dev/null +++ b/bin/triton-llvm-opt.cpp @@ -0,0 +1,114 @@ +/// Trimmed down clone of llvm opt to be able to test triton custom llvm ir +/// passes. +#include "lib/Target/LLVMIR/LLVMPasses.h" +#include "llvm/CodeGen/CommandFlags.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Verifier.h" +#include "llvm/IRReader/IRReader.h" +#include "llvm/Passes/PassBuilder.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/InitLLVM.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/SystemUtils.h" +#include "llvm/Support/ToolOutputFile.h" +#include "llvm/TargetParser/Triple.h" +#include + +using namespace llvm; + +static cl::opt InputFilename(cl::Positional, + cl::desc(""), + cl::init("-"), + cl::value_desc("filename")); + +static cl::opt ClDataLayout("data-layout", + cl::desc("data layout string to use"), + cl::value_desc("layout-string"), + cl::init("")); +static cl::opt + TargetTriple("mtriple", cl::desc("Override target triple for module")); + +static cl::opt + BreakStructPhiNodes("break-struct-phi-nodes", + llvm::cl::desc("run pass to break phi struct"), + cl::init(false)); + +namespace { +static std::function makeOptimizingPipeline() { + return [](Module *m) -> Error { + PipelineTuningOptions tuningOptions; + PassBuilder pb(nullptr, tuningOptions); + + LoopAnalysisManager lam; + FunctionAnalysisManager fam; + CGSCCAnalysisManager cgam; + ModuleAnalysisManager mam; + pb.registerModuleAnalyses(mam); + pb.registerCGSCCAnalyses(cgam); + pb.registerFunctionAnalyses(fam); + pb.registerLoopAnalyses(lam); + pb.crossRegisterProxies(lam, fam, cgam, mam); + + ModulePassManager mpm; + llvm::FunctionPassManager fpm; + if (BreakStructPhiNodes) + fpm.addPass(BreakStructPhiNodesPass()); + mpm.addPass(createModuleToFunctionPassAdaptor(std::move(fpm))); + mpm.run(*m, mam); + return Error::success(); + }; +} +} // namespace + +int main(int argc, char **argv) { + InitLLVM X(argc, argv); + cl::ParseCommandLineOptions( + argc, argv, "llvm .bc -> .bc modular optimizer and analysis printer\n"); + + LLVMContext Context; + SMDiagnostic Err; + + // Load the input module... + auto SetDataLayout = [](StringRef, StringRef) -> std::optional { + if (ClDataLayout.empty()) + return std::nullopt; + return ClDataLayout; + }; + std::unique_ptr M; + M = parseIRFile(InputFilename, Err, Context, ParserCallbacks(SetDataLayout)); + if (!M) { + Err.print(argv[0], errs()); + return 1; + } + // If we are supposed to override the target triple or data layout, do so now. + if (!TargetTriple.empty()) + M->setTargetTriple(Triple::normalize(TargetTriple)); + auto optPipeline = makeOptimizingPipeline(); + if (auto err = optPipeline(M.get())) { + llvm::errs() << "Failed to optimize LLVM IR " << err << "\n"; + } + + if (verifyModule(*M, &errs())) { + errs() << argv[0] << ": " << InputFilename + << ": error: input module is broken!\n"; + return 1; + } + + // Write to standard output. + std::unique_ptr Out; + std::string OutputFilename = "-"; + std::error_code EC; + sys::fs::OpenFlags Flags = sys::fs::OF_TextWithCRLF; + Out.reset(new ToolOutputFile(OutputFilename, EC, Flags)); + if (EC) { + errs() << EC.message() << '\n'; + return 1; + } + Out->os() << *M << "\n"; + return 0; +} diff --git a/lib/Target/LLVMIR/CMakeLists.txt b/lib/Target/LLVMIR/CMakeLists.txt index fbaefe68375c..9c0a6c26eea9 100644 --- a/lib/Target/LLVMIR/CMakeLists.txt +++ b/lib/Target/LLVMIR/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_translation_library(TritonLLVMIR LLVMIRTranslation.cpp LLVMDIScope.cpp + LLVMIRBreakPhiStruct.cpp LINK_COMPONENTS Core diff --git a/lib/Target/LLVMIR/LLVMIRBreakPhiStruct.cpp b/lib/Target/LLVMIR/LLVMIRBreakPhiStruct.cpp new file mode 100644 index 000000000000..44afcfd21109 --- /dev/null +++ b/lib/Target/LLVMIR/LLVMIRBreakPhiStruct.cpp @@ -0,0 +1,60 @@ +//===----------------------------------------------------------------------===// +/// Implements a trivial pass breaking up 1 level deep structure in phi nodes. +/// This handles the common case generated by Triton and allow better +/// optimizations down the compiler pipeline. +//===----------------------------------------------------------------------===// +#include "LLVMPasses.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instructions.h" + +using namespace llvm; + +static bool processPhiStruct(PHINode *phiNode) { + StructType *STy = dyn_cast(phiNode->getType()); + if (!STy) + return false; + IRBuilder<> builder(phiNode); + unsigned numOperands = phiNode->getNumIncomingValues(); + unsigned numScalarEl = STy->getNumElements(); + Value *newStruct = UndefValue::get(STy); + builder.SetInsertPoint(phiNode->getParent()->getFirstNonPHI()); + llvm::IRBuilderBase::InsertPoint insertInsertPt = builder.saveIP(); + for (unsigned i = 0; i < numScalarEl; i++) { + builder.SetInsertPoint(phiNode); + PHINode *newPhiNode = + builder.CreatePHI(STy->getElementType(i), numOperands); + for (unsigned j = 0; j < numOperands; ++j) { + Value *operand = phiNode->getIncomingValue(j); + builder.SetInsertPoint(phiNode->getIncomingBlock(j)->getTerminator()); + newPhiNode->addIncoming(builder.CreateExtractValue(operand, i), + phiNode->getIncomingBlock(j)); + } + builder.restoreIP(insertInsertPt); + newStruct = builder.CreateInsertValue(newStruct, newPhiNode, i); + insertInsertPt = builder.saveIP(); + } + phiNode->replaceAllUsesWith(newStruct); + return true; +} + +static bool runOnFunction(Function &F) { + bool Changed = false; + SmallVector PhiNodes; + for (BasicBlock &BB : F) { + for (Instruction &inst : BB) { + if (PHINode *phiNode = dyn_cast(&inst)) { + Changed |= processPhiStruct(phiNode); + continue; + } + break; + } + } + return Changed; +} + +PreservedAnalyses BreakStructPhiNodesPass::run(Function &F, + FunctionAnalysisManager &AM) { + + bool b = runOnFunction(F); + return b ? PreservedAnalyses::none() : PreservedAnalyses::all(); +} diff --git a/lib/Target/LLVMIR/LLVMIRTranslation.cpp b/lib/Target/LLVMIR/LLVMIRTranslation.cpp index 6d64fcbb1a29..3acc6a92e09c 100644 --- a/lib/Target/LLVMIR/LLVMIRTranslation.cpp +++ b/lib/Target/LLVMIR/LLVMIRTranslation.cpp @@ -1,7 +1,8 @@ #include "triton/Target/LLVMIR/LLVMIRTranslation.h" - +#include "LLVMPasses.h" #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" #include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h" +#include "mlir/Conversion/Passes.h" #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/Transforms/Passes.h" @@ -26,11 +27,20 @@ #include "llvm/ADT/APInt.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/CallingConv.h" #include "llvm/IR/Constants.h" +#include "llvm/IR/Module.h" #include "llvm/IRReader/IRReader.h" #include "llvm/Linker/Linker.h" +#include "llvm/Passes/OptimizationLevel.h" +#include "llvm/Passes/PassBuilder.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/FormatVariadic.h" #include "llvm/Support/SourceMgr.h" +#include "llvm/Target/TargetMachine.h" +#include "llvm/Transforms/InstCombine/InstCombine.h" +#include #ifdef _WIN32 #define WIN32_LEAN_AND_MEAN #include @@ -42,6 +52,89 @@ namespace fs = std::filesystem; +namespace { +using namespace llvm; + +static std::optional mapToLevel(unsigned optLevel, + unsigned sizeLevel) { + switch (optLevel) { + case 0: + return OptimizationLevel::O0; + + case 1: + return OptimizationLevel::O1; + + case 2: + switch (sizeLevel) { + case 0: + return OptimizationLevel::O2; + + case 1: + return OptimizationLevel::Os; + + case 2: + return OptimizationLevel::Oz; + } + break; + case 3: + return OptimizationLevel::O3; + } + return std::nullopt; +} + +// Create and return a lambda that uses LLVM pass manager builder to set up +// optimizations based on the given level. +static std::function +makeOptimizingPipeline(unsigned optLevel, unsigned sizeLevel, + TargetMachine *targetMachine) { + return [optLevel, sizeLevel, targetMachine](Module *m) -> Error { + std::optional ol = mapToLevel(optLevel, sizeLevel); + if (!ol) { + return make_error( + formatv("invalid optimization/size level {0}/{1}", optLevel, + sizeLevel) + .str(), + inconvertibleErrorCode()); + } + LoopAnalysisManager lam; + FunctionAnalysisManager fam; + CGSCCAnalysisManager cgam; + ModuleAnalysisManager mam; + + PipelineTuningOptions tuningOptions; + tuningOptions.LoopUnrolling = true; + tuningOptions.LoopInterleaving = true; + tuningOptions.LoopVectorization = true; + // TODO: currently we run SLP vectorizer with an empty target machine. This + // cause the vectorizer to create larger vector which could be bad. + // Disabling it would currently cause regressions as this pass also applies + // some scheduling that helps performance in some cases. We should work on + // using NVPTX target instead and address the performance regressions with + // some scheduling solution. + tuningOptions.SLPVectorization = true; + + PassBuilder pb(targetMachine, tuningOptions); + + pb.registerModuleAnalyses(mam); + pb.registerCGSCCAnalyses(cgam); + pb.registerFunctionAnalyses(fam); + pb.registerLoopAnalyses(lam); + pb.crossRegisterProxies(lam, fam, cgam, mam); + + ModulePassManager mpm; + llvm::FunctionPassManager fpm; + // Triton generates large structure of scalars which may pessimise + // optimizations, we run a pass to break up phi of struct to make sure all + // the struct are removed for the following passes. + fpm.addPass(BreakStructPhiNodesPass()); + mpm.addPass(createModuleToFunctionPassAdaptor(std::move(fpm))); + mpm.addPass(pb.buildPerModuleDefaultPipeline(*ol)); + mpm.run(*m, mam); + return Error::success(); + }; +} +} // namespace + namespace mlir { namespace triton { @@ -308,7 +401,7 @@ translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module, return nullptr; } - auto optPipeline = mlir::makeOptimizingTransformer( + auto optPipeline = makeOptimizingPipeline( /*optLevel=*/3, /*sizeLevel=*/0, /*targetMachine=*/nullptr); diff --git a/lib/Target/LLVMIR/LLVMPasses.h b/lib/Target/LLVMIR/LLVMPasses.h new file mode 100644 index 000000000000..1dcdb2992c02 --- /dev/null +++ b/lib/Target/LLVMIR/LLVMPasses.h @@ -0,0 +1,16 @@ +#include "llvm/IR/PassManager.h" +#include "llvm/Pass.h" +#include "llvm/Support/CodeGen.h" + +namespace llvm { + +// Pass to pre-process LLVM IR before optimization and break up phi of struct. +// Breaking up those phis into elementary types allows better optimizations +// downstream. +struct BreakStructPhiNodesPass : PassInfoMixin { + PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM); + + static StringRef name() { return "BreakStructPhiNodesPass"; } +}; + +} // namespace llvm diff --git a/test/LLVMIR/break-phi-struct.ll b/test/LLVMIR/break-phi-struct.ll new file mode 100644 index 000000000000..b27c87588f63 --- /dev/null +++ b/test/LLVMIR/break-phi-struct.ll @@ -0,0 +1,33 @@ +; RUN: triton-llvm-opt -break-struct-phi-nodes %s | FileCheck %s + +; CHECK-LABEL: struct +define {i32, i32} @struct(i1 %c) { +; CHECK: br i1 %{{.*}}, label [[TRUE:%.*]], label [[FALSE:%.*]] + br i1 %c, label %true, label %false + +true: + %s.1 = insertvalue {i32, i32} undef, i32 20, 0 + %s.2 = insertvalue {i32, i32} %s.1, i32 200, 1 + +; CHECK-DAG: [[E0:%.*]] = extractvalue { i32, i32 } %{{.*}}, 0 +; CHECK-DAG: [[E1:%.*]] = extractvalue { i32, i32 } %{{.*}}, 1 +; CHECK: br + br label %exit + +false: + %s.3 = insertvalue {i32, i32} undef, i32 30, 0 + %s.4 = insertvalue {i32, i32} %s.3, i32 300, 1 +; CHECK-DAG: [[E2:%.*]] = extractvalue { i32, i32 } %{{.*}}, 0 +; CHECK-DAG: [[E3:%.*]] = extractvalue { i32, i32 } %{{.*}}, 1 +; CHECK: br + br label %exit + +exit: +; CHECK-DAG: [[PHI0:%.*]] = phi i32 [ [[E0]], [[TRUE]] ], [ [[E2]], [[FALSE]] ] +; CHECK-DAG: [[PHI1:%.*]] = phi i32 [ [[E1]], [[TRUE]] ], [ [[E3]], [[FALSE]] ] +; CHECK: [[S0:%.*]] = insertvalue { i32, i32 } undef, i32 [[PHI0]], 0 +; CHECK: [[S1:%.*]] = insertvalue { i32, i32 } [[S0]], i32 [[PHI1]], 1 +; CHECK: ret { i32, i32 } [[S1]] + %r = phi {i32, i32} [ %s.2, %true], [ %s.4, %false ] + ret {i32, i32} %r +} diff --git a/test/lit.cfg.py b/test/lit.cfg.py index db65d3e4f172..5ea9c458dcd0 100644 --- a/test/lit.cfg.py +++ b/test/lit.cfg.py @@ -19,7 +19,7 @@ config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell) # suffixes: A list of file extensions to treat as test files. -config.suffixes = ['.mlir'] +config.suffixes = ['.mlir', '.ll'] # test_source_root: The root path where tests are located. config.test_source_root = os.path.dirname(__file__) @@ -62,6 +62,7 @@ llvm_config.with_environment('PATH', d, append_path=True) tools = [ 'triton-opt', + 'triton-llvm-opt', ToolSubst('%PYTHON', config.python_executable, unresolved='ignore'), ] From 896ee611e0453b56d20aa950ca63117c066af33f Mon Sep 17 00:00:00 2001 From: Sergey Kozub <110990540+sergeykozub@users.noreply.github.com> Date: Wed, 13 Sep 2023 19:04:01 +0200 Subject: [PATCH 037/122] [NFC] Create explicit conversion pattern for ExternElementwiseOp in TT->TTGPU pass (#2284) This is needed for forward-compatibility with MLIR that now has "inherent" and "discardable" attributes (https://mlir.llvm.org/OpenMeetings/2023-02-09-Properties.pdf) and the ExternElementwiseOp attrs do not propagate with the current `addNamedAttrs` implementation. --- .../TritonToTritonGPUPass.cpp | 24 ++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp index b03d76ac4bdb..de5ad6947c53 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp @@ -600,6 +600,24 @@ struct TritonScanReturnPattern } }; +struct TritonExternElementwisePattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::ExternElementwiseOp op, + typename triton::ExternElementwiseOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type retType = this->getTypeConverter()->convertType(op.getType()); + addNamedAttrs(rewriter.replaceOpWithNewOp( + op, retType, adaptor.getOperands(), op.getLibnameAttr(), + op.getLibpathAttr(), op.getSymbolAttr(), + op.getPureAttr()), + adaptor.getAttributes()); + return success(); + } +}; + struct TritonPrintPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -692,9 +710,9 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter, TritonReduceReturnPattern, TritonScanPattern, TritonScanReturnPattern, TritonTransPattern, TritonExpandDimsPattern, TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern, TritonStorePattern, - TritonGenericPattern, TritonPrintPattern, - TritonAssertPattern, TritonAtomicRMWPattern, TritonFuncOpPattern, - TritonReturnOpPattern, TritonCallOpPattern>(typeConverter, context); + TritonExternElementwisePattern, TritonPrintPattern, TritonAssertPattern, + TritonAtomicRMWPattern, TritonFuncOpPattern, TritonReturnOpPattern, + TritonCallOpPattern>(typeConverter, context); } // From cf7f8c5ea4b6f00c6a93ed41b44560542427db67 Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Wed, 13 Sep 2023 10:04:36 -0700 Subject: [PATCH 038/122] [BACKEND] Optimization to sink broadcast ops (#2274) Try to move broadcast ops after arithmetic and convert ops in order to reduce the amount of work needed. --- .../Triton/Transforms/ReorderBroadcast.cpp | 17 +++++---- .../Transforms/RemoveLayoutConversions.cpp | 35 ++++++++++--------- test/Triton/reorder-broadcast.mlir | 15 ++++++++ test/TritonGPU/combine.mlir | 17 +++++++++ 4 files changed, 61 insertions(+), 23 deletions(-) diff --git a/lib/Dialect/Triton/Transforms/ReorderBroadcast.cpp b/lib/Dialect/Triton/Transforms/ReorderBroadcast.cpp index 1930ab9f6950..6c0c9fcc9341 100644 --- a/lib/Dialect/Triton/Transforms/ReorderBroadcast.cpp +++ b/lib/Dialect/Triton/Transforms/ReorderBroadcast.cpp @@ -116,6 +116,7 @@ struct MoveBroadcastAfterElementwisePattern auto operands = op->getOperands(); bool seenBroadcast = false; + Type srcType; for (auto operand : operands) { auto definingOp = operand.getDefiningOp(); if (!definingOp) { @@ -123,11 +124,13 @@ struct MoveBroadcastAfterElementwisePattern } if (auto broadcastOp = llvm::dyn_cast(definingOp)) { - if (seenBroadcast) { - // Only support one broadcasted argument for now + if (!seenBroadcast) { + seenBroadcast = true; + srcType = broadcastOp.getSrc().getType(); + } else if (srcType != broadcastOp.getSrc().getType()) { + // If the broadcast have different types we cannot re-order. return mlir::failure(); } - seenBroadcast = true; } else if (!isSplat(definingOp)) { // Not splat or broadcast return mlir::failure(); @@ -149,8 +152,7 @@ struct MoveBroadcastAfterElementwisePattern } } - auto src = broadcastOp.getSrc(); - auto srcTy = src.getType().dyn_cast(); + auto srcTy = broadcastOp.getSrc().getType().dyn_cast(); auto srcShape = srcTy.getShape(); auto srcEncoding = srcTy.getEncoding(); @@ -158,8 +160,9 @@ struct MoveBroadcastAfterElementwisePattern llvm::SmallVector newOperands; for (auto operand : operands) { auto definingOp = operand.getDefiningOp(); - if (llvm::isa(definingOp)) { - newOperands.push_back(src); + if (auto broadcastSrcOp = + llvm::dyn_cast(definingOp)) { + newOperands.push_back(broadcastSrcOp.getSrc()); continue; } auto elemTy = diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp index 0b0dd9442990..cbdb59c88d2b 100644 --- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp @@ -915,7 +915,7 @@ static void backwardRematerialization(ConvertLayoutOp convertOp) { // For convert left we try to hoist them above type extension to reduce the cost // of the convert. -static void hoistConvertOnTopOfExt(ConvertLayoutOp convertOp) { +static void hoistConvertOnTopOfExtOrBroadcast(ConvertLayoutOp convertOp) { // we don't want to rematerialize any conversion to/from shared if (triton::gpu::isSharedEncoding(convertOp.getResult()) || triton::gpu::isSharedEncoding(convertOp.getOperand())) @@ -926,25 +926,27 @@ static void hoistConvertOnTopOfExt(ConvertLayoutOp convertOp) { if (targetType.getEncoding().isa()) return; - auto isExtOp = [](Operation *op) { - return isa(op); + auto isExtOrBroadcastOp = [](Operation *op) { + return isa(op); }; // 1. Take a backward slice of all the tensor dependencies. SetVector slice; DenseMap layout; - LogicalResult result = getRematerializableSlice( - convertOp.getOperand(), targetType.getEncoding(), slice, layout, isExtOp); + LogicalResult result = + getRematerializableSlice(convertOp.getOperand(), targetType.getEncoding(), + slice, layout, isExtOrBroadcastOp); if (result.failed()) return; - Operation *extOp = nullptr; + Operation *extOrBroadcatOp = nullptr; unsigned sliceSize = slice.size(); for (unsigned i = 0; i < sliceSize; i++) { Value v = slice[i]; Operation *op = v.getDefiningOp(); if (!op) continue; - if (isExtOp(op)) { + if (isExtOrBroadcastOp(op)) { SetVector tempSlice; DenseMap tempLayout; LogicalResult result = getRematerializableSlice( @@ -958,24 +960,25 @@ static void hoistConvertOnTopOfExt(ConvertLayoutOp convertOp) { } // Only apply it if there is a single ext op otherwise we would have to // duplicate the convert. - if (extOp != nullptr) + if (extOrBroadcatOp != nullptr) return; - extOp = op; + extOrBroadcatOp = op; } } - if (extOp == nullptr) + if (extOrBroadcatOp == nullptr) return; // Move the convert before the ext op and rewrite the slice. - OpBuilder builder(extOp); - auto tensorType = extOp->getOperand(0).getType().cast(); + OpBuilder builder(extOrBroadcatOp); + auto tensorType = + extOrBroadcatOp->getOperand(0).getType().cast(); auto newType = RankedTensorType::get(tensorType.getShape(), tensorType.getElementType(), - layout[extOp->getResult(0)]); + layout[extOrBroadcatOp->getResult(0)]); auto newConvertOp = builder.create( - convertOp.getLoc(), newType, extOp->getOperand(0)); + convertOp.getLoc(), newType, extOrBroadcatOp->getOperand(0)); IRMapping mapping; - mapping.map(extOp->getOperand(0), newConvertOp.getResult()); + mapping.map(extOrBroadcatOp->getOperand(0), newConvertOp.getResult()); // 3. Rewrite the slice. rewriteSlice(slice, layout, convertOp, mapping); } @@ -994,7 +997,7 @@ static void hoistConvert(ModuleOp module) { module.walk( [&](ConvertLayoutOp convertOp) { convertOps.push_back(convertOp); }); for (ConvertLayoutOp convertOp : convertOps) { - hoistConvertOnTopOfExt(convertOp); + hoistConvertOnTopOfExtOrBroadcast(convertOp); } } diff --git a/test/Triton/reorder-broadcast.mlir b/test/Triton/reorder-broadcast.mlir index fbe44dace289..d5e054337a08 100644 --- a/test/Triton/reorder-broadcast.mlir +++ b/test/Triton/reorder-broadcast.mlir @@ -38,3 +38,18 @@ tt.func @test_broadcast_elementwise_pattern(%arg0: tensor<128x1xf32>) -> (tensor tt.return %abs, %add : tensor<128x128xf32>, tensor<128x32xf32> } + +// CHECK-LABEL: @test_broadcast_binary_op_pattern +tt.func @test_broadcast_binary_op_pattern(%arg0: tensor<128x1xf32>, %arg1: tensor<128x1xf32>, %arg2: tensor<1x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) { + // CHECK: %[[mul:.*]] = arith.mulf %{{.*}}, %{{.*}} : tensor<128x1xf32> + // CHECK-NEXT: %{{.*}} = tt.broadcast %[[mul]] : (tensor<128x1xf32>) -> tensor<128x128xf32> + %broadcast0 = tt.broadcast %arg0 : (tensor<128x1xf32>) -> tensor<128x128xf32> + %broadcast1 = tt.broadcast %arg1 : (tensor<128x1xf32>) -> tensor<128x128xf32> + %mul = arith.mulf %broadcast0, %broadcast1 : tensor<128x128xf32> + + // CHECK: %[[mul:.*]] = arith.mulf %{{.*}}, %{{.*}} : tensor<128x128xf32> + %broadcast2 = tt.broadcast %arg2 : (tensor<1x128xf32>) -> tensor<128x128xf32> + %mul1 = arith.mulf %broadcast0, %broadcast2 : tensor<128x128xf32> + + tt.return %mul, %mul1 : tensor<128x128xf32>, tensor<128x128xf32> +} diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir index ec6925b373cb..c6be17f8eb04 100644 --- a/test/TritonGPU/combine.mlir +++ b/test/TritonGPU/combine.mlir @@ -3,6 +3,10 @@ #layout0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> #layout1 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#layout2 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#layout3 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> + + module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK: [[$target_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> @@ -105,6 +109,19 @@ tt.func @hoist_above_ext2(%arg0: tensor<1024xf16, #layout0>, %arg1: f16) -> tens tt.return %4 : tensor<1024xf32, #layout1> } +// Hoist the convert on top of broadcast to make it cheaper. +// CHECK-LABEL: hoist_above_broadcast +tt.func @hoist_above_broadcast(%arg0: tensor<1024x1xf32, #layout2>, %arg1: f32) -> tensor<1024x128xf32, #layout3> { +// CHECK: %[[CVT:.+]] = triton_gpu.convert_layout +// CHECK: tt.broadcast %[[CVT]] +// CHECK-NOT: triton_gpu.convert_layout +// CHECK: tt.return + %0 = tt.broadcast %arg0 : (tensor<1024x1xf32, #layout2>) -> tensor<1024x128xf32, #layout2> + %1 = tt.splat %arg1 : (f32) -> tensor<1024x128xf32, #layout2> + %2 = arith.addf %0, %1 : tensor<1024x128xf32, #layout2> + %3 = triton_gpu.convert_layout %2 : (tensor<1024x128xf32, #layout2>) -> tensor<1024x128xf32, #layout3> + tt.return %3 : tensor<1024x128xf32, #layout3> +} // CHECK-LABEL: if From b63e8f87fc2a6d0772bbc7f18b107a554f058166 Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Wed, 13 Sep 2023 10:05:47 -0700 Subject: [PATCH 039/122] [FRONTEND] Override prototype (#2214) Low tech but very useful way to override kernels on the fly. This can be use for debugging functionality or performance problems this lets user dump modify and feed back IR into the jit compiler. --- python/triton/compiler/compiler.py | 15 ++++++++++++- python/triton/runtime/cache.py | 35 +++++++++++++++++++++++++----- 2 files changed, 44 insertions(+), 6 deletions(-) diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index 51c00add462a..f5a5d941160b 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -20,7 +20,7 @@ # from ..runtime import driver, jit, JITFunction # TODO: runtime.errors from ..runtime.autotuner import OutOfResources -from ..runtime.cache import get_cache_manager +from ..runtime.cache import get_cache_manager, get_dump_manager, get_override_manager from ..runtime.driver import driver from ..runtime.jit import (JITFunction, get_cuda_stream, get_current_device, get_device_capability, version_key) @@ -229,6 +229,9 @@ def make_hash(fn, arch, env_vars, **kwargs): key = f"{fn.cache_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{num_ctas}-{num_stages}-{enable_warp_specialization}-{enable_persistent}-{debug}-{arch}-{env_vars_list}" return hashlib.md5(key.encode("utf-8")).hexdigest() assert isinstance(fn, str) + ignore_version = kwargs.get('ignore_version', False) + if (ignore_version): + return hashlib.md5((Path(fn).read_text()).encode("utf-8")).hexdigest() return hashlib.md5((Path(fn).read_text() + version_key()).encode("utf-8")).hexdigest() @@ -433,6 +436,11 @@ def compile(fn, **kwargs): # create cache manager fn_cache_manager = get_cache_manager(make_hash(fn, arch, get_env_vars(), **kwargs)) + # managers used to dump and override IR for debugging + enable_override = os.environ.get("TRITON_KERNEL_OVERRIDE", "0") == "1" + fn_override_manager = get_override_manager(make_hash(fn, arch, get_env_vars(), **kwargs, ignore_version=True)) + fn_dump_manager = get_dump_manager(make_hash(fn, arch, get_env_vars(), **kwargs, ignore_version=True)) + # determine name and extension type of provided function if isinstance(fn, JITFunction): name, ext = fn.__name__, "ast" @@ -493,6 +501,11 @@ def compile(fn, **kwargs): else: metadata_group[ir_filename] = fn_cache_manager.put(next_module, ir_filename) fn_cache_manager.put(next_module, ir_filename) + fn_dump_manager.put(next_module, ir_filename) + if (enable_override and fn_override_manager.has_file(ir_filename)): + print(f"\nOverriding kernel with file {ir_filename}") + full_name = fn_override_manager.get_file(ir_filename) + next_module = parse(full_name) else: if ir_name == "amdgcn": extra_file_name = f"{name}.hsaco_path" diff --git a/python/triton/runtime/cache.py b/python/triton/runtime/cache.py index db8f6193e9ac..e4721cbe3b2e 100644 --- a/python/triton/runtime/cache.py +++ b/python/triton/runtime/cache.py @@ -10,6 +10,14 @@ def default_cache_dir(): return os.path.join(Path.home(), ".triton", "cache") +def default_override_dir(): + return os.path.join(Path.home(), ".triton", "override") + + +def default_dump_dir(): + return os.path.join(Path.home(), ".triton", "dump") + + class CacheManager(ABC): def __init__(self, key): pass @@ -36,17 +44,26 @@ def put_group(self, filename: str, group: Dict[str, str]): class FileCacheManager(CacheManager): - def __init__(self, key): + def __init__(self, key, override=False, dump=False): self.key = key self.lock_path = None - # create cache directory if it doesn't exist - self.cache_dir = os.getenv('TRITON_CACHE_DIR', "").strip() or default_cache_dir() - if self.cache_dir: + if (dump): + self.cache_dir = default_dump_dir() self.cache_dir = os.path.join(self.cache_dir, self.key) self.lock_path = os.path.join(self.cache_dir, "lock") os.makedirs(self.cache_dir, exist_ok=True) + elif (override): + self.cache_dir = default_override_dir() + self.cache_dir = os.path.join(self.cache_dir, self.key) else: - raise RuntimeError("Could not create or locate cache dir") + # create cache directory if it doesn't exist + self.cache_dir = os.getenv('TRITON_CACHE_DIR', "").strip() or default_cache_dir() + if self.cache_dir: + self.cache_dir = os.path.join(self.cache_dir, self.key) + self.lock_path = os.path.join(self.cache_dir, "lock") + os.makedirs(self.cache_dir, exist_ok=True) + else: + raise RuntimeError("Could not create or locate cache dir") def _make_path(self, filename) -> str: return os.path.join(self.cache_dir, filename) @@ -131,3 +148,11 @@ def get_cache_manager(key) -> CacheManager: __cache_cls_nme = user_cache_manager return __cache_cls(key) + + +def get_override_manager(key) -> CacheManager: + return __cache_cls(key, override=True) + + +def get_dump_manager(key) -> CacheManager: + return __cache_cls(key, dump=True) From c61d772eee24c4d48b61273d4f92a504c69c170c Mon Sep 17 00:00:00 2001 From: Khushi Agrawal Date: Thu, 14 Sep 2023 01:00:40 +0530 Subject: [PATCH 040/122] [DOCS] add missing docs (#2154) --- python/triton/language/core.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 3b9205f9d02a..451aa64fd58a 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -888,6 +888,12 @@ def broadcast_to(input, shape, _builder=None): @builtin def trans(input, _builder=None): + """ + Returns a transposed tensor. + + :param input: The input tensor. + :type input: + """ return semantic.trans(input, _builder) @@ -926,6 +932,15 @@ def view(input, shape, _builder=None): @builtin def reshape(input, shape, _builder=None): + """ + Returns a tensor with the same number of elements as input but with the + provided shape. + + :param input: The input tensor. + :type input: + :param shape: The new shape. + :type shape: Tuple[int] + """ shape = _shape_check_impl(shape) return semantic.reshape(input, shape, _builder) @@ -1224,6 +1239,14 @@ def where(condition, x, y, _builder=None): @builtin def umulhi(x, y, _builder=None): + """ + Returns the most significant 32 bits of the product of x and y. + + :param x: the input tensor + :type x: int32 + :param y: the input tensor + :type y: int32 + """ x = _to_tensor(x, _builder) y = _to_tensor(y, _builder) return semantic.umulhi(x, y, _builder) @@ -1231,6 +1254,15 @@ def umulhi(x, y, _builder=None): @builtin def fdiv(x, y, ieee_rounding=False, _builder=None): + """ + Returns a floating-point resultant tensor of dividing x by y. + + :param x: the input numerator value. + :param y: the input denominator value. + :param ieee_rounding: To follow IEEE-754 floating point number + rounding mechanism + :type ieee_rounding: bool + """ ieee_rounding = _constexpr_to_value(ieee_rounding) return semantic.fdiv(x, y, ieee_rounding, _builder) From 38a2ecdccfb6a01f45b9aad0ee658f654ef07412 Mon Sep 17 00:00:00 2001 From: Bin Fan Date: Wed, 13 Sep 2023 12:52:09 -0700 Subject: [PATCH 041/122] [OPTIMIZER] Fix Shared layout in OptimizeDotOperands pass to generate correct swizzling code (#2180) fix bug #1937 Co-authored-by: Philippe Tillet --- .../Dialect/TritonGPU/IR/TritonGPUAttrDefs.td | 35 ++++++++++++++++--- .../Transforms/OptimizeDotOperands.cpp | 7 +++- lib/Dialect/TritonGPU/Transforms/Pipeline.cpp | 4 ++- 3 files changed, 40 insertions(+), 6 deletions(-) diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index 2e3797af940a..7cfe7f448d42 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -113,6 +113,16 @@ compared to 1*64 when the hasLeadingOffset is false. "ArrayRef":$order, "CTALayoutAttr":$CTALayout, "unsigned":$typeWidthInBit), [{ + bool needTrans = false; // default value + return get(context, dotOpEnc, shape, order, CTALayout, typeWidthInBit, needTrans); + }]>, + + AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc, + "ArrayRef":$shape, + "ArrayRef":$order, + "CTALayoutAttr":$CTALayout, + "unsigned":$typeWidthInBit, + "bool":$needTrans), [{ auto mmaEnc = dotOpEnc.getParent().dyn_cast(); if(!mmaEnc) @@ -152,16 +162,23 @@ compared to 1*64 when the hasLeadingOffset is false. // --- handle A operand --- if (opIdx == 0) { // compute swizzling for A operand - int vec = (order[0] == 1) ? matShape[2] : matShape[0]; // k : m - int mmaStride = (order[0] == 1) ? matShape[0] : matShape[2]; + int m = (needTrans) ? matShape[2] : matShape[0]; + int k = (needTrans) ? matShape[0] : matShape[2]; + int vec = (order[0] == 1) ? k : m; + int mmaStride = (order[0] == 1) ? m : k; int maxPhase = mmaStride / perPhase; return get(context, vec, perPhase, maxPhase, order, CTALayout); } // --- handle B operand --- if (opIdx == 1) { - int vec = (order[0] == 1) ? matShape[1] : matShape[2]; // n : k - int mmaStride = (order[0] == 1) ? matShape[2] : matShape[1]; + // we compute vec and maxPhase m, n and k size of the mma + // instruction. when matmul operands is transposed, we should + // consider that to get m, n and k. + int n = needTrans ? matShape[2] : matShape[1]; + int k = needTrans ? matShape[1] : matShape[2]; + int vec = (order[0] == 1) ? n : k; + int mmaStride = (order[0] == 1) ? k : n; int maxPhase = mmaStride / perPhase; return get(context, vec, perPhase, maxPhase, order, CTALayout); } @@ -189,6 +206,16 @@ compared to 1*64 when the hasLeadingOffset is false. return get(context, dotOpEnc, shape, order, CTALayout, bitwidth); }]>, + AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc, + "ArrayRef":$shape, + "ArrayRef":$order, + "CTALayoutAttr":$CTALayout, + "Type":$eltTy, + "bool":$needTrans), [{ + unsigned bitwidth = eltTy.getIntOrFloatBitWidth(); + return get(context, dotOpEnc, shape, order, CTALayout, bitwidth, needTrans); + }]>, + AttrBuilder<(ins "ArrayRef":$shape, "ArrayRef":$order, "CTALayoutAttr":$CTALayout, diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index 14a050472800..5a1d93c2569a 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -60,9 +60,14 @@ class ConvertTransConvert : public mlir::RewritePattern { // used here. For tests where numCTAs = 1, this is not a problem since all // CTALayouts are the same. auto newXOrder = triton::gpu::getOrder(argEncoding); + // set needTrans to true here. newXEncoding is computed based on argEncoding + // which is before the transpose. without needTrans we will compute vec and + // maxPhase based on incorrect m, n and k size of mma. the type inference of + // TransOp simply swap the order but doesn't fix the vec and maxPhase for + // the YType, hence it would causing incorrect swizzling code. auto newXEncoding = triton::gpu::SharedEncodingAttr::get( getContext(), ZEncoding, XType.getShape(), newXOrder, - XEncoding.getCTALayout(), XType.getElementType()); + XEncoding.getCTALayout(), XType.getElementType(), true); auto newXType = RankedTensorType::get(XType.getShape(), XType.getElementType(), newXEncoding); if (XEncoding == newXEncoding) diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp index db5513d92cfb..13de5d266cb6 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp @@ -652,10 +652,12 @@ void LoopPipeliner::createBufferTypes() { .getEncoding() .dyn_cast()) { // MMAv1 and MMAv2 + bool needTrans = dyn_cast_or_null( + cvt.getDefiningOp()->getOperand(0).getDefiningOp()); unsigned bitWidth = ty.getElementType().getIntOrFloatBitWidth(); sharedEnc = ttg::SharedEncodingAttr::get( ty.getContext(), dotOpEnc, ty.getShape(), - ttg::getOrder(ty.getEncoding()), CTALayout, bitWidth); + ttg::getOrder(ty.getEncoding()), CTALayout, bitWidth, needTrans); } else { // MMAv3 sharedEnc = ttg::SharedEncodingAttr::get(ty.getContext(), ty.getShape(), From a301502d254a1effb9821494a3d399983d1490d9 Mon Sep 17 00:00:00 2001 From: Zahi Moudallal <128723247+zahimoud@users.noreply.github.com> Date: Wed, 13 Sep 2023 12:58:42 -0700 Subject: [PATCH 042/122] [BACKEND] Fixing assert in shared encoding swizzling addresses calculation (#2292) --- lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h index 78c4f92caa53..667f174f2270 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h @@ -339,8 +339,9 @@ class ConvertTritonGPUOpToLLVMPatternBase { // Order auto inOrder = triton::gpu::getOrder(srcEncoding); auto outOrder = triton::gpu::getOrder(resSharedLayout); - assert(outVec * (maxPhase - 1) <= srcShape[outOrder[0]] && - "Swizzling would generate out of bounds memory accesses"); + assert(maxPhase == 1 || + outVec * maxPhase <= srcShape[outOrder[0]] && + "Swizzling would generate out of bounds memory accesses"); // Tensor indices held by the current thread, as LLVM values auto srcIndices = emitIndices(loc, rewriter, srcEncoding, srcTy, false); // Swizzling with leading offsets (e.g. Hopper GMMA) From 36087a108fa8102bb6e6593fab84b8c948b62a97 Mon Sep 17 00:00:00 2001 From: Zahi Moudallal <128723247+zahimoud@users.noreply.github.com> Date: Wed, 13 Sep 2023 14:21:01 -0700 Subject: [PATCH 043/122] [FRONTEND] Added SASS to asm dict (#2280) --- python/setup.py | 17 +++++------ python/test/unit/language/test_line_info.py | 22 +++++++------- python/triton/common/backend.py | 32 +++++++++++++++------ python/triton/compiler/compiler.py | 22 ++++---------- python/triton/tools/disasm.py | 22 ++++++++++++-- 5 files changed, 70 insertions(+), 45 deletions(-) diff --git a/python/setup.py b/python/setup.py index bcdc5faa3107..e2c8d9ff96c6 100644 --- a/python/setup.py +++ b/python/setup.py @@ -125,15 +125,15 @@ def get_thirdparty_packages(triton_cache_path): # ---- package data --- -def download_and_copy_ptxas(): - +def download_and_copy(src_path, version, url_func): base_dir = os.path.dirname(__file__) - src_path = "bin/ptxas" - version = "12.1.105" + # src_path = "bin/ptxas" + # version = "12.1.105" arch = platform.machine() if arch == "x86_64": arch = "64" - url = f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-nvcc-{version}-0.tar.bz2" + url = url_func(arch, version) + # url = f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-nvcc-{version}-0.tar.bz2" dst_prefix = os.path.join(base_dir, "triton") dst_suffix = os.path.join("third_party", "cuda", src_path) dst_path = os.path.join(dst_prefix, dst_suffix) @@ -156,9 +156,9 @@ def download_and_copy_ptxas(): shutil.copy(src_path, dst_path) return dst_suffix - # ---- cmake extension ---- + def get_base_dir(): return os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir)) @@ -280,8 +280,9 @@ def build_extension(self, ext): subprocess.check_call(["cmake", "--build", ".", "--target", "mlir-doc"], cwd=cmake_dir) -download_and_copy_ptxas() - +download_and_copy(src_path='bin/ptxas', version='12.1.105', url_func=lambda arch, version: f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-nvcc-{version}-0.tar.bz2") +download_and_copy(src_path='bin/cuobjdump', version='12.1.111', url_func=lambda arch, version: f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-cuobjdump-{version}-0.tar.bz2") +download_and_copy(src_path='bin/nvdisasm', version='12.1.105', url_func=lambda arch, version: f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-nvdisasm-{version}-0.tar.bz2") setup( name="triton", diff --git a/python/test/unit/language/test_line_info.py b/python/test/unit/language/test_line_info.py index 2823cf9299b2..fc73f2bf374a 100644 --- a/python/test/unit/language/test_line_info.py +++ b/python/test/unit/language/test_line_info.py @@ -6,6 +6,7 @@ import triton import triton.language as tl +from triton.common.backend import path_to_nvdisasm @triton.jit @@ -50,10 +51,11 @@ def kernel_multi_files(X, Y, BLOCK: tl.constexpr): def extract_file_lines(asm): + nvdisasm, _ = path_to_nvdisasm() fd, path = tempfile.mkstemp() with open(fd, 'wb') as cubin: cubin.write(asm) - asm = subprocess.check_output(["nvdisasm", "-g", path]).decode("utf-8") + asm = subprocess.check_output([nvdisasm, "-g", path]).decode("utf-8") file_lines = [] lines = asm.splitlines() for line in lines: @@ -80,7 +82,7 @@ def check_file_lines(file_lines, file_name, lineno): @pytest.mark.parametrize("func", func_types) def test_line_info(func: str): try: - subprocess.check_output(["nvdisasm", "-h"]) + _, _ = path_to_nvdisasm() except BaseException: pytest.skip("nvdisasm is not available") @@ -99,20 +101,20 @@ def test_line_info(func: str): file_lines = extract_file_lines(kernel_info.asm["cubin"]) if func == "single": - assert (check_file_lines(file_lines, "test_line_info.py", 15)) assert (check_file_lines(file_lines, "test_line_info.py", 16)) + assert (check_file_lines(file_lines, "test_line_info.py", 17)) elif func == "call": - assert (check_file_lines(file_lines, "test_line_info.py", 28)) - assert (check_file_lines(file_lines, "test_line_info.py", 21)) - assert (check_file_lines(file_lines, "test_line_info.py", 30)) + assert (check_file_lines(file_lines, "test_line_info.py", 29)) + assert (check_file_lines(file_lines, "test_line_info.py", 22)) + assert (check_file_lines(file_lines, "test_line_info.py", 31)) elif func == "call_noinline": - assert (check_file_lines(file_lines, "test_line_info.py", 42)) - assert (check_file_lines(file_lines, "test_line_info.py", 35)) + assert (check_file_lines(file_lines, "test_line_info.py", 43)) assert (check_file_lines(file_lines, "test_line_info.py", 36)) assert (check_file_lines(file_lines, "test_line_info.py", 37)) + assert (check_file_lines(file_lines, "test_line_info.py", 38)) elif func == "multi_files": - assert (check_file_lines(file_lines, "test_line_info.py", 47)) - assert (check_file_lines(file_lines, "test_line_info.py", 49)) + assert (check_file_lines(file_lines, "test_line_info.py", 48)) + assert (check_file_lines(file_lines, "test_line_info.py", 50)) assert (check_file_lines(file_lines, "standard.py", 33)) assert (check_file_lines(file_lines, "standard.py", 34)) assert (check_file_lines(file_lines, "standard.py", 36)) diff --git a/python/triton/common/backend.py b/python/triton/common/backend.py index 5b60c1377caa..edbcdde12c66 100644 --- a/python/triton/common/backend.py +++ b/python/triton/common/backend.py @@ -101,20 +101,34 @@ def get_backend(device_type: str): return _backends[device_type] if device_type in _backends else None -@functools.lru_cache() -def path_to_ptxas(): +def _path_to_binary(binary: str): base_dir = os.path.join(os.path.dirname(__file__), os.pardir) paths = [ os.environ.get("TRITON_PTXAS_PATH", ""), - os.path.join(base_dir, "third_party", "cuda", "bin", "ptxas") + os.path.join(base_dir, "third_party", "cuda", "bin", binary) ] - for ptxas in paths: - ptxas_bin = ptxas.split(" ")[0] - if os.path.exists(ptxas_bin) and os.path.isfile(ptxas_bin): - result = subprocess.check_output([ptxas_bin, "--version"], stderr=subprocess.STDOUT) + for p in paths: + bin = p.split(" ")[0] + if os.path.exists(bin) and os.path.isfile(bin): + result = subprocess.check_output([bin, "--version"], stderr=subprocess.STDOUT) if result is not None: version = re.search(r".*release (\d+\.\d+).*", result.decode("utf-8"), flags=re.MULTILINE) if version is not None: - return ptxas, version.group(1) - raise RuntimeError("Cannot find ptxas") + return p, version.group(1) + raise RuntimeError(f"Cannot find {binary}") + + +@functools.lru_cache() +def path_to_ptxas(): + return _path_to_binary("ptxas") + + +@functools.lru_cache() +def path_to_cuobjdump(): + return _path_to_binary("cuobjdump") + + +@functools.lru_cache() +def path_to_nvdisasm(): + return _path_to_binary("nvdisasm") diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index f5a5d941160b..211821bbb7ae 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -5,7 +5,6 @@ import json import os import re -import tempfile from collections import namedtuple from pathlib import Path from typing import Any @@ -24,7 +23,7 @@ from ..runtime.driver import driver from ..runtime.jit import (JITFunction, get_cuda_stream, get_current_device, get_device_capability, version_key) -from ..tools.disasm import extract +from ..tools.disasm import get_sass from .code_generator import ast_to_ttir from .make_launcher import make_stub from .utils import (InfoFromBackendForTensorMap, TensorMapManager, @@ -500,7 +499,6 @@ def compile(fn, **kwargs): metadata_group[extra_file_name] = fn_cache_manager.put(next_module[1], extra_file_name) else: metadata_group[ir_filename] = fn_cache_manager.put(next_module, ir_filename) - fn_cache_manager.put(next_module, ir_filename) fn_dump_manager.put(next_module, ir_filename) if (enable_override and fn_override_manager.has_file(ir_filename)): print(f"\nOverriding kernel with file {ir_filename}") @@ -517,6 +515,11 @@ def compile(fn, **kwargs): if ir_name == "cubin": asm[ir_name] = next_module + sass_ir = "sass" + sass_fname = f"{name}.{sass_ir}" + asm[sass_ir] = get_sass(next_module) + metadata_group[sass_fname] = fn_cache_manager.put(asm[sass_ir], sass_fname) + elif ir_name == "amdgcn": asm[ir_name] = str(next_module[0]) else: @@ -669,16 +672,3 @@ def runner(*args, stream=None): self.clusterDims[1], self.clusterDims[2], self.shared, stream, self.cu_function, CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, self, *args_expand) return runner - - def get_sass(self, fun=None): - if 'sass' in self.asm: - return self.asm['sass'] - fd, path = tempfile.mkstemp() - try: - with open(fd, 'wb') as cubin: - cubin.write(self.asm['cubin']) - self.sass = extract(path, fun) - finally: - os.remove(path) - self.asm['sass'] = self.sass - return self.sass diff --git a/python/triton/tools/disasm.py b/python/triton/tools/disasm.py index 24a0787c5c16..032b726682f5 100644 --- a/python/triton/tools/disasm.py +++ b/python/triton/tools/disasm.py @@ -20,8 +20,12 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +import os import re import subprocess +import tempfile + +from ..common.backend import path_to_cuobjdump, path_to_nvdisasm FLINE_RE = re.compile(r'\s*/\*\w{4}\*/\s*([^;]*;)\s*/\* 0x(\w{16}) \*/\s*') SLINE_RE = re.compile(r'\s*/\* 0x(\w{16}) \*/\s*') @@ -60,11 +64,25 @@ def processSassLines(fline, sline, labels): return (f'{ctrl}', f'{asm}') +def get_sass(cubin_asm, fun=None): + fd, path = tempfile.mkstemp() + try: + with open(fd, 'wb') as cubin: + cubin.write(cubin_asm) + sass = extract(path, fun) + finally: + os.remove(path) + return sass + + def extract(file_path, fun): + cuobjdump, _ = path_to_cuobjdump() + nvdisasm, _ = path_to_nvdisasm() + os.environ["NVDISASM_PATH"] = nvdisasm if fun is None: - sass_str = subprocess.check_output(["cuobjdump", "-sass", file_path]) + sass_str = subprocess.check_output([cuobjdump, "-sass", file_path]) else: - sass_str = subprocess.check_output(["cuobjdump", "-fun", fun, "-sass", file_path]) + sass_str = subprocess.check_output([cuobjdump, "-fun", fun, "-sass", file_path]) sass_lines = sass_str.splitlines() line_idx = 0 while line_idx < len(sass_lines): From 08c16589573621fcb8cd5a9c3b8a0537077f876d Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Thu, 14 Sep 2023 12:03:23 -0400 Subject: [PATCH 044/122] [FRONTEND] Accommodate new triton IR format (#2294) - Support memory space for pointers (e.g., `!tt.ptr`). - Support parsing function attribute, though not used yet. --- python/test/unit/language/test_core.py | 66 +++++++++++++------------- python/test/unit/tools/test_aot.py | 2 +- python/triton/compiler/compiler.py | 14 ++++-- 3 files changed, 45 insertions(+), 37 deletions(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index df798d5ac6c2..5bb68ef1e4fb 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1779,28 +1779,28 @@ def test_scan_layouts(M, N, src_layout, axis, device): ir = f""" #blocked = {src_layout} module attributes {{"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32}} {{ - tt.func public @kernel_0d1d(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ + tt.func public @kernel_0d1d(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #blocked> %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>> %1 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>>) -> tensor<{M}x1xi32, #blocked> %2 = arith.muli %1, %cst : tensor<{M}x1xi32, #blocked> - %3 = tt.splat %arg0 : (!tt.ptr) -> tensor<{M}x1x!tt.ptr, #blocked> - %4 = tt.addptr %3, %2 : tensor<{M}x1x!tt.ptr, #blocked>, tensor<{M}x1xi32, #blocked> + %3 = tt.splat %arg0 : (!tt.ptr) -> tensor<{M}x1x!tt.ptr, #blocked> + %4 = tt.addptr %3, %2 : tensor<{M}x1x!tt.ptr, #blocked>, tensor<{M}x1xi32, #blocked> %5 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #blocked}}>> %6 = tt.expand_dims %5 {{axis = 0 : i32}} : (tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #blocked}}>>) -> tensor<1x{N}xi32, #blocked> - %7 = tt.broadcast %4 : (tensor<{M}x1x!tt.ptr, #blocked>) -> tensor<{M}x{N}x!tt.ptr, #blocked> + %7 = tt.broadcast %4 : (tensor<{M}x1x!tt.ptr, #blocked>) -> tensor<{M}x{N}x!tt.ptr, #blocked> %8 = tt.broadcast %6 : (tensor<1x{N}xi32, #blocked>) -> tensor<{M}x{N}xi32, #blocked> - %9 = tt.addptr %7, %8 : tensor<{M}x{N}x!tt.ptr, #blocked>, tensor<{M}x{N}xi32, #blocked> + %9 = tt.addptr %7, %8 : tensor<{M}x{N}x!tt.ptr, #blocked>, tensor<{M}x{N}xi32, #blocked> %10 = tt.load %9 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}xi32, #blocked> %11 = "tt.scan"(%10) <{{axis = {axis} : i32}}> ({{ ^bb0(%arg2: i32, %arg3: i32): %16 = arith.addi %arg2, %arg3 : i32 tt.scan.return %16 : i32 }}) : (tensor<{M}x{N}xi32, #blocked>) -> tensor<{M}x{N}xi32, #blocked> - %12 = tt.splat %arg1 : (!tt.ptr) -> tensor<{M}x1x!tt.ptr, #blocked> - %13 = tt.addptr %12, %2 : tensor<{M}x1x!tt.ptr, #blocked>, tensor<{M}x1xi32, #blocked> - %14 = tt.broadcast %13 : (tensor<{M}x1x!tt.ptr, #blocked>) -> tensor<{M}x{N}x!tt.ptr, #blocked> - %15 = tt.addptr %14, %8 : tensor<{M}x{N}x!tt.ptr, #blocked>, tensor<{M}x{N}xi32, #blocked> + %12 = tt.splat %arg1 : (!tt.ptr) -> tensor<{M}x1x!tt.ptr, #blocked> + %13 = tt.addptr %12, %2 : tensor<{M}x1x!tt.ptr, #blocked>, tensor<{M}x1xi32, #blocked> + %14 = tt.broadcast %13 : (tensor<{M}x1x!tt.ptr, #blocked>) -> tensor<{M}x{N}x!tt.ptr, #blocked> + %15 = tt.addptr %14, %8 : tensor<{M}x{N}x!tt.ptr, #blocked>, tensor<{M}x{N}xi32, #blocked> tt.store %15, %11 {{cache = 1 : i32, evict = 1 : i32}} : tensor<{M}x{N}xi32, #blocked> tt.return }} @@ -1871,7 +1871,7 @@ def test_reduce_layouts(M, N, src_layout, axis, reduce2d, dtype_str, reduce_op, }} }} """ if reduce2d else f""" - %14 = tt.splat %arg2 : (!tt.ptr<{ty}>) -> tensor<{rdims_2d}x!tt.ptr<{ty}>, #blocked> + %14 = tt.splat %arg2 : (!tt.ptr<{ty}, 1>) -> tensor<{rdims_2d}x!tt.ptr<{ty}, 1>, #blocked> %15 = tt.addptr %14, {store_range} : tensor<{rdims_2d}x!tt.ptr<{ty}>, #blocked>, tensor<{rdims_2d}xi32, #blocked> %16 = {GPU_DIALECT}.convert_layout %13 : (tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #src}}>>) -> tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #blocked}}>> %17 = tt.expand_dims %16 {{axis = {axis} : i32}} : (tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #blocked}}>>) -> tensor<{rdims_2d}x{ty}, #blocked> @@ -1885,18 +1885,18 @@ def test_reduce_layouts(M, N, src_layout, axis, reduce2d, dtype_str, reduce_op, #blocked = {blocked} #src = {src_layout} module attributes {{"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32}} {{ - tt.func public @kernel_0d1d2c3d4c(%arg0: !tt.ptr<{ty}> {{tt.divisibility = 16 : i32}}, %arg1: i32 {{tt.divisibility = 16 : i32}}, %arg2: !tt.ptr<{ty}> {{tt.divisibility = 16 : i32}}) {{ + tt.func public @kernel_0d1d2c3d4c(%arg0: !tt.ptr<{ty}, 1> {{tt.divisibility = 16 : i32}}, %arg1: i32 {{tt.divisibility = 16 : i32}}, %arg2: !tt.ptr<{ty}, 1> {{tt.divisibility = 16 : i32}}) {{ %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #blocked}}>> %1 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #blocked}}>>) -> tensor<{M}x1xi32, #blocked> %2 = tt.splat %arg1 : (i32) -> tensor<{M}x1xi32, #blocked> %3 = arith.muli %1, %2 : tensor<{M}x1xi32, #blocked> - %4 = tt.splat %arg0 : (!tt.ptr<{ty}>) -> tensor<{M}x1x!tt.ptr<{ty}>, #blocked> - %5 = tt.addptr %4, %3 : tensor<{M}x1x!tt.ptr<{ty}>, #blocked>, tensor<{M}x1xi32, #blocked> + %4 = tt.splat %arg0 : (!tt.ptr<{ty}, 1>) -> tensor<{M}x1x!tt.ptr<{ty}, 1>, #blocked> + %5 = tt.addptr %4, %3 : tensor<{M}x1x!tt.ptr<{ty}, 1>, #blocked>, tensor<{M}x1xi32, #blocked> %6 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #{GPU_DIALECT}.slice<{{dim = 0, parent = #blocked}}>> %7 = tt.expand_dims %6 {{axis = 0 : i32}} : (tensor<{N}xi32, #{GPU_DIALECT}.slice<{{dim = 0, parent = #blocked}}>>) -> tensor<1x{N}xi32, #blocked> - %8 = tt.broadcast %5 : (tensor<{M}x1x!tt.ptr<{ty}>, #blocked>) -> tensor<{M}x{N}x!tt.ptr<{ty}>, #blocked> + %8 = tt.broadcast %5 : (tensor<{M}x1x!tt.ptr<{ty}, 1>, #blocked>) -> tensor<{M}x{N}x!tt.ptr<{ty}, 1>, #blocked> %9 = tt.broadcast %7 : (tensor<1x{N}xi32, #blocked>) -> tensor<{M}x{N}xi32, #blocked> - %10 = tt.addptr %8, %9 : tensor<{M}x{N}x!tt.ptr<{ty}>, #blocked>, tensor<{M}x{N}xi32, #blocked> + %10 = tt.addptr %8, %9 : tensor<{M}x{N}x!tt.ptr<{ty}, 1>, #blocked>, tensor<{M}x{N}xi32, #blocked> %11 = tt.load %10 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}x{ty}, #blocked> %12 = {GPU_DIALECT}.convert_layout %11 : (tensor<{M}x{N}x{ty}, #blocked>) -> tensor<{M}x{N}x{ty}, #src> %13 = "tt.reduce"(%12) ({{ @@ -1945,16 +1945,16 @@ def test_store_op(M, src_layout, device): ir = f""" #src = {src_layout} module attributes {{"{GPU_DIALECT}.num-warps" = 4 : i32, "{GPU_DIALECT}.num-ctas" = 1 : i32, "{GPU_DIALECT}.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ - tt.func public @kernel(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ + tt.func public @kernel(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> - %1 = tt.splat %arg0 : (!tt.ptr) -> tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> - %2 = tt.addptr %1, %0 : tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>, tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> + %1 = tt.splat %arg0 : (!tt.ptr) -> tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> + %2 = tt.addptr %1, %0 : tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>, tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> %3 = tt.load %2 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}xf32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> %4 = tt.expand_dims %3 {{axis = 1 : i32}} : (tensor<{M}xf32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xf32, #src> %5 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> %6 = tt.expand_dims %5 {{axis = 1 : i32}} : (tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xi32, #src> - %7 = tt.splat %arg1 : (!tt.ptr) -> tensor<{M}x1x!tt.ptr, #src> - %8 = tt.addptr %7, %6 : tensor<{M}x1x!tt.ptr, #src>, tensor<{M}x1xi32, #src> + %7 = tt.splat %arg1 : (!tt.ptr) -> tensor<{M}x1x!tt.ptr, #src> + %8 = tt.addptr %7, %6 : tensor<{M}x1x!tt.ptr, #src>, tensor<{M}x1xi32, #src> tt.store %8, %4 : tensor<{M}x1xf32, #src> tt.return }} @@ -1998,14 +1998,14 @@ def test_convert1d(M, src_layout, dst_layout, src_dim, dst_dim, device): #dst = {dst_layout} #src = {src_layout} module attributes {{"{GPU_DIALECT}.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ - tt.func public @kernel(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ - %0 = tt.splat %arg0 : (!tt.ptr) -> tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>> + tt.func public @kernel(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ + %0 = tt.splat %arg0 : (!tt.ptr) -> tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>> %1 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>> - %2 = tt.addptr %0, %1 : tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>>, tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>> + %2 = tt.addptr %0, %1 : tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>>, tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>> %3 = tt.load %2 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>> - %4 = tt.splat %arg1 : (!tt.ptr) -> tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>> + %4 = tt.splat %arg1 : (!tt.ptr) -> tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>> %5 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>> - %6 = tt.addptr %4, %5 : tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>>, tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>> + %6 = tt.addptr %4, %5 : tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>>, tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>> %7 = {GPU_DIALECT}.convert_layout %3 : (tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>>) -> tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>> tt.store %6, %7 : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>> tt.return @@ -2069,7 +2069,7 @@ def test_chain_reduce(M, N, src_layout, op, device, first_axis): ir = f""" #src = {src_layout} module attributes {{"{GPU_DIALECT}.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ - tt.func public @sum_kernel_0d1d(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ + tt.func public @sum_kernel_0d1d(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #src> %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> %1 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xi32, #src> @@ -2079,8 +2079,8 @@ def test_chain_reduce(M, N, src_layout, op, device, first_axis): %5 = tt.broadcast %2 : (tensor<{M}x1xi32, #src>) -> tensor<{M}x{N}xi32, #src> %6 = tt.broadcast %4 : (tensor<1x{N}xi32, #src>) -> tensor<{M}x{N}xi32, #src> %7 = arith.addi %5, %6 : tensor<{M}x{N}xi32, #src> - %8 = tt.splat %arg0 : (!tt.ptr) -> tensor<{M}x{N}x!tt.ptr, #src> - %9 = tt.addptr %8, %7 : tensor<{M}x{N}x!tt.ptr, #src>, tensor<{M}x{N}xi32, #src> + %8 = tt.splat %arg0 : (!tt.ptr) -> tensor<{M}x{N}x!tt.ptr, #src> + %9 = tt.addptr %8, %7 : tensor<{M}x{N}x!tt.ptr, #src>, tensor<{M}x{N}xi32, #src> %10 = tt.load %9 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}xi32, #src> %11 = "tt.reduce"(%10) ({{ ^bb0(%arg2: i32, %arg3: i32): @@ -3664,22 +3664,22 @@ def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device): ir = layouts + f""" module attributes {{"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32}} {{ - tt.func public @kernel_0d1d(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ + tt.func public @kernel_0d1d(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #src> %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>> %1 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>> - %2 = tt.splat %arg0 : (!tt.ptr) -> tensor<{M}x{N}x!tt.ptr, #src> + %2 = tt.splat %arg0 : (!tt.ptr) -> tensor<{M}x{N}x!tt.ptr, #src> %4 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xi32, #src> %5 = arith.muli %4, %cst : tensor<{M}x1xi32, #src> %6 = tt.expand_dims %1 {{axis = 0 : i32}} : (tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>) -> tensor<1x{N}xi32, #src> %7 = tt.broadcast %6 : (tensor<1x{N}xi32, #src>) -> tensor<{M}x{N}xi32, #src> %8 = tt.broadcast %5 : (tensor<{M}x1xi32, #src>) -> tensor<{M}x{N}xi32, #src> %9 = arith.addi %8, %7 : tensor<{M}x{N}xi32, #src> - %10 = tt.addptr %2, %9 : tensor<{M}x{N}x!tt.ptr, #src>, tensor<{M}x{N}xi32, #src> + %10 = tt.addptr %2, %9 : tensor<{M}x{N}x!tt.ptr, #src>, tensor<{M}x{N}xi32, #src> %11 = tt.load %10 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}xf16, #src> - %3 = tt.splat %arg1 : (!tt.ptr) -> tensor<{M}x{N}x!tt.ptr, #dst> + %3 = tt.splat %arg1 : (!tt.ptr) -> tensor<{M}x{N}x!tt.ptr, #dst> """ + conversion + f""" - %14 = tt.addptr %3, %12 : tensor<{M}x{N}x!tt.ptr, #dst>, tensor<{M}x{N}xi32, #dst> + %14 = tt.addptr %3, %12 : tensor<{M}x{N}x!tt.ptr, #dst>, tensor<{M}x{N}xi32, #dst> tt.store %14, %13 : tensor<{M}x{N}xf16, #dst> tt.return }} diff --git a/python/test/unit/tools/test_aot.py b/python/test/unit/tools/test_aot.py index 2c6fe88a2615..06a7ed2b12a1 100644 --- a/python/test/unit/tools/test_aot.py +++ b/python/test/unit/tools/test_aot.py @@ -330,7 +330,7 @@ def test_compile_link_autotune_matmul(): def test_ttgir_to_ptx(): src = """ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32, "triton_gpu.num-ctas" = 1 : i32} { - tt.func public @sum_kernel_0d1d(%arg0: !tt.ptr, %arg1: !tt.ptr) { + tt.func public @sum_kernel_0d1d(%arg0: !tt.ptr, %arg1: !tt.ptr) { tt.return } } diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index 211821bbb7ae..500754d378ee 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -204,7 +204,9 @@ def get_kernel_name(src: str, pattern: str) -> str: def convert_type_repr(x): - match = re.search(r'!tt\.ptr<(.*)>', x) + # Currently we only capture the pointer type and assume the pointer is on global memory. + # TODO: Capture and support shared memory space + match = re.search(r'!tt\.ptr<([^,]+)', x) if match is not None: return '*' + convert_type_repr(match.group(1)) return x @@ -241,7 +243,8 @@ def make_hash(fn, arch, env_vars, **kwargs): # (letters, digits, or underscores), and capture it as group 1 (the function name) # - (\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\)) : match a pair of parentheses enclosing # zero or more arguments separated by commas, and capture it as group 2 (the argument list) -mlir_prototype_pattern = r'^\s*tt\.func\s+(?:public\s+)?(@\w+)(\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\))\s*\{\s*$' +# - (attributes \{[\S\s]+\})? : optionally match attributes enclosed in braces and capture it as group 3 +mlir_prototype_pattern = r"^\s*tt\.func\s+(?:public\s+)?(@\w+)(\((?:%\w+: [\S\s]+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\))\s*(attributes \{[\S\s]+\})?\s+\{\s*$" ptx_prototype_pattern = r"\.(?:visible|extern)\s+\.(?:entry|func)\s+(\w+)\s*\(([^)]*)\)" prototype_pattern = { "ttir": mlir_prototype_pattern, @@ -249,7 +252,11 @@ def make_hash(fn, arch, env_vars, **kwargs): "ptx": ptx_prototype_pattern, } -mlir_arg_type_pattern = r'%\w+: ([^,^\)\s]+)(?: \{\S+ = \S+ : \S+\})?,?' +# - ((?:[^,\s<]+|<[^>]+>)+): Capturing group that matches one or more of either: +# [^,\s<]+: One or more characters that are not a comma, whitespace, or the < symbol. +# |: OR +# <[^>]+>: A string that starts with < and ends with >, containing any characters except > in between. +mlir_arg_type_pattern = r'%\w+: ((?:[^,\s<]+|<[^>]+>)+),?' ptx_arg_type_pattern = r"\.param\s+\.(\w+)" arg_type_pattern = { "ttir": mlir_arg_type_pattern, @@ -422,6 +429,7 @@ def compile(fn, **kwargs): src = Path(fn).read_text() import re match = re.search(prototype_pattern[ir_name], src, re.MULTILINE) + # TODO: support function attributes at group 3 (e.g., device function) name, signature = match.group(1), match.group(2) types = re.findall(arg_type_pattern[ir_name], signature) if ir_name == 'ttgir': From 976aabdeb28eeb7db8bb48892899e7d31d69ca4c Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Fri, 15 Sep 2023 10:00:58 -0700 Subject: [PATCH 045/122] [BUILD] Fix few dependencies and layering issues to make lld work (#2307) This fixes few problems that were preventing me to use lld linker. --- .../Dialect/TritonGPU/Transforms/Utility.h | 7 ---- .../Dialect/TritonNvidiaGPU/IR/Dialect.h | 11 ++++++ lib/Analysis/CMakeLists.txt | 1 + lib/Conversion/TritonGPUToLLVM/CMakeLists.txt | 35 +++++++++++-------- lib/Dialect/TritonGPU/Transforms/Utility.cpp | 24 ------------- lib/Dialect/TritonNvidiaGPU/IR/Dialect.cpp | 24 +++++++++++++ 6 files changed, 56 insertions(+), 46 deletions(-) diff --git a/include/triton/Dialect/TritonGPU/Transforms/Utility.h b/include/triton/Dialect/TritonGPU/Transforms/Utility.h index 6c0193182336..fe9f9f8c5953 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Utility.h +++ b/include/triton/Dialect/TritonGPU/Transforms/Utility.h @@ -141,13 +141,6 @@ Value linearize(OpBuilder &b, Location loc, ArrayRef multiDim, Value linearize(OpBuilder &b, Location loc, ArrayRef multiDim, ArrayRef shape); -// Returns null if the op is not inside a agent region (warp specialization -// mode). Note that there should be at most one agent id attached to the -// operation. -std::optional getWSAgentId(Operation *op); -std::optional getWSRoleId(Operation *op); -void setRoleId(Operation *op, int roleId); - } // namespace mlir #endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_ diff --git a/include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h b/include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h index d07f0743615d..fc8a99457257 100644 --- a/include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h @@ -42,4 +42,15 @@ #define GET_OP_CLASSES #include "triton/Dialect/TritonNvidiaGPU/IR/Ops.h.inc" +namespace mlir { + +// Returns null if the op is not inside a agent region (warp specialization +// mode). Note that there should be at most one agent id attached to the +// operation. +std::optional getWSAgentId(Operation *op); +std::optional getWSRoleId(Operation *op); +void setRoleId(Operation *op, int roleId); + +} // namespace mlir + #endif // TRITON_DIALECT_TRITONNVIDIAGPU_IR_DIALECT_H_ diff --git a/lib/Analysis/CMakeLists.txt b/lib/Analysis/CMakeLists.txt index df1fe4066188..aecc2345ac1d 100644 --- a/lib/Analysis/CMakeLists.txt +++ b/lib/Analysis/CMakeLists.txt @@ -10,6 +10,7 @@ add_mlir_library(TritonAnalysis TritonGPUAttrDefsIncGen LINK_LIBS PUBLIC + ASMBuilder MLIRAnalysis MLIRLLVMDialect TritonIR diff --git a/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt b/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt index 2b5798cff2ea..5503a07569a8 100644 --- a/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt +++ b/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt @@ -1,3 +1,10 @@ +# Separate out PTX/GCN builders to avoid cyclic dependencies as TritonAnalysis +# depends on it. +set(LLVM_OPTIONAL_SOURCES + GCNAsmFormat.cpp + PTXAsmFormat.cpp + ) + add_mlir_conversion_library(TritonGPUToLLVM ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp ConvertLayoutOpToLLVM/SharedToDotOperandMMAv1.cpp @@ -12,22 +19,7 @@ add_mlir_conversion_library(TritonGPUToLLVM LoadStoreOpToLLVM.cpp BarrierOpToLLVM.cpp TritonGPUToLLVM.cpp - GCNAsmFormat.cpp - PTXAsmFormat.cpp - ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp - ConvertLayoutOpToLLVM/SharedToDotOperandMMAv1.cpp - ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp - ConvertLayoutOpToLLVM.cpp - DotOpToLLVM/FMA.cpp - DotOpToLLVM/MMAv1.cpp - DotOpToLLVM/MMAv2.cpp - DotOpToLLVM.cpp - ElementwiseOpToLLVM.cpp - LoadStoreOpToLLVM.cpp - TritonGPUToLLVM.cpp TritonGPUToLLVMPass.cpp - GCNAsmFormat.cpp - PTXAsmFormat.cpp ReduceOpToLLVM.cpp ScanOpToLLVM.cpp TypeConverter.cpp @@ -48,6 +40,7 @@ add_mlir_conversion_library(TritonGPUToLLVM Core LINK_LIBS PUBLIC + ASMBuilder MLIRIR MLIRPass MLIRGPUOps @@ -61,3 +54,15 @@ add_mlir_conversion_library(TritonGPUToLLVM TritonNvidiaGPUTransforms NVGPUIR ) + +add_mlir_library(ASMBuilder + GCNAsmFormat.cpp + PTXAsmFormat.cpp + + DEPENDS + TritonTableGen + + LINK_LIBS PUBLIC + MLIRAnalysis + MLIRLLVMDialect +) diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index 6e92fb2901ae..b4a5bbe920de 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -492,30 +492,6 @@ Value linearize(OpBuilder &b, Location loc, ArrayRef multiDim, return linear; } -std::optional getWSAgentId(Operation *op) { - int prevAgentId = -1; - if (auto attr = op->getAttrOfType("async_agent")) { - for (auto agentId : attr.getValues()) { - assert(prevAgentId == -1 && "support at most one agent id"); - prevAgentId = agentId; - } - } - if (prevAgentId == -1) - return std::nullopt; - return prevAgentId; -} - -std::optional getWSRoleId(Operation *op) { - if (!op->hasAttr("agent.mutex_role")) - return std::nullopt; - return op->getAttrOfType("agent.mutex_role").getInt(); -} - -void setRoleId(Operation *op, int roleId) { - auto attr = IntegerAttr::get(IntegerType::get(op->getContext(), 32), roleId); - op->setAttr("agent.mutex_role", attr); -} - namespace { /// Detect dead arguments in scf.for op by assuming all the values are dead and diff --git a/lib/Dialect/TritonNvidiaGPU/IR/Dialect.cpp b/lib/Dialect/TritonNvidiaGPU/IR/Dialect.cpp index 0a982ce0572a..c7985a927b22 100644 --- a/lib/Dialect/TritonNvidiaGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonNvidiaGPU/IR/Dialect.cpp @@ -67,3 +67,27 @@ TritonNvidiaGPUDialect::verifyOperationAttribute(Operation *op, // TODO: fill this. return success(); } + +std::optional mlir::getWSAgentId(Operation *op) { + int prevAgentId = -1; + if (auto attr = op->getAttrOfType("async_agent")) { + for (auto agentId : attr.getValues()) { + assert(prevAgentId == -1 && "support at most one agent id"); + prevAgentId = agentId; + } + } + if (prevAgentId == -1) + return std::nullopt; + return prevAgentId; +} + +std::optional mlir::getWSRoleId(Operation *op) { + if (!op->hasAttr("agent.mutex_role")) + return std::nullopt; + return op->getAttrOfType("agent.mutex_role").getInt(); +} + +void mlir::setRoleId(Operation *op, int roleId) { + auto attr = IntegerAttr::get(IntegerType::get(op->getContext(), 32), roleId); + op->setAttr("agent.mutex_role", attr); +} From ac1c21611055991aebdfae66715da6f571388bb5 Mon Sep 17 00:00:00 2001 From: kshama-msft <66488860+kshama-msft@users.noreply.github.com> Date: Fri, 15 Sep 2023 15:07:38 -0700 Subject: [PATCH 046/122] [DOCS] update README.md (#2311) Triton conf registration closed. --- README.md | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 014ffba09413..ee38c164869e 100644 --- a/README.md +++ b/README.md @@ -10,10 +10,8 @@ We're hiring! If you are interested in working on Triton at OpenAI, we have role ------------------- | [![Documentation](https://github.com/openai/triton/actions/workflows/documentation.yml/badge.svg)](https://triton-lang.org/) -# Triton Developer Conference Registration Open -The Triton Developer Conference will be held in a hybrid mode at the Microsoft Silicon Valley Campus in Mountain View, California. The conference will be held on September 20th from 10am to 4pm, followed by a reception till 5:30 pm. Please use the link below to register to attend either in-person or virtually online. - -Registration Link for Triton Developer Conference is [here](https://forms.office.com/r/m4jQXShDts) +# Triton Developer Conference Registration Now Closed +The Triton Developer Conference will be held in a hybrid mode at the Microsoft Silicon Valley Campus in Mountain View, California. The conference will be held on September 20th from 10am to 4pm, followed by a reception till 5:30 pm. Tentative Agenda for the conference (subject to change): From db5c793f824eb120261a0e6dd921966033525bc1 Mon Sep 17 00:00:00 2001 From: Zahi Moudallal <128723247+zahimoud@users.noreply.github.com> Date: Fri, 15 Sep 2023 15:31:43 -0700 Subject: [PATCH 047/122] [FRONTEND] Add sass to asm dict with lazy evaluation (#2309) --- python/triton/compiler/compiler.py | 16 ++++++++++------ python/triton/tools/disasm.py | 2 ++ 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index 500754d378ee..bfffb9e584af 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -30,6 +30,14 @@ get_ids_of_tensormaps, parse_tma_info) +class LazyDict(dict): + def __getitem__(self, key): + val = dict.__getitem__(self, key) + if callable(val): + return val() + return val + + def inline_triton_ir(mod): pm = ir.pass_manager(mod.context) pm.enable_debug() @@ -489,7 +497,7 @@ def compile(fn, **kwargs): metadata["device_type"] = device_type first_stage = list(stages.keys()).index(ext) - asm = dict() + asm = LazyDict() module = fn # run compilation pipeline and populate metadata for ir_name, (parse, compile_kernel) in list(stages.items())[first_stage:]: @@ -523,11 +531,7 @@ def compile(fn, **kwargs): if ir_name == "cubin": asm[ir_name] = next_module - sass_ir = "sass" - sass_fname = f"{name}.{sass_ir}" - asm[sass_ir] = get_sass(next_module) - metadata_group[sass_fname] = fn_cache_manager.put(asm[sass_ir], sass_fname) - + asm["sass"] = lambda: get_sass(next_module) elif ir_name == "amdgcn": asm[ir_name] = str(next_module[0]) else: diff --git a/python/triton/tools/disasm.py b/python/triton/tools/disasm.py index 032b726682f5..1e309a2e4940 100644 --- a/python/triton/tools/disasm.py +++ b/python/triton/tools/disasm.py @@ -20,6 +20,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +import functools import os import re import subprocess @@ -64,6 +65,7 @@ def processSassLines(fline, sline, labels): return (f'{ctrl}', f'{asm}') +@functools.lru_cache() def get_sass(cubin_asm, fun=None): fd, path = tempfile.mkstemp() try: From 78a0b5dc2a917cb96de6d20dc8f43e4ea3da1160 Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Fri, 15 Sep 2023 21:38:15 -0400 Subject: [PATCH 048/122] [CI] update integration-tests.yml (#2310) --- .github/workflows/integration-tests.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 097c55e590d1..74967580b4e3 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -162,7 +162,6 @@ jobs: Integration-Tests-Third-Party: needs: Runner-Preparation - if: false runs-on: ${{ matrix.runner }} From 31b0c521427109a8eda609b58d756c380b21599a Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Fri, 15 Sep 2023 18:42:54 -0700 Subject: [PATCH 049/122] [FRONTEND][BACKEND] Add flag to control accumulation for fp8 (#2300) Change the dot to allow taking an initial accumulator and add a flag that will allow the compiler to accumulate in a lower precision than the output type. On Hopper this flag is on by default which allows accumualting with lower precision. This only affect Hopper fp8 dot. --- include/triton/Dialect/NVGPU/IR/NVGPUOps.td | 4 +- include/triton/Dialect/Triton/IR/TritonOps.td | 7 +- .../TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td | 6 +- lib/Analysis/Utility.cpp | 6 ++ .../NVGPUToLLVM/NVGPUToLLVMPass.cpp | 36 ++++---- .../TritonGPUToLLVM/DotOpToLLVM/WGMMA.cpp | 62 +++++++++++-- .../TritonToTritonGPUPass.cpp | 3 +- lib/Dialect/Triton/Transforms/Combine.cpp | 5 +- lib/Dialect/Triton/Transforms/Combine.td | 22 ++--- .../TritonGPU/Transforms/AccelerateMatmul.cpp | 3 +- lib/Dialect/TritonGPU/Transforms/Pipeline.cpp | 3 +- .../Transforms/RemoveLayoutConversions.cpp | 3 +- .../Transforms/FenceInsertion.cpp | 4 +- .../TritonNvidiaGPU/Transforms/WSPipeline.cpp | 2 +- python/src/triton.cc | 7 +- python/test/unit/language/test_core.py | 87 ++++++++++++++++++- python/test/unit/operators/test_matmul.py | 86 +++++++++--------- python/triton/language/core.py | 4 +- python/triton/language/semantic.py | 18 +++- python/triton/ops/matmul.py | 13 ++- test/Analysis/test-alias.mlir | 2 +- test/Analysis/test-allocation.mlir | 6 +- test/Analysis/test-membar.mlir | 2 +- test/Conversion/invalid.mlir | 6 +- test/Conversion/triton_ops.mlir | 8 +- test/Conversion/triton_to_tritongpu.mlir | 2 +- test/Conversion/tritongpu_to_llvm.mlir | 14 +-- test/Conversion/tritongpu_to_llvm_hopper.mlir | 71 +++++++++++++++ test/NVGPU/test_wgmma.mlir | 13 +++ test/Triton/combine.mlir | 6 +- test/TritonGPU/combine.mlir | 6 +- test/TritonGPU/dot-operands.mlir | 12 +-- test/TritonGPU/loop-pipeline-hopper.mlir | 8 +- test/TritonGPU/loop-pipeline.mlir | 16 ++-- test/TritonGPU/materialize-load-store.mlir | 2 +- test/TritonGPU/matmul.mlir | 2 +- test/TritonGPU/prefetch.mlir | 2 +- test/TritonGPU/reorder-instructions.mlir | 4 +- test/TritonGPU/rewrite-tensor-pointer.mlir | 2 +- test/TritonGPU/wsdecomposing.mlir | 12 +-- test/TritonGPU/wsmaterialization.mlir | 4 +- test/TritonGPU/wsmutex.mlir | 2 +- test/TritonGPU/wspipeline.mlir | 2 +- .../ws-feasibility-checking.mlir | 22 ++--- 44 files changed, 430 insertions(+), 177 deletions(-) diff --git a/include/triton/Dialect/NVGPU/IR/NVGPUOps.td b/include/triton/Dialect/NVGPU/IR/NVGPUOps.td index e8e2c91e63ac..896a27c17de1 100644 --- a/include/triton/Dialect/NVGPU/IR/NVGPUOps.td +++ b/include/triton/Dialect/NVGPU/IR/NVGPUOps.td @@ -148,12 +148,12 @@ def WGMMA_EltTypeAttr : I32EnumAttr<"WGMMAEltType", def WGMMA_OperandType : AnyTypeOf<[LLVM_AnyStruct, I64], "wgmma operand A/B type">; def NVGPU_WGMMAOp : NVGPU_Op<"wgmma", []> { - let arguments = (ins WGMMA_OperandType:$opA, WGMMA_OperandType:$opB, LLVM_AnyStruct:$opC, + let arguments = (ins WGMMA_OperandType:$opA, WGMMA_OperandType:$opB, Optional:$opC, I32Attr:$m, I32Attr:$n, I32Attr:$k, WGMMA_EltTypeAttr:$eltTypeC, WGMMA_EltTypeAttr:$eltTypeA, WGMMA_EltTypeAttr:$eltTypeB, WGMMA_LayoutAttr:$layoutA, WGMMA_LayoutAttr:$layoutB); let results = (outs LLVM_AnyStruct:$res); - let assemblyFormat = "$opA `,` $opB `,` $opC attr-dict `:` functional-type(operands, $res)"; + let assemblyFormat = "$opA `,` $opB (`,` $opC^)? attr-dict `:` functional-type(operands, $res)"; } def NVGPU_CGABarrierSyncOp : NVGPU_Op<"cga_barrier_sync", []> { diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index 69cad2bcf2eb..575db87be809 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -394,7 +394,12 @@ def TT_DotOp : TT_Op<"dot", [Pure, $d = matrix_multiply($a, $b) + $c }]; - let arguments = (ins TT_FpIntTensor:$a, TT_FpIntTensor:$b, TT_FpIntTensor:$c, BoolAttr:$allowTF32); + let arguments = (ins + TT_FpIntTensor:$a, + TT_FpIntTensor:$b, + TT_FpIntTensor:$c, + BoolAttr:$allowTF32, + I32Attr:$maxNumImpreciseAcc); let results = (outs TT_FpIntTensor:$d); diff --git a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td index cdf1146900c8..7d8cc7b41eda 100644 --- a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td +++ b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td @@ -258,7 +258,11 @@ def TTNG_DotAsyncOp : TTNG_Op<"dot_async", [Pure, $d = matrix_multiply($a, $b) + $c }]; - let arguments = (ins TT_FpIntTensor:$a, TT_FpIntTensor:$b, TT_FpIntTensor:$c, BoolAttr:$allowTF32); + let arguments = (ins TT_FpIntTensor:$a, + TT_FpIntTensor:$b, + TT_FpIntTensor:$c, + BoolAttr:$allowTF32, + I32Attr:$maxNumImpreciseAcc); let results = (outs TT_FpIntTensor:$d); diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index eff4eb527a0f..6b4141170042 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -379,6 +379,12 @@ bool supportMMA(triton::DotOp op, int version) { aElemTy.isF32()))) { return false; } + // We cannot use MMA_V3 if we need to accumulate in F32 within the MMA op. + if (op.getMaxNumImpreciseAcc() < 32 && + (aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FNUZ()) && + op.getType().cast().getElementType().isF32()) { + return false; + } } if (aElemTy.isF32() && bElemTy.isF32()) { return op.getAllowTF32() && version >= 2; diff --git a/lib/Conversion/NVGPUToLLVM/NVGPUToLLVMPass.cpp b/lib/Conversion/NVGPUToLLVM/NVGPUToLLVMPass.cpp index fcce9884bc6e..d9f27700d5f1 100644 --- a/lib/Conversion/NVGPUToLLVM/NVGPUToLLVMPass.cpp +++ b/lib/Conversion/NVGPUToLLVM/NVGPUToLLVMPass.cpp @@ -708,13 +708,13 @@ class WGMMAOpPattern : public NVGPUOpPatternBase { // TODO (zahi): Return type must always be a struct for wgmma, currently // we rely on the size of output constraints vector to determine whether // the output is a struct or not. We should find a way to pass this info - auto opC = op.getOpC(); - auto typeC = opC.getType(); + auto resultType = op.getType(); - auto structTypeC = typeC.dyn_cast(); - uint32_t numCRegs = structTypeC.getBody().size(); - std::string c = structTypeC.getBody().front().isF32() ? "=f" : "=r"; - return std::vector(numCRegs, c); + auto outputStructType = resultType.dyn_cast(); + uint32_t numOutputRegs = outputStructType.getBody().size(); + std::string output = + outputStructType.getBody().front().isF32() ? "=f" : "=r"; + return std::vector(numOutputRegs, output); } OperandsAndConstraints getOperandsAndConstraints(ttn::WGMMAOp op) const { @@ -727,7 +727,8 @@ class WGMMAOpPattern : public NVGPUOpPatternBase { auto structTypeA = typeA.dyn_cast(); // TODO (zahi): is this the best way to tie inputs/outputs ? - operandsAndConstraints.push_back({opC, "0"}); + if (opC) + operandsAndConstraints.push_back({opC, "0"}); if (structTypeA) { operandsAndConstraints.push_back({opA, "f"}); @@ -744,7 +745,6 @@ class WGMMAOpPattern : public NVGPUOpPatternBase { using namespace ttn; auto opA = op.getOpA(); auto opB = op.getOpB(); - auto opC = op.getOpC(); auto m = op.getM(); auto n = op.getN(); auto k = op.getK(); @@ -757,12 +757,12 @@ class WGMMAOpPattern : public NVGPUOpPatternBase { // Register checks auto typeA = opA.getType(); auto typeB = opB.getType(); - auto typeC = opC.getType(); + auto typeOutput = op.getType(); auto structTypeA = typeA.dyn_cast(); auto structTypeB = typeB.dyn_cast(); - auto structTypeC = typeC.dyn_cast(); + auto structTypeOutput = typeOutput.dyn_cast(); assert(!structTypeB && "Operand B can not be registers"); - assert(structTypeC && "Operand C must be registers"); + assert(structTypeOutput && "Output and C operand must be registers"); // Element type, MNK shape and transposing support check // Reference: @@ -804,18 +804,20 @@ class WGMMAOpPattern : public NVGPUOpPatternBase { // Operands uint32_t asmOpIdx = 0; + std::string args = ""; - // Operand C - uint32_t numCRegs = structTypeC.getBody().size(); + // Output and operand C + uint32_t numCRegs = structTypeOutput.getBody().size(); - std::string args = ""; args += "{"; for (uint32_t i = 0; i < numCRegs; ++i) { args += "$" + std::to_string(asmOpIdx++) + (i == numCRegs - 1 ? "" : ","); } args += "}, "; - asmOpIdx += numCRegs; + if (op.getOpC()) + asmOpIdx += numCRegs; + // Operand A if (structTypeA) { uint32_t numARegs = m * k / 128; @@ -833,8 +835,8 @@ class WGMMAOpPattern : public NVGPUOpPatternBase { // Operand B (must be `desc`) args += "$" + std::to_string(asmOpIdx++) + ", "; - // `scale-d` is 1 by default - args += "1"; + // `scale-d` is 1 if we have a C operand. + args += op.getOpC() ? "1" : "0"; // `imm-scale-a`, and `imm-scale-b` are 1 by default only for float-based // WGMMA diff --git a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/WGMMA.cpp b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/WGMMA.cpp index 9f943a615fa1..e69e89b5030d 100644 --- a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/WGMMA.cpp +++ b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/WGMMA.cpp @@ -260,11 +260,30 @@ SmallVector unpackAccumulator(ConversionPatternRewriter &rewriter, return results; } +static bool isFP8(triton::nvgpu::WGMMAEltType eltType) { + return eltType == triton::nvgpu::WGMMAEltType::e5m2 || + eltType == triton::nvgpu::WGMMAEltType::e4m3; +} + +static Value faddAccumulate(ConversionPatternRewriter &rewriter, Location loc, + Value a, Value b) { + int numEl = a.getType().cast().getBody().size(); + Value newStruct = rewriter.create(loc, a.getType()); + for (int i = 0; i < numEl; ++i) { + Value lhs = rewriter.create(loc, a, i); + Value rhs = rewriter.create(loc, b, i); + Value add = rewriter.create(loc, lhs, rhs); + newStruct = rewriter.create(loc, newStruct, add, i); + } + return newStruct; +} + LogicalResult convertDot(TritonGPUToLLVMTypeConverter *typeConverter, ConversionPatternRewriter &rewriter, Location loc, Operation *op, Value a, Value b, Value c, Value d, Value loadedA, Value loadedB, Value loadedC, - bool allowTF32, const SharedMemoryObject &smemObjA, + bool allowTF32, uint32_t maxNumImpreciseAcc, + const SharedMemoryObject &smemObjA, const SharedMemoryObject &smemObjB, bool sync, Value thread) { auto aTensorTy = a.getType().cast(); @@ -311,7 +330,10 @@ LogicalResult convertDot(TritonGPUToLLVMTypeConverter *typeConverter, if (numTMADescs == 0) rewriter.create(loc, 0); rewriter.create(loc); - + // WGMMA fp8 -> fp32 accumulates in lower precision than fp32. + bool needsPartialAccumulator = isFP8(eltTypeA) && + eltTypeC == triton::nvgpu::WGMMAEltType::f32 && + maxNumImpreciseAcc <= aTensorTy.getShape()[1]; SmallVector mmaResults; for (int m = 0; m < numRepM; ++m) { for (int n = 0; n < numRepN; ++n) { @@ -323,13 +345,33 @@ LogicalResult convertDot(TritonGPUToLLVMTypeConverter *typeConverter, auto accTy = LLVM::LLVMStructType::getLiteral(rewriter.getContext(), elemTypes); Value d = typeConverter->packLLElements(loc, mmaOut, rewriter, accTy); + uint32_t numLowPrecisionAcc = 0; + Value partialAcc; for (int k = 0; k < numRepK; ++k) { auto a = aLoader.smemLoad(m, k); auto b = bLoader.smemLoad(n, k); ValueRange operands{a, b, d}; - d = rewriter.create(loc, accTy, a, b, d, M, N, - K, eltTypeC, eltTypeA, - eltTypeB, layoutA, layoutB); + numLowPrecisionAcc += K; + // If using native accumulation would cause use to do more low precion + // accumulation than allowed do a separate allocation. + bool requireAddAccumulator = + needsPartialAccumulator && + (numLowPrecisionAcc >= maxNumImpreciseAcc || k == numRepK - 1); + Value mmaAcc = needsPartialAccumulator ? partialAcc : d; + mmaAcc = rewriter.create( + loc, accTy, a, b, mmaAcc, M, N, K, eltTypeC, eltTypeA, eltTypeB, + layoutA, layoutB); + if (needsPartialAccumulator) + partialAcc = mmaAcc; + else + d = mmaAcc; + // If we need accumulate separately to have higer precision, insert + // adds. + if (requireAddAccumulator) { + d = faddAccumulate(rewriter, loc, d, partialAcc); + numLowPrecisionAcc = 0; + partialAcc = Value(); + } } auto acc = typeConverter->unpackLLElements(loc, d, rewriter, accTy); for (int i = 0; i < acc.size(); ++i) { @@ -398,8 +440,9 @@ LogicalResult convertWGMMA(triton::DotOp op, triton::DotOp::Adaptor adaptor, auto smemObjA = getSharedMemoryObjectFromStruct(loc, llA, rewriter); auto smemObjB = getSharedMemoryObjectFromStruct(loc, llB, rewriter); return convertDot(typeConverter, rewriter, loc, op.getOperation(), A, B, C, - op.getD(), llA, llB, llC, op.getAllowTF32(), smemObjA, - smemObjB, true, thread); + op.getD(), llA, llB, llC, op.getAllowTF32(), + op.getMaxNumImpreciseAcc(), smemObjA, smemObjB, true, + thread); } LogicalResult convertAsyncWGMMA(triton::nvidia_gpu::DotAsyncOp op, @@ -426,6 +469,7 @@ LogicalResult convertAsyncWGMMA(triton::nvidia_gpu::DotAsyncOp op, auto smemObjA = getSharedMemoryObjectFromStruct(loc, llA, rewriter); auto smemObjB = getSharedMemoryObjectFromStruct(loc, llB, rewriter); return convertDot(typeConverter, rewriter, loc, op.getOperation(), A, B, C, - op.getD(), llA, llB, llC, op.getAllowTF32(), smemObjA, - smemObjB, false, thread); + op.getD(), llA, llB, llC, op.getAllowTF32(), + op.getMaxNumImpreciseAcc(), smemObjA, smemObjB, false, + thread); } diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp index de5ad6947c53..fffdb05559ee 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp @@ -342,7 +342,8 @@ struct TritonDotPattern : public OpConversionPattern { c = rewriter.create(c.getLoc(), retType, c); addNamedAttrs(rewriter.replaceOpWithNewOp( - op, retType, a, b, c, adaptor.getAllowTF32()), + op, retType, a, b, c, adaptor.getAllowTF32(), + adaptor.getMaxNumImpreciseAcc()), adaptor.getAttributes()); return success(); } diff --git a/lib/Dialect/Triton/Transforms/Combine.cpp b/lib/Dialect/Triton/Transforms/Combine.cpp index 99a891b9d496..0ee9f96ebbe3 100644 --- a/lib/Dialect/Triton/Transforms/Combine.cpp +++ b/lib/Dialect/Triton/Transforms/Combine.cpp @@ -181,8 +181,9 @@ class CombineBroadcastMulReducePattern : public mlir::RewritePattern { op->getLoc(), newAccType, rewriter.create(op->getLoc(), rewriter.getF32FloatAttr(0))); - rewriter.replaceOpWithNewOp( - op, expandLhsOp.getOperand(), expandRhsOp.getOperand(), newAcc, true); + rewriter.replaceOpWithNewOp(op, expandLhsOp.getOperand(), + expandRhsOp.getOperand(), newAcc, + true, 0); return mlir::success(); } }; diff --git a/lib/Dialect/Triton/Transforms/Combine.td b/lib/Dialect/Triton/Transforms/Combine.td index 974e09924f10..39c4ad234d22 100644 --- a/lib/Dialect/Triton/Transforms/Combine.td +++ b/lib/Dialect/Triton/Transforms/Combine.td @@ -12,22 +12,24 @@ include "mlir/IR/PatternBase.td" // AddIOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d) // AddFOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d) def CombineDotAddIPattern : Pat< - (Arith_AddIOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32)), - (TT_DotOp $a, $b, $d, $allowTF32), + (Arith_AddIOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32, $maxNumImpreciseAcc)), + (TT_DotOp $a, $b, $d, $allowTF32, $maxNumImpreciseAcc), [(Constraint> $c)]>; def CombineDotAddFPattern : Pat< - (Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32), $fastmath), - (TT_DotOp $a, $b, $d, $allowTF32), - [(Constraint> $c)]>; + (Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32, $maxNumImpreciseAcc), $fastmath), + (TT_DotOp $a, $b, $d, $allowTF32, $maxNumImpreciseAcc), + [(Constraint> $c), + (Constraint($0).getInt() == 0">> $maxNumImpreciseAcc)]>; def CombineDotAddIRevPattern : Pat< - (Arith_AddIOp (TT_DotOp:$res $a, $b, $c, $allowTF32), $d), - (TT_DotOp $a, $b, $d, $allowTF32), + (Arith_AddIOp (TT_DotOp:$res $a, $b, $c, $allowTF32, $maxNumImpreciseAcc), $d), + (TT_DotOp $a, $b, $d, $allowTF32, $maxNumImpreciseAcc), [(Constraint> $c)]>; def CombineDotAddFRevPattern : Pat< - (Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $allowTF32), $d, $fastmath), - (TT_DotOp $a, $b, $d, $allowTF32), - [(Constraint> $c)]>; + (Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $allowTF32, $maxNumImpreciseAcc), $d, $fastmath), + (TT_DotOp $a, $b, $d, $allowTF32, $maxNumImpreciseAcc), + [(Constraint> $c), + (Constraint($0).getInt() == 0">> $maxNumImpreciseAcc)]>; // TODO: this fails for addptr(addptr(ptr, i32), i64) // Commented out until fixed diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp index b8060cd6ce0c..8c82098416a8 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -316,7 +316,8 @@ class BlockedToMMA : public mlir::RewritePattern { } // convert dot instruction auto newDot = rewriter.create(dotOp.getLoc(), newRetType, a, b, - newAcc, dotOp.getAllowTF32()); + newAcc, dotOp.getAllowTF32(), + dotOp.getMaxNumImpreciseAcc()); rewriter.replaceOpWithNewOp(op, oldRetType, newDot.getResult()); diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp index 13de5d266cb6..ef9f60e1b1e5 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp @@ -1640,7 +1640,8 @@ void PipelinePass::asyncLaunchDots(scf::ForOp forOp) { auto dotOp = cast(dot.getDefiningOp()); builder.setInsertionPoint(dot.getDefiningOp()); auto dotAsync = builder.create( - loc, dotOp.getA(), dotOp.getB(), dotOp.getC(), dotOp.getAllowTF32()); + loc, dotOp.getA(), dotOp.getB(), dotOp.getC(), dotOp.getAllowTF32(), + dotOp.getMaxNumImpreciseAcc()); dot.replaceAllUsesWith(dotAsync.getResult()); updateConsumerReleaseInfo(dot.getDefiningOp(), dotWait, /*stage=*/1); dot.getDefiningOp()->erase(); diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp index cbdb59c88d2b..2d8ca362465a 100644 --- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp @@ -117,7 +117,8 @@ class ConvertDotConvert : public mlir::RewritePattern { op->getLoc(), dotOp.getResult().getType(), _0f); auto newDot = rewriter.create( op->getLoc(), dotOp.getResult().getType(), dotOp.getOperand(0), - dotOp.getOperand(1), _0, dotOp.getAllowTF32()); + dotOp.getOperand(1), _0, dotOp.getAllowTF32(), + dotOp.getMaxNumImpreciseAcc()); auto newCvt = rewriter.create( op->getLoc(), dstTy, newDot.getResult()); rewriter.replaceOpWithNewOp(op, newCvt, cvtOp.getOperand()); diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp index 3175fbbfb018..eb5c0f2ffcd2 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp @@ -50,9 +50,9 @@ struct FenceInsertionPass .getEncoding() .dyn_cast(); auto isHopperEncoding = mmaEncoding && mmaEncoding.isHopper(); - if (isHopperEncoding && (isa(a.getDefiningOp()) && + if (isHopperEncoding && (a.getDefiningOp() && ttg::isSharedEncoding(a)) || - (isa(b.getDefiningOp()) && + (b.getDefiningOp() && ttg::isSharedEncoding(b))) { // TODO: check whether cluster fence is needed diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/WSPipeline.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/WSPipeline.cpp index 373eac0e548b..b7488a8ba4ea 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/WSPipeline.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/WSPipeline.cpp @@ -736,7 +736,7 @@ void buildAsyncComm(const DenseMap> &map, auto dotAsync = builder.createWithAgentIds( loc, dotOp.getA(), dotOp.getB(), dotOp.getC(), - dotOp.getAllowTF32()); + dotOp.getAllowTF32(), dotOp.getMaxNumImpreciseAcc()); dot.replaceAllUsesWith(dotAsync.getResult()); builder.createWithAgentIds(loc, 1); diff --git a/python/src/triton.cc b/python/src/triton.cc index 6ac87d6c34fe..1bb6dae3cfd0 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -1482,9 +1482,10 @@ void init_triton_ir(py::module &&m) { }) .def("create_dot", [](TritonOpBuilder &self, mlir::Value &a, mlir::Value &b, - mlir::Value &c, bool allowTF32) -> mlir::Value { - return self.create(c.getType(), a, b, c, - allowTF32); + mlir::Value &c, bool allowTF32, + int maxNumImpreciseAcc) -> mlir::Value { + return self.create( + c.getType(), a, b, c, allowTF32, maxNumImpreciseAcc); }) .def("create_exp", [](TritonOpBuilder &self, mlir::Value &val) -> mlir::Value { diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 5bb68ef1e4fb..b5da4c17a072 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -131,8 +131,8 @@ def check_type_supported(dtype, device): cc = torch.cuda.get_device_capability() if cc[0] < 8 and (dtype is tl.bfloat16 or dtype == "bfloat16" or dtype is torch.bfloat16): pytest.skip("bfloat16 is only supported on NVGPU with cc >= 80") - if cc[0] < 9 and (dtype is tl.float8e4nv or dtype == "float8e4"): - pytest.skip("float8e4 is only supported on NVGPU with cc >= 90") + if cc[0] < 9 and (dtype is tl.float8e4nv or dtype == "float8e4nv"): + pytest.skip("float8e4nv is only supported on NVGPU with cc >= 90") class MmaLayout: @@ -3750,3 +3750,86 @@ def kernel(in_ptr0, out_ptr2, xnumel, rnumel, dtype: tl.constexpr, XBLOCK: tl.co buf14 = -torch.ones((s0, 6, 197, 197), device=device, dtype=torch_dtype) kernel[(4728,)](buf11, buf14, 1182 * s0, 197, triton_dtype, 1, 256, num_warps=2) assert buf14.to(torch.float32).mean() == -2.0 + +# ----------------------- +# test fp8 -> fp32 dot +# ----------------------- + + +def f8_to_f16(x, dtype): + @triton.jit + def kernel(Y, X, N, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offs < N + x = tl.load(X + offs, mask=mask) + tl.store(Y + offs, x, mask=mask) + + ret = torch.empty(x.shape, dtype=torch.float16, device=x.device) + grid = lambda META: (triton.cdiv(x.numel(), META['BLOCK_SIZE']),) + dtype = getattr(tl, dtype) + kernel[grid](ret, triton.reinterpret(x, dtype), ret.numel(), BLOCK_SIZE=1024) + return ret + + +@triton.jit +def matmul_kernel( + a_ptr, b_ptr, c_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + low_precision_acc: tl.constexpr, +): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + pid_m = pid % num_pid_m + pid_n = pid // num_pid_m + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) + accumulator = tl.dot(a, b, acc=accumulator, max_num_imprecise_acc=low_precision_acc) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + tl.store(c_ptrs, accumulator) + + +@pytest.mark.parametrize("in_type_str", ['float8e5', 'float8e4nv']) +@pytest.mark.parametrize("low_precision_acc", [0, 32, 64, 128]) +def test_fp8_dot_acc(in_type_str, low_precision_acc, device): + check_type_supported(in_type_str, device) + M, N, K = 128, 256, 256 + BLOCK_M, BLOCK_N, BLOCK_K = 128, 256, 128 + A = numpy_random((M, K), dtype_str=in_type_str) + B = numpy_random((K, N), dtype_str=in_type_str) + Bt = B.T + C = torch.empty((M, N), dtype=torch.float32, device='cuda') + num_warps = 8 + a = to_triton(A, device='cuda', dst_type=in_type_str) + b = to_triton(B, device='cuda', dst_type=in_type_str) + grid = (triton.cdiv(M, BLOCK_M), 1) + matmul_kernel[grid](a, b, C, M, N, K, + a.stride(0), a.stride(1), b.stride(0), b.stride( + 1), C.stride(0), C.stride(1), + BLOCK_M, BLOCK_N, BLOCK_K, low_precision_acc, num_warps=num_warps) + torch_a = torch.from_numpy(A) + th_a = f8_to_f16(torch_a.cuda(), in_type_str) + torch_b = torch.from_numpy(B) + th_b = f8_to_f16(torch_b.cuda(), in_type_str) + ref_out = torch.matmul(th_a, th_b).to(torch.float32) + if in_type_str == 'float8e4nv': + torch.testing.assert_close(ref_out, C, rtol=0.01, atol=0.01) + elif low_precision_acc > 32: + torch.testing.assert_close(ref_out, C, rtol=1e-3, atol=1e-3) + else: + torch.testing.assert_close(ref_out, C) diff --git a/python/test/unit/operators/test_matmul.py b/python/test/unit/operators/test_matmul.py index 19b5e0f050a2..642b0982b45a 100644 --- a/python/test/unit/operators/test_matmul.py +++ b/python/test/unit/operators/test_matmul.py @@ -26,61 +26,61 @@ def kernel(Y, X, N, BLOCK_SIZE: tl.constexpr): @pytest.mark.parametrize( - "BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, ALLOW_TF32", + "BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, ALLOW_TF32, F8_FASTACCUM", itertools.chain( *[ [ # 1 warp - (16, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), - (32, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), - (16, 32, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), - (16, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), - (32, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), - (16, 32, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), - (16, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), - (64, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), - (16, 64, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), + (16, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (32, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (16, 32, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (16, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (32, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (16, 32, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (16, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (64, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (16, 64, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), # 2 warp - (64, 32, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), - (32, 64, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), - (64, 32, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), - (32, 64, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), - (128, 32, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), - (32, 128, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), + (64, 32, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (32, 64, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (64, 32, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (32, 64, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (128, 32, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (32, 128, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), # 4 warp - (128, 64, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), - (64, 128, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), - (128, 32, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), - (32, 128, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), - (128, 32, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), - (32, 128, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), + (128, 64, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (64, 128, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (128, 32, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (32, 128, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (128, 32, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (32, 128, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), # 8 warp - (128, 256, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), - (256, 128, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), - (256, 128, 32, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), + (128, 256, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (256, 128, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (256, 128, 32, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), # variable input - (128, 128, 32, 1, 4, 2, 256, 384, 160, AT, BT, DTYPE, DTYPE, True), - (128, 128, 32, 1, 4, 2, 107, 233, 128, AT, BT, DTYPE, DTYPE, True), - (128, 128, 32, 1, 4, 2, 107, 233, 83, AT, BT, DTYPE, DTYPE, True), - (128, 256, 64, 1, 8, 3, 256, 512, 160, AT, BT, DTYPE, DTYPE, True), + (128, 128, 32, 1, 4, 2, 256, 384, 160, AT, BT, DTYPE, DTYPE, True, True), + (128, 128, 32, 1, 4, 2, 107, 233, 128, AT, BT, DTYPE, DTYPE, True, True), + (128, 128, 32, 1, 4, 2, 107, 233, 83, AT, BT, DTYPE, DTYPE, True, True), + (128, 256, 64, 1, 8, 3, 256, 512, 160, AT, BT, DTYPE, DTYPE, True, True), ] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True] ], # n-stage *[ [ - (16, 16, 16, 1, 1, STAGES, 32, 32, 80, AT, BT, DTYPE, DTYPE, True), - (64, 32, 64, 1, 2, STAGES, 128, 64, 128, AT, BT, DTYPE, DTYPE, True), - (128, 64, 16, 1, 4, STAGES, 256, 128, 80, AT, BT, DTYPE, DTYPE, True), - (256, 128, 32, 1, 8, STAGES, 512, 256, 160, AT, BT, DTYPE, DTYPE, True), - (128, 128, 32, 1, 4, STAGES, 256, 256, 160, AT, BT, DTYPE, DTYPE, True), + (16, 16, 16, 1, 1, STAGES, 32, 32, 80, AT, BT, DTYPE, DTYPE, True, True), + (64, 32, 64, 1, 2, STAGES, 128, 64, 128, AT, BT, DTYPE, DTYPE, True, True), + (128, 64, 16, 1, 4, STAGES, 256, 128, 80, AT, BT, DTYPE, DTYPE, True, True), + (256, 128, 32, 1, 8, STAGES, 512, 256, 160, AT, BT, DTYPE, DTYPE, True, True), + (128, 128, 32, 1, 4, STAGES, 256, 256, 160, AT, BT, DTYPE, DTYPE, True, True), ] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True] for STAGES in [4] ], # mixed-precision *[ [ - (32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, True), - (128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, True), - (32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE, True), + (32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, True, FASTACCUM), + (128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, True, FASTACCUM), + (32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE, True, FASTACCUM), ] for ADTYPE, BDTYPE in [("float8e4nv", "float8e5"), ("float8e4nv", "float8e4nv"), ("float8e5", "float8e4nv"), @@ -91,14 +91,14 @@ def kernel(Y, X, N, BLOCK_SIZE: tl.constexpr): ("float16", "float32"), ("float32", "float16"), ("bfloat16", "float32"), - ("float32", "bfloat16")] for AT in [False, True] for BT in [False, True] + ("float32", "bfloat16")] for AT in [False, True] for BT in [False, True] for FASTACCUM in [True, False] ], # mixed-precision block layout *[ [ - (32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, False), - (128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, False), - (32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE, False), + (32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, False, True), + (128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, False, True), + (32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE, False, True), ] for ADTYPE, BDTYPE in [("float8e4nv", "float16"), ("float16", "float8e5"), ("float16", "float32"), @@ -108,7 +108,7 @@ def kernel(Y, X, N, BLOCK_SIZE: tl.constexpr): ], ), ) -def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, ALLOW_TF32): +def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, ALLOW_TF32, F8_FASTACCUM): capability = torch.cuda.get_device_capability() if capability[0] < 7: pytest.skip("Only test tl.dot() on devices with sm >= 70") @@ -176,7 +176,7 @@ def init_input(m, n, dtype): a = triton.reinterpret(a, getattr(tl, ADTYPE)) if b_fp8: b = triton.reinterpret(b, getattr(tl, BDTYPE)) - tt_c = triton.ops.matmul(a, b, None, ALLOW_TF32) + tt_c = triton.ops.matmul(a, b, None, ALLOW_TF32, F8_FASTACCUM) torch.testing.assert_close(th_c, tt_c) except triton.OutOfResources as e: pytest.skip(str(e)) diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 451aa64fd58a..a0fcb1633b5e 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -985,7 +985,7 @@ def expand_dims(input, axis, _builder=None): @builtin -def dot(input, other, allow_tf32=True, out_dtype=float32, _builder=None): +def dot(input, other, acc=None, allow_tf32=True, max_num_imprecise_acc=None, out_dtype=float32, _builder=None): """ Returns the matrix product of two blocks. @@ -998,7 +998,7 @@ def dot(input, other, allow_tf32=True, out_dtype=float32, _builder=None): """ allow_tf32 = _constexpr_to_value(allow_tf32) out_dtype = _constexpr_to_value(out_dtype) - return semantic.dot(input, other, allow_tf32, out_dtype, _builder) + return semantic.dot(input, other, acc, allow_tf32, max_num_imprecise_acc, out_dtype, _builder) # ----------------------- diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 49597e4cb7af..2b2b55136395 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1265,7 +1265,9 @@ def mfma_supported(M, N, K, allow_tf32, ret_scalar_ty) -> bool: def dot(lhs: tl.tensor, rhs: tl.tensor, + acc: tl.tensor, allow_tf32: bool, + max_num_imprecise_acc: int, out_dtype: tl.dtype, builder: ir.builder) -> tl.tensor: def assert_dtypes_valid(lhs_dtype, rhs_dtype, arch): @@ -1343,10 +1345,20 @@ def assert_dtypes_valid(lhs_dtype, rhs_dtype, arch): ret = tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32), ret_ty) return cast(ret, ret_scalar_ty, builder) - - _0 = builder.create_splat(_0, [M, N]) ret_ty = tl.block_type(ret_scalar_ty, [M, N]) - return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32), + if acc is None: + acc_handle = builder.create_splat(_0, [M, N]) + else: + acc_handle = acc.handle + assert acc.type == ret_ty + + # max_num_imprecise_acc only applies to fp8 -> fp32 dot on sm_90 + if not (_is_cuda(builder.arch) and builder.arch == 90 and lhs.dtype.is_fp8() and rhs.dtype.is_fp8() and ret_scalar_ty.is_fp32()): + max_num_imprecise_acc = 0 + if max_num_imprecise_acc is None: + max_num_imprecise_acc = 2**30 + + return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, acc_handle, allow_tf32, max_num_imprecise_acc), ret_ty) diff --git a/python/triton/ops/matmul.py b/python/triton/ops/matmul.py index 35407b578bcd..4dc33e6b9ccf 100644 --- a/python/triton/ops/matmul.py +++ b/python/triton/ops/matmul.py @@ -82,6 +82,7 @@ def _kernel(A, B, C, M, N, K, stride_cm, stride_cn, dot_out_dtype: tl.constexpr, allow_tf32: tl.constexpr, + fp8_fast_accum: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, AB_DTYPE: tl.constexpr ): @@ -118,7 +119,10 @@ def _kernel(A, B, C, M, N, K, if AB_DTYPE: a = a.to(C.dtype.element_ty) b = b.to(C.dtype.element_ty) - acc += tl.dot(a, b, out_dtype=dot_out_dtype, allow_tf32=allow_tf32) + if fp8_fast_accum: + acc = tl.dot(a, b, acc, out_dtype=dot_out_dtype, allow_tf32=allow_tf32) + else: + acc += tl.dot(a, b, out_dtype=dot_out_dtype, allow_tf32=allow_tf32) A += BLOCK_K * SPLIT_K * stride_ak B += BLOCK_K * SPLIT_K * stride_bk acc = acc.to(C.dtype.element_ty) @@ -140,7 +144,7 @@ class _matmul(torch.autograd.Function): _locks = {} @staticmethod - def _call(a, b, dot_out_dtype, allow_tf32): + def _call(a, b, dot_out_dtype, allow_tf32, fp8_fast_accum): device = a.device # handle non-contiguous inputs if necessary if a.stride(0) > 1 and a.stride(1) > 1: @@ -182,12 +186,13 @@ def _call(a, b, dot_out_dtype, allow_tf32): c.stride(0), c.stride(1), dot_out_dtype=dot_out_dtype, allow_tf32=allow_tf32, + fp8_fast_accum=fp8_fast_accum, GROUP_M=8, AB_DTYPE=ab_dtype) return c @staticmethod - def forward(ctx, a, b, dot_out_dtype=None, allow_tf32=True): - return _matmul._call(a, b, dot_out_dtype=dot_out_dtype, allow_tf32=allow_tf32) + def forward(ctx, a, b, dot_out_dtype=None, allow_tf32=True, fp8_fast_accum=True): + return _matmul._call(a, b, dot_out_dtype=dot_out_dtype, allow_tf32=allow_tf32, fp8_fast_accum=fp8_fast_accum) matmul = _matmul.apply diff --git a/test/Analysis/test-alias.mlir b/test/Analysis/test-alias.mlir index 75740d929da2..1f3b1df6541b 100644 --- a/test/Analysis/test-alias.mlir +++ b/test/Analysis/test-alias.mlir @@ -26,7 +26,7 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A_DOT> %b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL> %b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B_DOT> - %c = tt.dot %a, %b, %prev_c {transA = false, transB = false, allowTF32 = true} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> + %c = tt.dot %a, %b, %prev_c {transA = false, transB = false, allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> diff --git a/test/Analysis/test-allocation.mlir b/test/Analysis/test-allocation.mlir index 919dca69201d..c0175811fff2 100644 --- a/test/Analysis/test-allocation.mlir +++ b/test/Analysis/test-allocation.mlir @@ -34,7 +34,7 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, // CHECK-NEXT: offset = 0, size = 4224 %b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B_DOT> - %c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> + %c = tt.dot %a, %b, %prev_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> @@ -64,11 +64,11 @@ tt.func @reusable(%A : !tt.ptr) { %a3_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL> // CHECK-NEXT: offset = 0, size = 4608 %a3 = triton_gpu.convert_layout %a3_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A_DOT> - %c = tt.dot %a1, %a2, %c_init {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> + %c = tt.dot %a1, %a2, %c_init {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> %a4_ = tt.load %b_ptr, %cst3, %cst4 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #AL> // CHECK-NEXT: offset = 0, size = 1152 %a4 = triton_gpu.convert_layout %a4_ : (tensor<32x128xf16, #AL>) -> tensor<32x128xf16, #B_DOT> - %c1 = tt.dot %a3, %a4, %c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> + %c1 = tt.dot %a3, %a4, %c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> tt.return // CHECK-NEXT: size = 4608 } diff --git a/test/Analysis/test-membar.mlir b/test/Analysis/test-membar.mlir index 63b4ef5d2ca1..961176cc68dc 100644 --- a/test/Analysis/test-membar.mlir +++ b/test/Analysis/test-membar.mlir @@ -32,7 +32,7 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A_DOT> %b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL> %b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B_DOT> - %c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> + %c = tt.dot %a, %b, %prev_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> diff --git a/test/Conversion/invalid.mlir b/test/Conversion/invalid.mlir index 81b86650291f..178d5109f0f7 100644 --- a/test/Conversion/invalid.mlir +++ b/test/Conversion/invalid.mlir @@ -6,7 +6,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { tt.func @convert_dot(%A: tensor<16x16xf32, #dot_operand_a>, %B: tensor<16x16xf16, #dot_operand_b>, %C: tensor<16x16xf32, #mma0>) { // expected-error@+1 {{element types of operands A and B must have same bit width}} - %D = tt.dot %A, %B, %C {allowTF32 = true, transA = false, transB = false} : + %D = tt.dot %A, %B, %C {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<16x16xf32, #dot_operand_a> * tensor<16x16xf16, #dot_operand_b> -> tensor<16x16xf32, #mma0> tt.return } @@ -20,7 +20,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { module attributes {"triton_gpu.num-warps" = 1 : i32} { tt.func @convert_dot(%A: tensor<16x16xf16>, %B: tensor<16x16xf16, #dot_operand_b>, %C: tensor<16x16xf32, #mma0>) { // expected-error@+1 {{mismatching encoding between A and B operands}} - %D = tt.dot %A, %B, %C {allowTF32 = true, transA = false, transB = false} : + %D = tt.dot %A, %B, %C {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<16x16xf16> * tensor<16x16xf16, #dot_operand_b> -> tensor<16x16xf32, #mma0> tt.return } @@ -34,7 +34,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { module attributes {"triton_gpu.num-warps" = 1 : i32} { tt.func @convert_dot(%A: tensor<16x16xf16, #dot_operand_a>, %B: tensor<16x16xf16, #dot_operand_b>, %C: tensor<16x16xf32, #mma0>) { // expected-error@+1 {{mismatching kWidth between A and B operands}} - %D = tt.dot %A, %B, %C {allowTF32 = true, transA = false, transB = false} : + %D = tt.dot %A, %B, %C {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<16x16xf16, #dot_operand_a> * tensor<16x16xf16, #dot_operand_b> -> tensor<16x16xf32, #mma0> tt.return } diff --git a/test/Conversion/triton_ops.mlir b/test/Conversion/triton_ops.mlir index bb4cba09645d..ce6505d72e49 100644 --- a/test/Conversion/triton_ops.mlir +++ b/test/Conversion/triton_ops.mlir @@ -161,13 +161,13 @@ tt.func @dot_ops_infer(%ptr: !tt.ptr, %v : f32) { %zero1x1 = arith.constant dense<0.00e+00> : tensor<1x1xf32> // CHECK: %{{.*}} = tt.dot %{{.*}} -> tensor<128x128xf32> - %r1 = tt.dot %v128x32, %v32x128, %zero128x128 {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf32> * tensor<32x128xf32> -> tensor<128x128xf32> + %r1 = tt.dot %v128x32, %v32x128, %zero128x128 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf32> * tensor<32x128xf32> -> tensor<128x128xf32> // CHECK: %{{.*}} = tt.dot %{{.*}} -> tensor<32x32xf32> - %r2 = tt.dot %v32x128, %v128x32, %zero32x32 {allowTF32 = true, transA = false, transB = false} : tensor<32x128xf32> * tensor<128x32xf32> -> tensor<32x32xf32> + %r2 = tt.dot %v32x128, %v128x32, %zero32x32 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<32x128xf32> * tensor<128x32xf32> -> tensor<32x32xf32> // CHECK: %{{.*}} = tt.dot %{{.*}} -> tensor<128x128xf32> - %r3 = tt.dot %v128x1, %v1x128, %zero128x128 {allowTF32 = true, transA = false, transB = false} : tensor<128x1xf32> * tensor<1x128xf32> -> tensor<128x128xf32> + %r3 = tt.dot %v128x1, %v1x128, %zero128x128 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x1xf32> * tensor<1x128xf32> -> tensor<128x128xf32> // CHECK: %{{.*}} = tt.dot %{{.*}} -> tensor<1x1xf32> - %r4 = tt.dot %v1x128, %v128x1, %zero1x1 {allowTF32 = true, transA = false, transB = false} : tensor<1x128xf32> * tensor<128x1xf32> -> tensor<1x1xf32> + %r4 = tt.dot %v1x128, %v128x1, %zero1x1 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<1x128xf32> * tensor<128x1xf32> -> tensor<1x1xf32> %ptr128x128 = tt.splat %ptr : (!tt.ptr) -> tensor<128x128x!tt.ptr> %ptr32x32 = tt.splat %ptr : (!tt.ptr) -> tensor<32x32x!tt.ptr> diff --git a/test/Conversion/triton_to_tritongpu.mlir b/test/Conversion/triton_to_tritongpu.mlir index 185d88906d32..c08843e58540 100644 --- a/test/Conversion/triton_to_tritongpu.mlir +++ b/test/Conversion/triton_to_tritongpu.mlir @@ -5,7 +5,7 @@ tt.func @ops() { %a = arith.constant dense<1.00e+00> : tensor<128x32xf16> %b = arith.constant dense<2.00e+00> : tensor<32x128xf16> %c = arith.constant dense<3.00e+00> : tensor<128x128xf32> - %0 = tt.dot %a, %b, %c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16> * tensor<32x128xf16> -> tensor<128x128xf32> + %0 = tt.dot %a, %b, %c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16> * tensor<32x128xf16> -> tensor<128x128xf32> tt.return } diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index acd9d88ee45a..834c2eecbcc5 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -834,7 +834,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // CHECK-SAME: mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 // CHECK: llvm.inline_asm // CHECK-SAME: mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 - %D = tt.dot %AA_DOT, %BB_DOT, %cst0 {allowTF32 = true, transA = false, transB = false} : tensor<16x16xf16, #dot_operand_a> * tensor<16x16xf16, #dot_operand_b> -> tensor<16x16xf32, #mma0> + %D = tt.dot %AA_DOT, %BB_DOT, %cst0 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<16x16xf16, #dot_operand_a> * tensor<16x16xf16, #dot_operand_b> -> tensor<16x16xf32, #mma0> tt.return } @@ -967,7 +967,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %a_mat = triton_gpu.convert_layout %a : (tensor<128x32xf16, #shared>) -> tensor<128x32xf16, #dot_operand_a> %b_mat = triton_gpu.convert_layout %b : (tensor<32x256xf16, #shared>) -> tensor<32x256xf16, #dot_operand_b> - %28 = tt.dot %a_mat, %b_mat, %cst {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #dot_operand_a> * tensor<32x256xf16, #dot_operand_b> -> tensor<128x256xf32, #mma> + %28 = tt.dot %a_mat, %b_mat, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #dot_operand_a> * tensor<32x256xf16, #dot_operand_b> -> tensor<128x256xf32, #mma> %38 = triton_gpu.convert_layout %28 : (tensor<128x256xf32, #mma>) -> tensor<128x256xf32, #blocked> %30 = tt.splat %ptr : (!tt.ptr) -> tensor<128x1x!tt.ptr, #blocked> @@ -993,7 +993,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %a_mat = triton_gpu.convert_layout %a : (tensor<32x64xf16, #shared0>) -> tensor<32x64xf16, #dot_operand_a> %b_mat = triton_gpu.convert_layout %b : (tensor<64x64xf16, #shared1>) -> tensor<64x64xf16, #dot_operand_b> - %28 = tt.dot %a_mat, %b_mat, %cst {allowTF32 = true, transA = false, transB = false} : tensor<32x64xf16, #dot_operand_a> * tensor<64x64xf16, #dot_operand_b> -> tensor<32x64xf32, #mma> + %28 = tt.dot %a_mat, %b_mat, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<32x64xf16, #dot_operand_a> * tensor<64x64xf16, #dot_operand_b> -> tensor<32x64xf32, #mma> %38 = triton_gpu.convert_layout %28 : (tensor<32x64xf32, #mma>) -> tensor<32x64xf32, #blocked> %30 = tt.splat %ptr : (!tt.ptr) -> tensor<32x1x!tt.ptr, #blocked> %36 = tt.broadcast %30 : (tensor<32x1x!tt.ptr, #blocked>) -> tensor<32x64x!tt.ptr, #blocked> @@ -1016,7 +1016,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %a_mat = triton_gpu.convert_layout %a : (tensor<32x16xf32, #shared>) -> tensor<32x16xf32, #dot_operand_a> %b_mat = triton_gpu.convert_layout %b : (tensor<16x32xf32, #shared>) -> tensor<16x32xf32, #dot_operand_b> - %28 = tt.dot %a_mat, %b_mat, %cst {allowTF32 = false, transA = false, transB = false} : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #blocked> + %28 = tt.dot %a_mat, %b_mat, %cst {allowTF32 = false, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #blocked> %30 = tt.splat %ptr : (!tt.ptr) -> tensor<32x1x!tt.ptr, #blocked> %36 = tt.broadcast %30 : (tensor<32x1x!tt.ptr, #blocked>) -> tensor<32x32x!tt.ptr, #blocked> tt.store %36, %28 : tensor<32x32xf32, #blocked> @@ -1053,7 +1053,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-SAME: mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 // CHECK: llvm.inline_asm // CHECK-SAME: mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 - %28 = tt.dot %a_mat, %b_mat, %cst {allowTF32 = true, transA = false, transB = false} : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #mma> + %28 = tt.dot %a_mat, %b_mat, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #mma> %38 = triton_gpu.convert_layout %28 : (tensor<32x32xf32, #mma>) -> tensor<32x32xf32, #blocked> %30 = tt.splat %ptr : (!tt.ptr) -> tensor<32x1x!tt.ptr, #blocked> @@ -1265,7 +1265,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: %[[SI:.+]] = llvm.mlir.undef : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> // CHECK: llvm.insertvalue %[[BC]], %[[SI]][0] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> %b_mat = arith.constant dense<1.000000e+00> : tensor<16x32xf32, #dot_operand_b> - %28 = tt.dot %a, %b_mat, %c {allowTF32 = true, transA = false, transB = false} : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #mma> + %28 = tt.dot %a, %b_mat, %c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #mma> %38 = triton_gpu.convert_layout %28 : (tensor<32x32xf32, #mma>) -> tensor<32x32xf32, #blocked> %30 = tt.splat %ptr : (!tt.ptr) -> tensor<32x1x!tt.ptr, #blocked> %36 = tt.broadcast %30 : (tensor<32x1x!tt.ptr, #blocked>) -> tensor<32x32x!tt.ptr, #blocked> @@ -1295,7 +1295,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %cst_0 = arith.constant dense<1.000000e+00> : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> %cst_1 = arith.constant dense<1.000000e+00> : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> %cst_2 = arith.constant dense<32> : tensor<32x1xi32, #blocked> - %0 = tt.dot %cst_0, %cst_1, %cst {allowTF32 = true} : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x32xf32, #mma> + %0 = tt.dot %cst_0, %cst_1, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x32xf32, #mma> %1 = triton_gpu.convert_layout %0 : (tensor<32x32xf32, #mma>) -> tensor<32x32xf32, #blocked> %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> %3 = tt.expand_dims %2 {axis = 1 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<32x1xi32, #blocked> diff --git a/test/Conversion/tritongpu_to_llvm_hopper.mlir b/test/Conversion/tritongpu_to_llvm_hopper.mlir index 053330d47a49..2095788623f4 100644 --- a/test/Conversion/tritongpu_to_llvm_hopper.mlir +++ b/test/Conversion/tritongpu_to_llvm_hopper.mlir @@ -78,3 +78,74 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : tt.return } } + +// ----- + +#mma = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}> +#shared = #triton_gpu.shared<{vec = 16, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +#shared1 = #triton_gpu.shared<{vec = 16, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32} { + // CHECK-LABEL: @dot_high_precision_acc + tt.func @dot_high_precision_acc(%a: tensor<128x128xf8E5M2, #shared>, %b: tensor<128x256xf8E5M2, #shared1>, %c: tensor<128x256xf32, #mma>) { + // CHECK: nvgpu.wgmma + // CHECK-COUNT-128: llvm.fadd + // CHECK: nvgpu.wgmma + // CHECK-COUNT-128: llvm.fadd + // CHECK: nvgpu.wgmma + // CHECK-COUNT-128: llvm.fadd + // CHECK: nvgpu.wgmma + // CHECK-COUNT-128: llvm.fadd + %m = triton_nvidia_gpu.dot_async %a, %b, %c + {maxNumImpreciseAcc = 32 : i32, allowTF32 = true} : + tensor<128x128xf8E5M2, #shared> * tensor<128x256xf8E5M2, #shared1> -> tensor<128x256xf32, #mma> + tt.return + } +} + +// ----- + +#mma = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}> +#shared = #triton_gpu.shared<{vec = 16, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +#shared1 = #triton_gpu.shared<{vec = 16, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32} { + // CHECK-LABEL: @dot_low_precision_acc + tt.func @dot_low_precision_acc(%a: tensor<128x128xf8E5M2, #shared>, %b: tensor<128x256xf8E5M2, #shared1>, %c: tensor<128x256xf32, #mma>) { + // CHECK: nvgpu.wgmma + // CHECK-NOT: llvm.fadd + // CHECK: nvgpu.wgmma + // CHECK-NOT: llvm.fadd + // CHECK: nvgpu.wgmma + // CHECK-NOT: llvm.fadd + // CHECK: nvgpu.wgmma + // CHECK-NOT: llvm.fadd + // CHECK: llvm.return + %m = triton_nvidia_gpu.dot_async %a, %b, %c + {maxNumImpreciseAcc = 129 : i32, allowTF32 = true} : + tensor<128x128xf8E5M2, #shared> * tensor<128x256xf8E5M2, #shared1> -> tensor<128x256xf32, #mma> + tt.return + } +} + +// ----- + +#mma = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}> +#shared = #triton_gpu.shared<{vec = 16, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +#shared1 = #triton_gpu.shared<{vec = 16, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32} { + // CHECK-LABEL: @dot_mix_precision_acc + tt.func @dot_mix_precision_acc(%a: tensor<128x128xf8E5M2, #shared>, %b: tensor<128x256xf8E5M2, #shared1>, %c: tensor<128x256xf32, #mma>) { + // CHECK: nvgpu.wgmma + // CHECK-NOT: llvm.fadd + // CHECK: nvgpu.wgmma + // CHECK-COUNT-128: llvm.fadd + // CHECK: nvgpu.wgmma + // CHECK-NOT: llvm.fadd + // CHECK: nvgpu.wgmma + // CHECK-COUNT-128: llvm.fadd + // CHECK: llvm.return + %m = triton_nvidia_gpu.dot_async %a, %b, %c + {maxNumImpreciseAcc = 64 : i32, allowTF32 = true} : + tensor<128x128xf8E5M2, #shared> * tensor<128x256xf8E5M2, #shared1> -> tensor<128x256xf32, #mma> + tt.return + } +} diff --git a/test/NVGPU/test_wgmma.mlir b/test/NVGPU/test_wgmma.mlir index f4ae65ad04cf..ee059b329fb1 100644 --- a/test/NVGPU/test_wgmma.mlir +++ b/test/NVGPU/test_wgmma.mlir @@ -17,3 +17,16 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 2 tt.return } } // end module + +// ----- + +module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 2 : i32} { + tt.func @wgmma_no_acc(%descA: i64, %descB: i64) { + // CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "wgmma.mma_async.sync.aligned.m64n256k32.f32.e5m2.e5m2 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63,$64,$65,$66,$67,$68,$69,$70,$71,$72,$73,$74,$75,$76,$77,$78,$79,$80,$81,$82,$83,$84,$85,$86,$87,$88,$89,$90,$91,$92,$93,$94,$95,$96,$97,$98,$99,$100,$101,$102,$103,$104,$105,$106,$107,$108,$109,$110,$111,$112,$113,$114,$115,$116,$117,$118,$119,$120,$121,$122,$123,$124,$125,$126,$127}, $128, $129, 0, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,l,l" %0, %1 : (i64, i64) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %acc0 = nvgpu.wgmma %descA, %descB + {eltTypeA = 3 : i32, eltTypeB = 3 : i32, eltTypeC = 7 : i32, k = 32 : i32, layoutA = 0 : i32, layoutB = 1 : i32, m = 64 : i32, n = 256 : i32} : + (i64, i64) -> + !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + tt.return + } +} diff --git a/test/Triton/combine.mlir b/test/Triton/combine.mlir index deb85dc6222c..170bad012d23 100644 --- a/test/Triton/combine.mlir +++ b/test/Triton/combine.mlir @@ -10,12 +10,12 @@ tt.func @test_combine_dot_add_pattern() -> (tensor<128x128xf32>, tensor<128x128x %zero = arith.constant dense<0.0> : tensor<128x128xf32> %d = arith.constant dense<3.0> : tensor<128x128xf32> - %dot_out = tt.dot %a, %b, %zero {allowTF32 = true, transA = false, transB = false} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32> + %dot_out = tt.dot %a, %b, %zero {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32> - // CHECK-NEXT: %[[res0:.*]] = tt.dot %[[a]], %[[b]], %[[d]] {allowTF32 = true} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32> + // CHECK-NEXT: %[[res0:.*]] = tt.dot %[[a]], %[[b]], %[[d]] {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32> %res0 = arith.addf %dot_out, %d : tensor<128x128xf32> - // CHECK-NEXT: %[[res1:.*]] = tt.dot %[[a]], %[[b]], %[[d]] {allowTF32 = true} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32> + // CHECK-NEXT: %[[res1:.*]] = tt.dot %[[a]], %[[b]], %[[d]] {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32> %res1 = arith.addf %d, %dot_out : tensor<128x128xf32> tt.return %res0, %res1 : tensor<128x128xf32>, tensor<128x128xf32> diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir index c6be17f8eb04..5c3fdd6b9ba6 100644 --- a/test/TritonGPU/combine.mlir +++ b/test/TritonGPU/combine.mlir @@ -1543,7 +1543,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war %26 = triton_gpu.convert_layout %19 : (tensor<32x32xf16, #blocked>) -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked5}>> %27 = triton_gpu.convert_layout %25 : (tensor<32x32xf16, #blocked>) -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked5}>> %28 = triton_gpu.convert_layout %cst : (tensor<32x32xf32, #blocked>) -> tensor<32x32xf32, #blocked5> - %29 = tt.dot %26, %27, %28 {allowTF32 = true} : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked5}>> * tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked5}>> -> tensor<32x32xf32, #blocked5> + %29 = tt.dot %26, %27, %28 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked5}>> * tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked5}>> -> tensor<32x32xf32, #blocked5> %30 = triton_gpu.convert_layout %29 : (tensor<32x32xf32, #blocked5>) -> tensor<32x32xf32, #blocked> %31:2 = "tt.reduce"(%30, %11) <{axis = 1 : i32}> ({ ^bb0(%arg3: f32, %arg4: i32, %arg5: f32, %arg6: i32): @@ -1690,7 +1690,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c %117 = tt.load %116 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf16, #blocked3> %118 = triton_gpu.convert_layout %41 : (tensor<128x64xf16, #blocked2>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> %119 = triton_gpu.convert_layout %97 : (tensor<64x64xf16, #blocked6>) -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> - %120 = tt.dot %118, %119, %cst {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x64xf16, #blocked> + %120 = tt.dot %118, %119, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x64xf16, #blocked> %121 = triton_gpu.convert_layout %120 : (tensor<128x64xf16, #blocked>) -> tensor<128x64xf16, #blocked2> %122 = arith.extf %121 : tensor<128x64xf16, #blocked2> to tensor<128x64xf32, #blocked2> %123 = "tt.reduce"(%122) <{axis = 1 : i32}> ({ @@ -1719,7 +1719,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c %142 = triton_gpu.convert_layout %141 : (tensor<128x64xf16, #blocked2>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> %143 = triton_gpu.convert_layout %117 : (tensor<64x64xf16, #blocked3>) -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> %144 = triton_gpu.convert_layout %140 : (tensor<128x64xf32, #blocked2>) -> tensor<128x64xf32, #blocked> - %145 = tt.dot %142, %143, %144 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x64xf32, #blocked> + %145 = tt.dot %142, %143, %144 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x64xf32, #blocked> %146 = triton_gpu.convert_layout %145 : (tensor<128x64xf32, #blocked>) -> tensor<128x64xf32, #blocked2> %147 = arith.mulf %arg24, %127 : tensor<128xf32, #blocked1> %148 = "tt.reduce"(%133) <{axis = 1 : i32}> ({ diff --git a/test/TritonGPU/dot-operands.mlir b/test/TritonGPU/dot-operands.mlir index ded8d0613bb9..039c9429438a 100644 --- a/test/TritonGPU/dot-operands.mlir +++ b/test/TritonGPU/dot-operands.mlir @@ -36,7 +36,7 @@ tt.func @push_elementwise( %a = tt.fp_to_fp %af8: tensor<16x16xf8E5M2, #ALR> -> tensor<16x16xf16, #ALR> %dota = triton_gpu.convert_layout %a : (tensor<16x16xf16, #ALR>) -> tensor<16x16xf16, #Av2k4> %dotb = triton_gpu.convert_layout %b : (tensor<16x16xf16, #BLC>) -> tensor<16x16xf16, #Bv2k4> - %newc = tt.dot %dota, %dotb, %c {allowTF32 = true, transA = false, transB = false} : tensor<16x16xf16, #Av2k4> * tensor<16x16xf16, #Bv2k4> -> tensor<16x16xf32, #Cv2> + %newc = tt.dot %dota, %dotb, %c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<16x16xf16, #Av2k4> * tensor<16x16xf16, #Bv2k4> -> tensor<16x16xf32, #Cv2> tt.return %newc : tensor<16x16xf32, #Cv2> } @@ -58,7 +58,7 @@ tt.func @succeeds_if_arg_is_not_convert_layout( %dotaf8 = tt.bitcast %dotai8 : tensor<16x16xi8, #Av2k4> -> tensor<16x16xf8E5M2, #Av2k4> %dota = tt.fp_to_fp %dotaf8 : tensor<16x16xf8E5M2, #Av2k4> -> tensor<16x16xf16, #Av2k4> %dotb = triton_gpu.convert_layout %b : (tensor<16x16xf16, #BLC>) -> tensor<16x16xf16, #Bv2k4> - %newc = tt.dot %dota, %dotb, %c {allowTF32 = true, transA = false, transB = false} : tensor<16x16xf16, #Av2k4> * tensor<16x16xf16, #Bv2k4> -> tensor<16x16xf32, #Cv2> + %newc = tt.dot %dota, %dotb, %c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<16x16xf16, #Av2k4> * tensor<16x16xf16, #Bv2k4> -> tensor<16x16xf32, #Cv2> tt.return %newc : tensor<16x16xf32, #Cv2> } @@ -82,7 +82,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.compute-capabil // CHECK: %[[AEXT:.*]] = arith.extf %[[ACVT]] : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> to tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> // CHECK: %[[BCVT:.*]] = triton_gpu.convert_layout %[[BLOAD]] : (tensor<16x16xf16, #[[BB]]>) -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> // CHECK: %[[BEXT:.*]] = arith.extf %[[BCVT]] : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> to tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -// CHECK: tt.dot %[[AEXT]], %[[BEXT]], %{{.*}} {allowTF32 = true} : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> * tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -> tensor<16x16xf32, #mma> +// CHECK: tt.dot %[[AEXT]], %[[BEXT]], %{{.*}} {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> * tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -> tensor<16x16xf32, #mma> tt.func @push_convert_both_operands( %pa: tensor<16x16x!tt.ptr, #blockedA> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, %pb: tensor<16x16x!tt.ptr, #blockedB> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, @@ -93,7 +93,7 @@ tt.func @push_convert_both_operands( %be = arith.extf %b : tensor<16x16xf16, #blockedB> to tensor<16x16xf32, #blockedB> %al = triton_gpu.convert_layout %ae : (tensor<16x16xf32, #blockedA>) -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> %bl = triton_gpu.convert_layout %be : (tensor<16x16xf32, #blockedB>) -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %r = tt.dot %al, %bl, %c {allowTF32 = true} : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma> + %r = tt.dot %al, %bl, %c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma> tt.return %r : tensor<16x16xf32, #mma> } @@ -119,7 +119,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.compute-capabil // CHECK: %[[BCVT:.*]] = triton_gpu.convert_layout %[[BLOAD]] : (tensor<16x16xf16, #[[BB]]>) -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> // CHECK: %[[BEXT:.*]] = arith.extf %[[BCVT]] : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> to tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> // CHECK: %[[ADD:.+]] = arith.addf %[[BEXT]], %[[CST]] : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -// CHECK: tt.dot %[[AEXT]], %[[ADD]], %{{.*}} {allowTF32 = true} : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> * tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -> tensor<16x16xf32, #mma> +// CHECK: tt.dot %[[AEXT]], %[[ADD]], %{{.*}} {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> * tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -> tensor<16x16xf32, #mma> tt.func @update_kwidth_slice( %pa: tensor<16x16x!tt.ptr, #blockedA> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, %pb: tensor<16x16x!tt.ptr, #blockedB> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, @@ -132,7 +132,7 @@ tt.func @update_kwidth_slice( %add = arith.addf %be, %cst : tensor<16x16xf32, #blockedB> %al = triton_gpu.convert_layout %ae : (tensor<16x16xf32, #blockedA>) -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> %bl = triton_gpu.convert_layout %add : (tensor<16x16xf32, #blockedB>) -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %r = tt.dot %al, %bl, %c {allowTF32 = true} : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma> + %r = tt.dot %al, %bl, %c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma> tt.return %r : tensor<16x16xf32, #mma> } diff --git a/test/TritonGPU/loop-pipeline-hopper.mlir b/test/TritonGPU/loop-pipeline-hopper.mlir index 7c54ce39b839..8bfb7b5760db 100644 --- a/test/TritonGPU/loop-pipeline-hopper.mlir +++ b/test/TritonGPU/loop-pipeline-hopper.mlir @@ -74,7 +74,7 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL> %b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B> - %c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + %c = tt.dot %a, %b, %prev_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> @@ -151,7 +151,7 @@ tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index, %b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL> %b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B> - %c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + %c = tt.dot %a, %b, %prev_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> @@ -220,7 +220,7 @@ tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, scf.for %iv = %lb to %ub step %step iter_args(%b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { %b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL> %b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B> - %c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + %c = tt.dot %a, %b, %prev_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> scf.yield %next_b_ptr, %c : tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> } @@ -293,7 +293,7 @@ tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, // // %sa = triton_gpu.convert_layout %a : (tensor<128x32xf16, #BA>) -> tensor<128x32xf16, #SA> // %sb = triton_gpu.convert_layout %b : (tensor<32x128xf16, #BB>) -> tensor<32x128xf16, #SB> -// %c = tt.dot %sa, %sb, %prev_c {allowTF32 = true} : tensor<128x32xf16, #SA> * tensor<32x128xf16, #SB> -> tensor<128x128xf32, #C> +// %c = tt.dot %sa, %sb, %prev_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #SA> * tensor<32x128xf16, #SB> -> tensor<128x128xf32, #C> // // %a_tileptr_next = tt.advance %a_tileptr, [%c0, %c32_i32] : !tt.ptr, 1> // %b_tileptr_next = tt.advance %b_tileptr, [%c32_i32, %c0] : !tt.ptr, 1> diff --git a/test/TritonGPU/loop-pipeline.mlir b/test/TritonGPU/loop-pipeline.mlir index 626e7bdb1c85..a5bb2f239f64 100644 --- a/test/TritonGPU/loop-pipeline.mlir +++ b/test/TritonGPU/loop-pipeline.mlir @@ -84,7 +84,7 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %b_ = triton_gpu.convert_layout %b__ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B> %b = arith.mulf %b_, %b_scale: tensor<32x128xf16, #B> - %c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + %c = tt.dot %a, %b, %prev_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> @@ -157,7 +157,7 @@ tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index, %b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL> %b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B> - %c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + %c = tt.dot %a, %b, %prev_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> @@ -224,7 +224,7 @@ tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, %loop:2 = scf.for %iv = %lb to %ub step %step iter_args(%b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { %b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL> %b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B> - %c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + %c = tt.dot %a, %b, %prev_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> scf.yield %next_b_ptr, %c : tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> } @@ -266,7 +266,7 @@ tt.func @lut_bmm_scalar(%77: i64 {tt.divisibility=16: i32}, %87 = tt.load %86 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #BL> %88 = triton_gpu.convert_layout %82 : (tensor<16x16xf16, #AL>) -> tensor<16x16xf16, #A> %89 = triton_gpu.convert_layout %87 : (tensor<16x16xf16, #BL>) -> tensor<16x16xf16, #B> - %90 = tt.dot %88, %89, %arg19 {allowTF32 = true} : tensor<16x16xf16, #A> * tensor<16x16xf16, #B> -> tensor<16x16xf32, #C> + %90 = tt.dot %88, %89, %arg19 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xf16, #A> * tensor<16x16xf16, #B> -> tensor<16x16xf32, #C> %91 = tt.addptr %arg20, %78 : tensor<16x16x!tt.ptr, #AL>, tensor<16x16xi32, #AL> %92 = tt.addptr %arg21, %c1_i32 : !tt.ptr, i32 scf.yield %90, %91, %92 : tensor<16x16xf32, #C>, tensor<16x16x!tt.ptr, #AL>, !tt.ptr @@ -312,7 +312,7 @@ tt.func @lut_bmm_vector(%77: tensor<16x16xi64, #BL> {tt.divisibility=16: i32, tt %87 = tt.load %86 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #BL> %88 = triton_gpu.convert_layout %82 : (tensor<16x16xf16, #AL>) -> tensor<16x16xf16, #A> %89 = triton_gpu.convert_layout %87 : (tensor<16x16xf16, #BL>) -> tensor<16x16xf16, #B> - %90 = tt.dot %88, %89, %arg19 {allowTF32 = true} : tensor<16x16xf16, #A> * tensor<16x16xf16, #B> -> tensor<16x16xf32, #C> + %90 = tt.dot %88, %89, %arg19 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xf16, #A> * tensor<16x16xf16, #B> -> tensor<16x16xf32, #C> %91 = tt.addptr %arg20, %78 : tensor<16x16x!tt.ptr, #AL>, tensor<16x16xi32, #AL> %92 = tt.addptr %arg21, %c1_i32_splat : tensor<16x!tt.ptr, #BLs1>, tensor<16xi32, #BLs1> scf.yield %90, %91, %92 : tensor<16x16xf32, #C>, tensor<16x16x!tt.ptr, #AL>, tensor<16x!tt.ptr, #BLs1> @@ -362,7 +362,7 @@ tt.func @post_load_inv(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %116 = tt.load %arg12, %115, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32xf32, #AL> %117 = triton_gpu.convert_layout %112 : (tensor<32x32xf32, #AL>) -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>> %118 = triton_gpu.convert_layout %116 : (tensor<32x32xf32, #AL>) -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>> - %119 = tt.dot %117, %118, %arg10 {allowTF32 = true} : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>> -> tensor<32x32xf32, #C> + %119 = tt.dot %117, %118, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>> -> tensor<32x32xf32, #C> %131 = arith.index_cast %arg9 : index to i32 %120 = arith.addi %131, %c1_i32 : i32 %121 = arith.muli %120, %c32_i32 : i32 @@ -425,7 +425,7 @@ tt.func @cross_iter_dep(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %150 = tt.load %arg12, %149, %cst_1 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32xf32, #AL> %151 = triton_gpu.convert_layout %146 : (tensor<32x32xf32, #AL>) -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>> %152 = triton_gpu.convert_layout %150 : (tensor<32x32xf32, #AL>) -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>> - %153 = tt.dot %151, %152, %arg10 {allowTF32 = true} : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>> -> tensor<32x32xf32, #C> + %153 = tt.dot %151, %152, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>> -> tensor<32x32xf32, #C> %162 = arith.index_cast %arg9 : index to i32 %154 = arith.addi %162, %c2_i32 : i32 %155 = arith.muli %154, %c32_i32 : i32 @@ -497,7 +497,7 @@ tt.func @dep_arg_two_uses(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %199 = tt.load %arg24, %198, %88 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL> %200 = triton_gpu.convert_layout %193 : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 2}>> %201 = triton_gpu.convert_layout %199 : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 2}>> - %202 = tt.dot %200, %201, %arg23 {allowTF32 = true} : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 2}>> * tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 2}>> -> tensor<128x128xf32, #C> + %202 = tt.dot %200, %201, %arg23 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 2}>> * tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 2}>> -> tensor<128x128xf32, #C> %203 = tt.addptr %arg24, %90 : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi64, #BL> scf.yield %190, %196, %197, %202, %203 : tensor<128x32x!tt.ptr, #AL>, tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #AL}>>, !tt.ptr, tensor<128x128xf32, #C>, tensor<32x128x!tt.ptr, #BL> } diff --git a/test/TritonGPU/materialize-load-store.mlir b/test/TritonGPU/materialize-load-store.mlir index 65ca0e6c65a7..58bc51514f36 100644 --- a/test/TritonGPU/materialize-load-store.mlir +++ b/test/TritonGPU/materialize-load-store.mlir @@ -52,7 +52,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c %8 = tt.load %6 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr, 1> -> tensor<16x16xf16, #blockedB1> %9 = triton_gpu.convert_layout %7 : (tensor<64x16xf16, #blockedA1>) -> tensor<64x16xf16, #sharedA> %10 = triton_gpu.convert_layout %8 : (tensor<16x16xf16, #blockedB1>) -> tensor<16x16xf16, #sharedB> - %11 = tt.dot %9, %10, %cst {allowTF32 = true} : tensor<64x16xf16, #sharedA> * tensor<16x16xf16, #sharedB> -> tensor<64x16xf32, #mma> + %11 = tt.dot %9, %10, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<64x16xf16, #sharedA> * tensor<16x16xf16, #sharedB> -> tensor<64x16xf32, #mma> %12 = triton_gpu.convert_layout %11 : (tensor<64x16xf32, #mma>) -> tensor<64x16xf32, #blockedA1> %13 = arith.truncf %12 : tensor<64x16xf32, #blockedA1> to tensor<64x16xf16, #blockedA1> %14 = arith.extsi %arg8 : i32 to i64 diff --git a/test/TritonGPU/matmul.mlir b/test/TritonGPU/matmul.mlir index 81560db65943..6c9264400a30 100644 --- a/test/TritonGPU/matmul.mlir +++ b/test/TritonGPU/matmul.mlir @@ -62,7 +62,7 @@ tt.func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__1 %47:3 = scf.for %arg12 = %c0 to %46 step %c64 iter_args(%arg13 = %cst_0, %arg14 = %34, %arg15 = %45) -> (tensor<64x64xf32>, tensor<64x64x!tt.ptr>, tensor<64x64x!tt.ptr>) { %76 = tt.load %arg14, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false, transA=false, transB=false} : tensor<64x64xf32> %77 = tt.load %arg15, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false, transA=false, transB=false} : tensor<64x64xf32> - %78 = tt.dot %76, %77, %cst_0 {allowTF32 = true, transA = false, transB = false} : tensor<64x64xf32> * tensor<64x64xf32> -> tensor<64x64xf32> + %78 = tt.dot %76, %77, %cst_0 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x64xf32> * tensor<64x64xf32> -> tensor<64x64xf32> %79 = arith.addf %arg13, %78 : tensor<64x64xf32> %80 = arith.muli %arg7, %c64_i32 : i32 %81 = tt.splat %80 : (i32) -> tensor<64x64xi32> diff --git a/test/TritonGPU/prefetch.mlir b/test/TritonGPU/prefetch.mlir index b820f4034abb..7104d8dc8e2e 100644 --- a/test/TritonGPU/prefetch.mlir +++ b/test/TritonGPU/prefetch.mlir @@ -53,7 +53,7 @@ tt.func @matmul_loop_mixed(%lb : index, %ub : index, %step : index, %A : !tt.ptr %a_op_ = triton_gpu.convert_layout %a : (tensor<128x32xf8E5M2, #A>) -> tensor<128x32xf8E5M2, #A_OP> %a_op = tt.fp_to_fp %a_op_ : tensor<128x32xf8E5M2, #A_OP> -> tensor<128x32xf16, #A_OP> %b_op = triton_gpu.convert_layout %b : (tensor<32x128xf16, #B>) -> tensor<32x128xf16, #B_OP> - %c = tt.dot %a_op, %b_op, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A_OP> * tensor<32x128xf16, #B_OP> -> tensor<128x128xf32, #C> + %c = tt.dot %a_op, %b_op, %prev_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #A_OP> * tensor<32x128xf16, #B_OP> -> tensor<128x128xf32, #C> %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> diff --git a/test/TritonGPU/reorder-instructions.mlir b/test/TritonGPU/reorder-instructions.mlir index d7a7f3ae2fc0..a5b0dda944fc 100644 --- a/test/TritonGPU/reorder-instructions.mlir +++ b/test/TritonGPU/reorder-instructions.mlir @@ -15,7 +15,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war %9 = tt.load %arg0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32xf32, #blocked> %10 = triton_gpu.convert_layout %9 : (tensor<32x32xf32, #blocked>) -> tensor<32x32xf32, #shared> %11 = triton_gpu.convert_layout %10 : (tensor<32x32xf32, #shared>) -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> - %12 = tt.dot %11, %cst_0, %cst {allowTF32 = true} : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> + %12 = tt.dot %11, %cst_0, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> %13 = triton_gpu.convert_layout %12 : (tensor<32x32xf32, #mma>) -> tensor<32x32xf32, #blocked> tt.store %arg0, %13 {cache = 1 : i32, evict = 1 : i32} : tensor<32x32xf32, #blocked> tt.return @@ -41,7 +41,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war %A = tt.load %arg0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32xf32, #blocked> %AS = triton_gpu.convert_layout %A : (tensor<32x32xf32, #blocked>) -> tensor<32x32xf32, #shared> %AD = triton_gpu.convert_layout %AS : (tensor<32x32xf32, #shared>) -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> - %12 = tt.dot %AD, %BD, %cst {allowTF32 = true} : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> + %12 = tt.dot %AD, %BD, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> %13 = triton_gpu.convert_layout %12 : (tensor<32x32xf32, #mma>) -> tensor<32x32xf32, #blocked> tt.store %arg0, %13 {cache = 1 : i32, evict = 1 : i32} : tensor<32x32xf32, #blocked> tt.return diff --git a/test/TritonGPU/rewrite-tensor-pointer.mlir b/test/TritonGPU/rewrite-tensor-pointer.mlir index 23eddb24b536..cfe46c787fa8 100644 --- a/test/TritonGPU/rewrite-tensor-pointer.mlir +++ b/test/TritonGPU/rewrite-tensor-pointer.mlir @@ -46,7 +46,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c %30 = triton_gpu.convert_layout %28 : (tensor<128x64xf16, #blocked>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked2}>> %31 = triton_gpu.convert_layout %29 : (tensor<64x128xf16, #blocked1>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked2}>> %32 = triton_gpu.convert_layout %arg12 : (tensor<128x128xf32, #blocked>) -> tensor<128x128xf32, #blocked2> - %33 = tt.dot %30, %31, %32 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked2}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked2}>> -> tensor<128x128xf32, #blocked2> + %33 = tt.dot %30, %31, %32 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked2}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked2}>> -> tensor<128x128xf32, #blocked2> %34 = triton_gpu.convert_layout %33 : (tensor<128x128xf32, #blocked2>) -> tensor<128x128xf32, #blocked> // CHECK-NOT: tt.advance %35 = tt.advance %arg13, [%c0_i32, %c64_i32] : , 1> diff --git a/test/TritonGPU/wsdecomposing.mlir b/test/TritonGPU/wsdecomposing.mlir index 059554a59195..7d89baa8d711 100644 --- a/test/TritonGPU/wsdecomposing.mlir +++ b/test/TritonGPU/wsdecomposing.mlir @@ -97,7 +97,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %90 = tt.load %arg12 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #blocked> %91 = triton_gpu.convert_layout %89 : (tensor<128x32xf16, #blocked1>) -> tensor<128x32xf16, #shared> %92 = triton_gpu.convert_layout %90 : (tensor<32x128xf16, #blocked>) -> tensor<32x128xf16, #shared1> - %93 = tt.dot %91, %92, %arg10 {allowTF32 = true} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma> + %93 = tt.dot %91, %92, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma> %94 = tt.addptr %arg11, %cst_1 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> %95 = tt.addptr %arg12, %cst_0 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> scf.yield %93, %94, %95 : tensor<128x128xf32, #mma>, tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked> @@ -208,7 +208,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-NEXT: %90 = tt.load %arg12 {async_agent = dense<0> : vector<1xi32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #blocked> // CHECK-NEXT: %91 = triton_gpu.convert_layout %89 {async_agent = dense<1> : vector<1xi32>} : (tensor<128x32xf16, #blocked1>) -> tensor<128x32xf16, #shared> // CHECK-NEXT: %92 = triton_gpu.convert_layout %90 {async_agent = dense<1> : vector<1xi32>} : (tensor<32x128xf16, #blocked>) -> tensor<32x128xf16, #shared1> - // CHECK-NEXT: %93 = tt.dot %91, %92, %arg10 {allowTF32 = true, async_agent = dense<1> : vector<1xi32>} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma> + // CHECK-NEXT: %93 = tt.dot %91, %92, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma> // CHECK-NEXT: %94 = tt.addptr %arg11, %cst_1 {async_agent = dense<0> : vector<1xi32>} : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> // CHECK-NEXT: %95 = tt.addptr %arg12, %cst_0 {async_agent = dense<0> : vector<1xi32>} : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> // CHECK-NEXT: scf.yield {async_agent = dense<[0, 1]> : vector<2xi32>} %93, %94, %95 : tensor<128x128xf32, #mma>, tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked> @@ -336,7 +336,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %92 = triton_gpu.convert_layout %90 : (tensor<32x128xf16, #blocked>) -> tensor<32x128xf16, #shared1> scf.yield %91, %92 : tensor<128x32xf16, #shared>, tensor<32x128xf16, #shared1> } - %93 = tt.dot %96#0, %96#1, %arg10 {allowTF32 = true} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma> + %93 = tt.dot %96#0, %96#1, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma> %94 = tt.addptr %arg11, %cst_1 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> %95 = tt.addptr %arg12, %cst_0 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> scf.yield %93, %94, %95 : tensor<128x128xf32, #mma>, tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked> @@ -452,7 +452,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-NEXT: %96 = triton_gpu.convert_layout %94 {async_agent = dense<1> : vector<1xi32>} : (tensor<32x128xf16, #blocked>) -> tensor<32x128xf16, #shared1> // CHECK-NEXT: scf.yield {async_agent = dense<[0, 1]> : vector<2xi32>} %95, %96 : tensor<128x32xf16, #shared>, tensor<32x128xf16, #shared1> // CHECK-NEXT: } {async_agent = dense<[0, 1]> : vector<2xi32>} - // CHECK-NEXT: %90 = tt.dot %89#0, %89#1, %arg10 {allowTF32 = true, async_agent = dense<1> : vector<1xi32>} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma> + // CHECK-NEXT: %90 = tt.dot %89#0, %89#1, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma> // CHECK-NEXT: %91 = tt.addptr %arg11, %cst_1 {async_agent = dense<0> : vector<1xi32>} : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> // CHECK-NEXT: %92 = tt.addptr %arg12, %cst_0 {async_agent = dense<0> : vector<1xi32>} : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> // CHECK-NEXT: scf.yield {async_agent = dense<[0, 1]> : vector<2xi32>} %90, %91, %92 : tensor<128x128xf32, #mma>, tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked> @@ -587,7 +587,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %90 = tt.load %arg12 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #blocked> %91 = triton_gpu.convert_layout %89 : (tensor<128x32xf16, #blocked1>) -> tensor<128x32xf16, #shared> %92 = triton_gpu.convert_layout %90 : (tensor<32x128xf16, #blocked>) -> tensor<32x128xf16, #shared1> - %93 = tt.dot %91, %92, %arg10 {allowTF32 = true} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma> + %93 = tt.dot %91, %92, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma> %base_94 = scf.if %96 -> (tensor<128x32x!tt.ptr, #blocked1>) { %r1_0 = arith.select %96, %c31_i32, %c127_i32 : i32 %r1_1 = tt.splat %r1_0 : (i32) -> tensor<128x32xi32, #blocked1> @@ -717,7 +717,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-NEXT: %92 = tt.load %arg12 {async_agent = dense<0> : vector<1xi32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #blocked> // CHECK-NEXT: %93 = triton_gpu.convert_layout %91 {async_agent = dense<1> : vector<1xi32>} : (tensor<128x32xf16, #blocked1>) -> tensor<128x32xf16, #shared> // CHECK-NEXT: %94 = triton_gpu.convert_layout %92 {async_agent = dense<1> : vector<1xi32>} : (tensor<32x128xf16, #blocked>) -> tensor<32x128xf16, #shared1> - // CHECK-NEXT: %95 = tt.dot %93, %94, %arg10 {allowTF32 = true, async_agent = dense<1> : vector<1xi32>} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma> + // CHECK-NEXT: %95 = tt.dot %93, %94, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma> // CHECK-NEXT: %96 = scf.if %90 -> (tensor<128x32x!tt.ptr, #blocked1>) { // CHECK-NEXT: %99 = arith.select %90, %c31_i32, %c127_i32 {async_agent = dense<1> : vector<1xi32>} : i32 // CHECK-NEXT: %100 = tt.splat %99 {async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<128x32xi32, #blocked1> diff --git a/test/TritonGPU/wsmaterialization.mlir b/test/TritonGPU/wsmaterialization.mlir index 4ab8be6c5d96..07ee80f9b734 100644 --- a/test/TritonGPU/wsmaterialization.mlir +++ b/test/TritonGPU/wsmaterialization.mlir @@ -177,7 +177,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %63 = triton_gpu.extract_slice %0[%arg11, 0, 0] [1, 32, 128] [1, 1, 1] {async_agent = dense<1> : vector<1xi32>} : tensor<3x32x128xf16, #shared> to tensor<32x128xf16, #shared> %64 = triton_gpu.convert_layout %62 {async_agent = dense<1> : vector<1xi32>} : (tensor<128x32xf16, #shared1>) -> tensor<128x32xf16, #shared1> %65 = triton_gpu.convert_layout %63 {async_agent = dense<1> : vector<1xi32>} : (tensor<32x128xf16, #shared>) -> tensor<32x128xf16, #shared> - %66 = tt.dot %64, %65, %arg10 {allowTF32 = true, async_agent = dense<1> : vector<1xi32>} : tensor<128x32xf16, #shared1> * tensor<32x128xf16, #shared> -> tensor<128x128xf32, #mma> + %66 = tt.dot %64, %65, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<128x32xf16, #shared1> * tensor<32x128xf16, #shared> -> tensor<128x128xf32, #mma> %c1_i32_2 = arith.constant {async_agent = dense<1> : vector<1xi32>} 1 : i32 %c3_i32 = arith.constant {async_agent = dense<1> : vector<1xi32>} 3 : i32 %67 = arith.addi %arg11, %c1_i32_2 {async_agent = dense<1> : vector<1xi32>} : i32 @@ -384,7 +384,7 @@ module attributes {"async.num-agents" = 2 : i32, "triton_gpu.compute-capability" %50 = triton_gpu.convert_layout %49 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : (tensor<64x16xf16, #shared>) -> tensor<64x16xf16, #shared> %51 = triton_gpu.extract_slice %1[%48, 0, 0] [1, 16, 64] [1, 1, 1] {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<3x16x64xf16, #shared1> to tensor<16x64xf16, #shared1> %52 = triton_gpu.convert_layout %51 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : (tensor<16x64xf16, #shared1>) -> tensor<16x64xf16, #shared1> - %53 = tt.dot %50, %52, %arg12 {agent.mutex_role = 0 : i32, allowTF32 = true, async_agent = dense<1> : vector<1xi32>} : tensor<64x16xf16, #shared> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma> + %53 = tt.dot %50, %52, %arg12 {agent.mutex_role = 0 : i32, allowTF32 = true, maxNumImpreciseAcc = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<64x16xf16, #shared> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma> triton_nvidia_gpu.consumer_release %2, %48 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32 %c1_i32_6 = arith.constant {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} 1 : i32 %54 = arith.addi %arg13, %c1_i32_6 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32 diff --git a/test/TritonGPU/wsmutex.mlir b/test/TritonGPU/wsmutex.mlir index 78b9037c51fa..1c0ad771218e 100644 --- a/test/TritonGPU/wsmutex.mlir +++ b/test/TritonGPU/wsmutex.mlir @@ -141,7 +141,7 @@ module attributes {"async.num-agents" = 2 : i32, "triton_gpu.compute-capability" %40 = triton_gpu.convert_layout %39 {async_agent = dense<1> : vector<1xi32>} : (tensor<64x16xf16, #shared>) -> tensor<64x16xf16, #shared> %41 = triton_gpu.extract_slice %1[%38, 0, 0] [1, 16, 64] [1, 1, 1] {async_agent = dense<1> : vector<1xi32>} : tensor<3x16x64xf16, #shared1> to tensor<16x64xf16, #shared1> %42 = triton_gpu.convert_layout %41 {async_agent = dense<1> : vector<1xi32>} : (tensor<16x64xf16, #shared1>) -> tensor<16x64xf16, #shared1> - %43 = tt.dot %40, %42, %arg12 {allowTF32 = true, async_agent = dense<1> : vector<1xi32>} : tensor<64x16xf16, #shared> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma> + %43 = tt.dot %40, %42, %arg12 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<64x16xf16, #shared> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma> triton_nvidia_gpu.consumer_release %2, %38 {async_agent = dense<1> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32 %c1_i32_5 = arith.constant {async_agent = dense<1> : vector<1xi32>} 1 : i32 %44 = arith.addi %arg13, %c1_i32_5 {async_agent = dense<1> : vector<1xi32>} : i32 diff --git a/test/TritonGPU/wspipeline.mlir b/test/TritonGPU/wspipeline.mlir index a08e6fe1ef7c..a42ca46d009e 100644 --- a/test/TritonGPU/wspipeline.mlir +++ b/test/TritonGPU/wspipeline.mlir @@ -120,7 +120,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %90 = tt.load %arg12 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #blocked> %91 = triton_gpu.convert_layout %89 : (tensor<128x32xf16, #blocked1>) -> tensor<128x32xf16, #shared> %92 = triton_gpu.convert_layout %90 : (tensor<32x128xf16, #blocked>) -> tensor<32x128xf16, #shared1> - %93 = tt.dot %91, %92, %arg10 {allowTF32 = true} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma> + %93 = tt.dot %91, %92, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma> %94 = tt.addptr %arg11, %cst_1 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> %95 = tt.addptr %arg12, %cst_0 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> scf.yield %93, %94, %95 : tensor<128x128xf32, #mma>, tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked> diff --git a/test/TritonNvidiaGPU/ws-feasibility-checking.mlir b/test/TritonNvidiaGPU/ws-feasibility-checking.mlir index 0eec6889f8f7..981d4748d7cd 100644 --- a/test/TritonNvidiaGPU/ws-feasibility-checking.mlir +++ b/test/TritonNvidiaGPU/ws-feasibility-checking.mlir @@ -96,7 +96,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %90 = tt.load %arg12 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #blocked> %91 = triton_gpu.convert_layout %89 : (tensor<128x32xf16, #blocked1>) -> tensor<128x32xf16, #shared> %92 = triton_gpu.convert_layout %90 : (tensor<32x128xf16, #blocked>) -> tensor<32x128xf16, #shared1> - %93 = tt.dot %91, %92, %arg10 {allowTF32 = true} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma> + %93 = tt.dot %91, %92, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma> %94 = tt.addptr %arg11, %cst_1 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> %95 = tt.addptr %arg12, %cst_0 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> scf.yield %93, %94, %95 : tensor<128x128xf32, #mma>, tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked> @@ -226,7 +226,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %92 = triton_gpu.convert_layout %90 : (tensor<32x128xf16, #blocked>) -> tensor<32x128xf16, #shared1> scf.yield %91, %92 : tensor<128x32xf16, #shared>, tensor<32x128xf16, #shared1> } - %93 = tt.dot %96#0, %96#1, %arg10 {allowTF32 = true} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma> + %93 = tt.dot %96#0, %96#1, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma> %94 = tt.addptr %arg11, %cst_1 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> %95 = tt.addptr %arg12, %cst_0 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> scf.yield %93, %94, %95 : tensor<128x128xf32, #mma>, tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked> @@ -362,7 +362,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %90 = tt.load %arg12 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #blocked> %91 = triton_gpu.convert_layout %89 : (tensor<128x32xf16, #blocked1>) -> tensor<128x32xf16, #shared> %92 = triton_gpu.convert_layout %90 : (tensor<32x128xf16, #blocked>) -> tensor<32x128xf16, #shared1> - %93 = tt.dot %91, %92, %arg10 {allowTF32 = true} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma> + %93 = tt.dot %91, %92, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma> %base_94 = scf.if %96 -> (tensor<128x32x!tt.ptr, #blocked1>) { %r1_0 = arith.select %96, %c31_i32, %c127_i32 : i32 %r1_1 = tt.splat %r1_0 : (i32) -> tensor<128x32xi32, #blocked1> @@ -438,7 +438,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c %47 = tt.load %arg12 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr, 1> -> tensor<16x64xf16, #blocked3> %48 = triton_gpu.convert_layout %46 : (tensor<64x16xf16, #blocked2>) -> tensor<64x16xf16, #shared> %49 = triton_gpu.convert_layout %47 : (tensor<16x64xf16, #blocked3>) -> tensor<16x64xf16, #shared1> - %50 = tt.dot %48, %49, %arg10 {allowTF32 = true} : tensor<64x16xf16, #shared> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma> + %50 = tt.dot %48, %49, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<64x16xf16, #shared> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma> %51 = tt.advance %arg11, [%c0_i32, %c16_i32] : , 1> %52 = tt.advance %arg12, [%c16_i32, %c0_i32] : , 1> scf.yield %50, %51, %52 : tensor<64x64xf32, #mma>, !tt.ptr, 1>, !tt.ptr, 1> @@ -518,7 +518,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c %47 = tt.load %arg12 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr, 1> -> tensor<16x64xf16, #blocked3> %48 = triton_gpu.convert_layout %46 : (tensor<64x16xf16, #blocked2>) -> tensor<64x16xf16, #shared> %49 = triton_gpu.convert_layout %47 : (tensor<16x64xf16, #blocked3>) -> tensor<16x64xf16, #shared1> - %50 = tt.dot %48, %49, %arg10 {allowTF32 = true} : tensor<64x16xf16, #shared> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma> + %50 = tt.dot %48, %49, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<64x16xf16, #shared> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma> %51 = tt.advance %arg11, [%c0_i32, %c16_i32] : , 1> %52 = tt.advance %arg12, [%c16_i32, %c0_i32] : , 1> scf.yield %50, %51, %52 : tensor<64x64xf32, #mma>, !tt.ptr, 1>, !tt.ptr, 1> @@ -600,7 +600,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c %47 = tt.load %arg12 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr, 1> -> tensor<16x64xf16, #blocked3> %48 = triton_gpu.convert_layout %46 : (tensor<64x16xf16, #blocked2>) -> tensor<64x16xf16, #shared> %49 = triton_gpu.convert_layout %47 : (tensor<16x64xf16, #blocked3>) -> tensor<16x64xf16, #shared1> - %50 = tt.dot %48, %49, %arg10 {allowTF32 = true} : tensor<64x16xf16, #shared> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma> + %50 = tt.dot %48, %49, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<64x16xf16, #shared> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma> %51 = tt.advance %arg11, [%c0_i32, %c16_i32] : , 1> %52 = tt.advance %arg12, [%c16_i32, %c0_i32] : , 1> scf.yield %50, %51, %52 : tensor<64x64xf32, #mma>, !tt.ptr, 1>, !tt.ptr, 1> @@ -686,7 +686,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c %47 = tt.load %arg12 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr, 1> -> tensor<16x64xf16, #blocked3> %48 = triton_gpu.convert_layout %46 : (tensor<64x16xf16, #blocked2>) -> tensor<64x16xf16, #shared> %49 = triton_gpu.convert_layout %47 : (tensor<16x64xf16, #blocked3>) -> tensor<16x64xf16, #shared1> - %50 = tt.dot %48, %49, %arg10 {allowTF32 = true} : tensor<64x16xf16, #shared> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma> + %50 = tt.dot %48, %49, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<64x16xf16, #shared> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma> %51 = tt.advance %arg11, [%c0_i32, %c16_i32] : , 1> %52 = tt.advance %arg12, [%c16_i32, %c0_i32] : , 1> scf.yield %50, %51, %52 : tensor<64x64xf32, #mma>, !tt.ptr, 1>, !tt.ptr, 1> @@ -799,7 +799,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c %44 = tt.load %arg17 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr, 1> -> tensor<64x128xf16, #blocked4> %45 = triton_gpu.convert_layout %43 : (tensor<256x64xf16, #blocked3>) -> tensor<256x64xf16, #shared> %46 = triton_gpu.convert_layout %44 : (tensor<64x128xf16, #blocked4>) -> tensor<64x128xf16, #shared1> - %47 = tt.dot %45, %46, %arg15 {allowTF32 = true} : tensor<256x64xf16, #shared> * tensor<64x128xf16, #shared1> -> tensor<256x128xf32, #mma> + %47 = tt.dot %45, %46, %arg15 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<256x64xf16, #shared> * tensor<64x128xf16, #shared1> -> tensor<256x128xf32, #mma> %48 = tt.advance %arg16, [%c0_i32, %c64_i32] : , 1> %49 = tt.advance %arg17, [%c64_i32, %c0_i32] : , 1> scf.yield %47, %48, %49 : tensor<256x128xf32, #mma>, !tt.ptr, 1>, !tt.ptr, 1> @@ -852,7 +852,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c %b = tt.load %arg1 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr, 1> -> tensor<64x128xf16, #blocked4> %shm_a = triton_gpu.convert_layout %a : (tensor<256x64xf16, #blocked3>) -> tensor<256x64xf16, #shared> %shm_b = triton_gpu.convert_layout %b : (tensor<64x128xf16, #blocked4>) -> tensor<64x128xf16, #shared1> - %d = tt.dot %shm_a, %shm_b, %cst {allowTF32 = true} : tensor<256x64xf16, #shared> * tensor<64x128xf16, #shared1> -> tensor<256x128xf32, #mma> + %d = tt.dot %shm_a, %shm_b, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<256x64xf16, #shared> * tensor<64x128xf16, #shared1> -> tensor<256x128xf32, #mma> %out = triton_gpu.convert_layout %d : (tensor<256x128xf32, #mma>) -> tensor<256x128xf32, #blocked2> tt.store %arg2, %out {cache = 1 : i32, evict = 1 : i32} : tensor<256x128xf32, #blocked2> } @@ -887,7 +887,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c %b = tt.load %arg1 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x128x!tt.ptr, #blocked4> -> tensor<64x128xf16, #blocked4> %shm_a = triton_gpu.convert_layout %a : (tensor<256x64xf16, #blocked3>) -> tensor<256x64xf16, #shared> %shm_b = triton_gpu.convert_layout %b : (tensor<64x128xf16, #blocked4>) -> tensor<64x128xf16, #shared1> - %d = tt.dot %shm_a, %shm_b, %cst {allowTF32 = true} : tensor<256x64xf16, #shared> * tensor<64x128xf16, #shared1> -> tensor<256x128xf32, #mma> + %d = tt.dot %shm_a, %shm_b, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<256x64xf16, #shared> * tensor<64x128xf16, #shared1> -> tensor<256x128xf32, #mma> %out = triton_gpu.convert_layout %d : (tensor<256x128xf32, #mma>) -> tensor<256x128xf32, #blocked2> tt.store %arg2, %out {cache = 1 : i32, evict = 1 : i32} : tensor<256x128xf32, #blocked2> } @@ -1000,7 +1000,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c %92 = tt.load %arg19 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr, 1> -> tensor<16x64xf16, #blocked4> %93 = triton_gpu.convert_layout %91 : (tensor<64x16xf16, #blocked3>) -> tensor<64x16xf16, #shared> %94 = triton_gpu.convert_layout %92 : (tensor<16x64xf16, #blocked4>) -> tensor<16x64xf16, #shared1> - %95 = tt.dot %93, %94, %arg17 {allowTF32 = true} : tensor<64x16xf16, #shared> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma> + %95 = tt.dot %93, %94, %arg17 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<64x16xf16, #shared> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma> %96 = tt.advance %arg18, [%c0_i32, %c16_i32] : , 1> %97 = tt.advance %arg19, [%c16_i32, %c0_i32] : , 1> scf.yield %95, %96, %97 : tensor<64x64xf32, #mma>, !tt.ptr, 1>, !tt.ptr, 1> From bb949d1141d154ba2d493a0091e2187806b0cb7a Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Sat, 16 Sep 2023 12:28:53 -0700 Subject: [PATCH 050/122] [BACKEND] Move struct optimization down the LLVM pipeline (#2312) Move the optimization to remove phi of struct later in the optimization pipeline to avoid interfering with CFG optimization. --- lib/Target/LLVMIR/LLVMIRTranslation.cpp | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/lib/Target/LLVMIR/LLVMIRTranslation.cpp b/lib/Target/LLVMIR/LLVMIRTranslation.cpp index 3acc6a92e09c..cc07f0fe117a 100644 --- a/lib/Target/LLVMIR/LLVMIRTranslation.cpp +++ b/lib/Target/LLVMIR/LLVMIRTranslation.cpp @@ -122,12 +122,14 @@ makeOptimizingPipeline(unsigned optLevel, unsigned sizeLevel, pb.crossRegisterProxies(lam, fam, cgam, mam); ModulePassManager mpm; - llvm::FunctionPassManager fpm; - // Triton generates large structure of scalars which may pessimise - // optimizations, we run a pass to break up phi of struct to make sure all - // the struct are removed for the following passes. - fpm.addPass(BreakStructPhiNodesPass()); - mpm.addPass(createModuleToFunctionPassAdaptor(std::move(fpm))); + pb.registerVectorizerStartEPCallback( + [&](llvm::FunctionPassManager &fpm, llvm::OptimizationLevel level) { + // Triton generates large structure of scalars which may pessimise + // optimizations, we run a pass to break up phi of struct to make sure + // all the struct are removed for the following passes. + fpm.addPass(BreakStructPhiNodesPass()); + fpm.addPass(InstCombinePass()); + }); mpm.addPass(pb.buildPerModuleDefaultPipeline(*ol)); mpm.run(*m, mam); return Error::success(); From 41584c71a64b5158d9dc358d055ca7c6a08f62cf Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Sat, 16 Sep 2023 18:53:46 -0700 Subject: [PATCH 051/122] Add cuobjdump and nvsisasm to gitignore. (#2319) Otherwise, these files show up in `git status` under python/triton/third_party/cuda/bin/. --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index cd3d84ead5a9..fed9cbf4ea64 100644 --- a/.gitignore +++ b/.gitignore @@ -26,6 +26,8 @@ venv.bak/ cmake-build-* # Third-party binaries +cuobjdump +nvdisasm ptxas # Docs From 0015611c17806ed7df9f238b414c3ca66f856e1b Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Sun, 17 Sep 2023 00:14:33 -0700 Subject: [PATCH 052/122] [DOCS] Add build instrs for running in a virtualenv. (#2320) On my machine, when I try to `pip install cmake` outside a virtualenv, it gets mad at me and tells me to use apt. Which doesn't quite work for some reason. Anyway maybe this is simple to Python people, but perhaps worth mentioning. Especially because we have `.venv` in gitignore already. --- README.md | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index ee38c164869e..707c6d42f16a 100644 --- a/README.md +++ b/README.md @@ -62,12 +62,24 @@ pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/ ``` git clone https://github.com/openai/triton.git; -cd triton/python; +cd triton; + pip install cmake; # build-time dependency -pip install -e . +pip install -e python ``` +Or with a virtualenv: + +``` +git clone https://github.com/openai/triton.git; +cd triton; +python -m venv .venv --prompt triton; +source .venv/bin/activate; + +pip install cmake; # build-time dependency +pip install -e python +``` # Changelog From c98671cf7c5d2ed03ea039f6acbc9486cd0bcd87 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sun, 17 Sep 2023 01:16:00 -0700 Subject: [PATCH 053/122] Revert "Update integration-tests.yml" (#2323) reverts #2310 as recent changes to Triton-IR have broken third-party backends --- .github/workflows/integration-tests.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 74967580b4e3..097c55e590d1 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -162,6 +162,7 @@ jobs: Integration-Tests-Third-Party: needs: Runner-Preparation + if: false runs-on: ${{ matrix.runner }} From 073aa16379bd7a832fe79dd126a6aaa786eef890 Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Sun, 17 Sep 2023 02:08:04 -0700 Subject: [PATCH 054/122] [BUILD] use ninja (#2318) --- .github/workflows/integration-tests.yml | 1 + README.md | 4 ++-- python/pyproject.toml | 2 +- python/setup.py | 6 +++--- 4 files changed, 7 insertions(+), 6 deletions(-) diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 097c55e590d1..2ddb8ec82c91 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -74,6 +74,7 @@ jobs: cd python python3 -m pip install --upgrade pip python3 -m pip install cmake==3.24 + python3 -m pip install ninja python3 -m pip install --no-build-isolation -vvv '.[tests]' python3 -m pip install pytest-xdist diff --git a/README.md b/README.md index 707c6d42f16a..baec97a68e86 100644 --- a/README.md +++ b/README.md @@ -64,7 +64,7 @@ pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/ git clone https://github.com/openai/triton.git; cd triton; -pip install cmake; # build-time dependency +pip install ninja cmake; # build-time dependencies pip install -e python ``` @@ -77,7 +77,7 @@ cd triton; python -m venv .venv --prompt triton; source .venv/bin/activate; -pip install cmake; # build-time dependency +pip install ninja cmake; # build-time dependencies pip install -e python ``` diff --git a/python/pyproject.toml b/python/pyproject.toml index 6430c0c154dc..8bd8093e720d 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -1,6 +1,6 @@ [build-system] -requires = ["setuptools>=40.8.0", "wheel", "cmake>=3.18"] +requires = ["setuptools>=40.8.0", "wheel", "cmake>=3.18", "ninja>=1.11.1"] [tool.autopep8] aggressive = 1 diff --git a/python/setup.py b/python/setup.py index e2c8d9ff96c6..9578249cb277 100644 --- a/python/setup.py +++ b/python/setup.py @@ -127,13 +127,10 @@ def get_thirdparty_packages(triton_cache_path): def download_and_copy(src_path, version, url_func): base_dir = os.path.dirname(__file__) - # src_path = "bin/ptxas" - # version = "12.1.105" arch = platform.machine() if arch == "x86_64": arch = "64" url = url_func(arch, version) - # url = f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-nvcc-{version}-0.tar.bz2" dst_prefix = os.path.join(base_dir, "triton") dst_suffix = os.path.join("third_party", "cuda", src_path) dst_path = os.path.join(dst_prefix, dst_suffix) @@ -219,6 +216,7 @@ def run(self): def build_extension(self, ext): lit_dir = shutil.which('lit') + ninja_dir = shutil.which('ninja') user_home = os.getenv("HOME") or os.getenv("USERPROFILE") or \ os.getenv("HOMEPATH") or None if not user_home: @@ -233,6 +231,8 @@ def build_extension(self, ext): # python directories python_include_dir = sysconfig.get_path("platinclude") cmake_args = [ + "-G", "Ninja", # Ninja is much faster than make + "-DCMAKE_MAKE_PROGRAM=" + ninja_dir, # Pass explicit path to ninja otherwise cmake may cache a temporary path "-DCMAKE_EXPORT_COMPILE_COMMANDS=ON", "-DLLVM_ENABLE_WERROR=ON", "-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir, From 4f2d995fad5f02e115202873859fd96fe6096fca Mon Sep 17 00:00:00 2001 From: jon-chuang <9093549+jon-chuang@users.noreply.github.com> Date: Sun, 17 Sep 2023 05:15:06 -0400 Subject: [PATCH 055/122] [FRONTEND] Explicitly forbid `dot(.., out_dtype=bfloat16)` (#2308) Fixes: https://github.com/openai/triton/issues/2302 --- python/triton/language/semantic.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 2b2b55136395..019919629a40 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1311,6 +1311,8 @@ def assert_dtypes_valid(lhs_dtype, rhs_dtype, arch): assert lhs.shape[1].value >= 32, "small blocks not supported!" _0 = builder.get_int32(0) ret_scalar_ty = tl.int32 + elif out_dtype.is_bf16(): + raise ValueError("out_dtype=bfloat16 is unsupported. Please use out_dtype=float32/float16 and cast with `.to(tl.bfloat16)`") elif lhs.type.scalar.is_fp32() or lhs.type.scalar.is_bf16(): _0 = builder.get_fp32(0) ret_scalar_ty = tl.float32 From 68e1bd162cc40ec36b066d790b45760a7438441c Mon Sep 17 00:00:00 2001 From: Stonepia Date: Mon, 18 Sep 2023 00:19:14 +0800 Subject: [PATCH 056/122] [FRONTEND] fix xpu stages logic (#2305) --- python/triton/compiler/compiler.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index bfffb9e584af..85c7b460a0ae 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -405,12 +405,6 @@ def compile(fn, **kwargs): add_cuda_stages(arch, extern_libs, stages) elif device_type == "hip": _device_backend.add_stages(arch, extern_libs, stages, num_warps=num_warps, num_stages=num_stages) - elif device_type == "xpu": - stages["ttgir"] = (lambda path: parse_mlir_module(path, context), - lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps, num_ctas, arch), num_stages, num_warps, num_ctas, arch, cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue)) - stages["llir"] = (lambda path: Path(path).read_text(), - lambda src: ttgir_to_llir(src, extern_libs, arch, tma_infos)) - _device_backend.add_stages(arch, extern_libs, stages) else: # pass the user's configuration to the backend device. arch["num_warps"] = num_warps From 2b066000aa2af1fe5c2745d8b0ccee2d93106cdd Mon Sep 17 00:00:00 2001 From: Myeonghwan Ahn Date: Mon, 18 Sep 2023 01:41:02 +0900 Subject: [PATCH 057/122] [FRONTEND] fix matmul int8 overflow issue (#2297) Previously on matmul, if inputs are int8, output was also int8. This commit fixes the overflow problem with int32 output. #2296 --- python/triton/ops/matmul.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/triton/ops/matmul.py b/python/triton/ops/matmul.py index 4dc33e6b9ccf..9bbeb3650ac5 100644 --- a/python/triton/ops/matmul.py +++ b/python/triton/ops/matmul.py @@ -159,6 +159,8 @@ def _call(a, b, dot_out_dtype, allow_tf32, fp8_fast_accum): if a.dtype in [tl.float8e4nv, tl.float8e4b15, tl.float8e5] or\ b.dtype in [tl.float8e4nv, tl.float8e4b15, tl.float8e5]: c_dtype = torch.float16 + elif a.dtype in [torch.int8] or b.dtype in [torch.int8]: + c_dtype = torch.int32 else: c_dtype = get_higher_dtype(a.dtype, b.dtype) c = torch.empty((M, N), device=device, dtype=c_dtype) @@ -178,6 +180,8 @@ def _call(a, b, dot_out_dtype, allow_tf32, fp8_fast_accum): ab_dtype = True if a.dtype in [tl.float8e4nv, tl.float8e5] and b.dtype in [tl.float8e4nv, tl.float8e5]: ab_dtype = False + if a.dtype in [torch.int8] and b.dtype in [torch.int8]: + ab_dtype = False # launch kernel grid = lambda META: (cdiv(M, META['BLOCK_M']) * cdiv(N, META['BLOCK_N']), META['SPLIT_K']) _kernel[grid](a, b, c, M, N, K, From e686b4d6d48fe9249a107ea513ed1b4e5155c272 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sun, 17 Sep 2023 14:58:50 -0700 Subject: [PATCH 058/122] [FRONTEND] interpreter rewrite (#2321) This is a new interpreter mode that shares semantic analysis with the JIT'ed codepath and that the Triton core team is committed to maintain --- .github/workflows/integration-tests.yml | 9 + python/setup.py | 5 +- python/src/triton.cc | 42 ++ .../test/unit/interpreter/test_interpreter.py | 69 -- python/test/unit/language/test_core.py | 1 - .../unit/operators/test_flash_attention.py | 11 +- python/triton/interpreter/__init__.py | 0 python/triton/interpreter/core.py | 9 - python/triton/interpreter/interpreter.py | 171 ----- python/triton/interpreter/memory_map.py | 102 --- python/triton/interpreter/tl_lang.py | 641 ------------------ python/triton/interpreter/torch_wrapper.py | 18 - python/triton/language/core.py | 14 +- python/triton/language/semantic.py | 2 + python/triton/language/standard.py | 2 +- python/triton/runtime/interpreter.py | 525 ++++++++++++++ python/triton/runtime/jit.py | 11 +- 17 files changed, 599 insertions(+), 1033 deletions(-) delete mode 100644 python/test/unit/interpreter/test_interpreter.py delete mode 100644 python/triton/interpreter/__init__.py delete mode 100644 python/triton/interpreter/core.py delete mode 100644 python/triton/interpreter/interpreter.py delete mode 100644 python/triton/interpreter/memory_map.py delete mode 100644 python/triton/interpreter/tl_lang.py delete mode 100644 python/triton/interpreter/torch_wrapper.py create mode 100644 python/triton/runtime/interpreter.py diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 2ddb8ec82c91..c5b1cf2101c7 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -33,6 +33,7 @@ jobs: echo '::set-output name=matrix-optional::["ubuntu-latest"]' fi + Integration-Tests-Nvidia: needs: Runner-Preparation @@ -119,6 +120,14 @@ jobs: run: | rm -rf ~/.triton + - name: Run interpreter tests + env: + # TRITON_INTERPRET: "1" + CUA_VISIBLE_DEVICES: "" + run: | + cd python/test/unit + python3 -m pytest -vs operators/test_flash_attention.py + - name: Run partial tests on CUDA with ENABLE_TMA=1 and ENABLE_MMA_V3=1 if: ${{ env.BACKEND == 'CUDA' && env.ENABLE_TMA == '1' && env.ENABLE_MMA_V3 == '1'}} run: | diff --git a/python/setup.py b/python/setup.py index 9578249cb277..0d7bf594aff3 100644 --- a/python/setup.py +++ b/python/setup.py @@ -59,8 +59,8 @@ class Package(NamedTuple): def get_pybind11_package_info(): - name = "pybind11-2.10.0" - url = "https://github.com/pybind/pybind11/archive/refs/tags/v2.10.0.tar.gz" + name = "pybind11-2.11.1" + url = "https://github.com/pybind/pybind11/archive/refs/tags/v2.11.1.tar.gz" return Package("pybind11", name, url, "PYBIND11_INCLUDE_DIR", "", "PYBIND11_SYSPATH") # llvm @@ -296,7 +296,6 @@ def build_extension(self, ext): "triton/_C", "triton/common", "triton/compiler", - "triton/interpreter", "triton/language", "triton/language/extra", "triton/ops", diff --git a/python/src/triton.cc b/python/src/triton.cc index 1bb6dae3cfd0..0068a23f8006 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -64,6 +64,7 @@ #include #include +#include namespace py = pybind11; PYBIND11_MAKE_OPAQUE(mlir::triton::gpu::TMAMetadataTy); @@ -1961,11 +1962,52 @@ void init_triton_translation(py::module &m) { }); } +void init_triton_interpreter(py::module &&m) { + using ret = py::return_value_policy; + + m.def("load", + [](py::array_t ptrs, py::array_t masks, py::array other, + py::dtype ret_dtype) -> py::array { + int numel = ptrs.size(); + auto shape = + std::vector(ptrs.shape(), ptrs.shape() + ptrs.ndim()); + py::array ret(ret_dtype, py::array::ShapeContainer{numel}); + py::array_t reshaped_ptrs = ptrs.reshape({numel}); + py::array_t reshaped_masks = masks.reshape({numel}); + py::array reshaped_others = other.reshape({numel}); + for (size_t i = 0; i < ptrs.size(); ++i) { + if (reshaped_masks.at(i)) + memcpy(ret.mutable_data(i), + reinterpret_cast(reshaped_ptrs.at(i)), + ret_dtype.itemsize()); + else + memcpy(ret.mutable_data(i), reshaped_others.data(i), + ret_dtype.itemsize()); + } + return ret.reshape(shape); + }); + + m.def("store", [](py::array_t ptrs, py::array values, + py::array_t mask) { + int numel = ptrs.size(); + py::array_t reshaped_ptrs = ptrs.reshape({numel}); + py::array_t reshaped_masks = mask.reshape({numel}); + py::array reshaped_values = values.reshape({numel}); + for (size_t i = 0; i < ptrs.size(); ++i) { + if (reshaped_masks.at(i)) { + memcpy(reinterpret_cast(reshaped_ptrs.mutable_at(i)), + reshaped_values.data(i), values.dtype().itemsize()); + } + } + }); +} + void init_triton(py::module &m) { py::module subm = m.def_submodule("triton"); init_triton_env_vars(subm); // init_triton_codegen(subm.def_submodule("code_gen")); init_triton_runtime(subm.def_submodule("runtime")); init_triton_ir(subm.def_submodule("ir")); + init_triton_interpreter(subm.def_submodule("interpreter")); init_triton_translation(subm); } diff --git a/python/test/unit/interpreter/test_interpreter.py b/python/test/unit/interpreter/test_interpreter.py deleted file mode 100644 index b6bb6b79c206..000000000000 --- a/python/test/unit/interpreter/test_interpreter.py +++ /dev/null @@ -1,69 +0,0 @@ -import random - -import torch - -import triton -import triton.language as tl -from triton.interpreter.interpreter import program_ids_from_grid - - -def test_addition(): - - @triton.jit(interpret=True) - def add_kernel( - x_ptr, - y_ptr, - output_ptr, - n_elements, - BLOCK_SIZE: tl.constexpr, - ): - pid = tl.program_id(axis=0) - block_start = pid * BLOCK_SIZE - offsets = block_start + tl.arange(0, BLOCK_SIZE) - mask = offsets < n_elements - x = tl.load(x_ptr + offsets, mask=mask) - y = tl.load(y_ptr + offsets, mask=mask) - output = x + y - tl.store(output_ptr + offsets, output, mask=mask) - - a = torch.rand((128,), device="cuda") - b = torch.rand((128,), device="cuda") - expected = a + b - output = torch.empty((128,), device="cuda") - - def grid(meta): - return (triton.cdiv(128, meta["BLOCK_SIZE"]),) - - add_kernel[grid](a, b, output, 128, BLOCK_SIZE=32) - - assert torch.allclose(expected, output, atol=1e-2, rtol=0) - - -def test_program_ids_from_grid(): - random.seed(123) - grid = (3, 4) - expected_combinations = 3 * 4 - unique_combinations = set(program_ids_from_grid(grid)) - assert len(unique_combinations) == expected_combinations - - first_run = list(program_ids_from_grid(grid)) - second_run = list(program_ids_from_grid(grid)) - assert first_run != second_run - - -def test_atomic(): - @triton.jit(interpret=True) - def atomic( - x_ptr, - ): - pid = tl.program_id(axis=0) - tl.atomic_add(x_ptr + pid, 1) - t = tl.atomic_xchg(x_ptr + pid, 3) - t += 1 # 2 - tl.atomic_cas(x_ptr + pid, 3, t) # match - tl.atomic_cas(x_ptr + pid, 40, 9) # no match - nb_dim = 16 - a = torch.zeros((nb_dim, ), dtype=torch.int32, device="cuda") - - atomic[(nb_dim, )](a) - assert torch.allclose(a, torch.full_like(a, 2)) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index b5da4c17a072..fe105e6ba398 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -2421,7 +2421,6 @@ def kernel(X, stride_xm, stride_xk, if epilogue == 'chain-dot': z_ref = np.matmul(z_ref, w) # compare - # print(z_ref[:,0], z_tri[:,0]) if in_dtype == 'float32': # XXX: Somehow there's a larger difference when we use float32 np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3) diff --git a/python/test/unit/operators/test_flash_attention.py b/python/test/unit/operators/test_flash_attention.py index b6f74f2fc33d..75da98e5044a 100644 --- a/python/test/unit/operators/test_flash_attention.py +++ b/python/test/unit/operators/test_flash_attention.py @@ -5,10 +5,10 @@ import triton.ops -@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(4, 48, 1024, 16), - (4, 48, 1024, 32), - (4, 48, 1024, 64), - (4, 48, 1024, 128)]) +@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(2, 4, 512, 16), + (2, 4, 512, 32), + (2, 4, 512, 64), + (2, 4, 512, 128)]) @pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) @pytest.mark.parametrize('causal', [True, False]) @pytest.mark.parametrize('seq_par', [True, False]) @@ -21,7 +21,8 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par): pytest.skip('Segmentation fault') capability = torch.cuda.get_device_capability() - if capability[0] < 8: + interpreter = os.environ.get("TRITON_INTERPRET", 'not found') in ["on", "true", "1"] + if not interpreter and capability[0] < 8: pytest.skip("Flash attention only supported for compute capability < 80") torch.manual_seed(20) q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() diff --git a/python/triton/interpreter/__init__.py b/python/triton/interpreter/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/python/triton/interpreter/core.py b/python/triton/interpreter/core.py deleted file mode 100644 index 82f3f43a25a0..000000000000 --- a/python/triton/interpreter/core.py +++ /dev/null @@ -1,9 +0,0 @@ -from typing import Tuple - -import dataclasses - - -@dataclasses.dataclass -class ExecutionContext: - program_id: Tuple[int] - program_size: Tuple[int] diff --git a/python/triton/interpreter/interpreter.py b/python/triton/interpreter/interpreter.py deleted file mode 100644 index 001b80ec9855..000000000000 --- a/python/triton/interpreter/interpreter.py +++ /dev/null @@ -1,171 +0,0 @@ -import itertools -import random -from typing import Tuple - -from .. import language as tl -# import .language.core as lcore -from ..language import core as lcore -from . import torch_wrapper -from .core import ExecutionContext -from .memory_map import MemoryMap -from .tl_lang import (TritonLangProxy, WrappedTensor, _primitive_to_tensor, - debugger_constexpr) - -torch = torch_wrapper.torch -tl_method_backup = {} - - -def get_proxy_method(proxy, name): - method = getattr(proxy, name) - - def fun(*args, **kwarg): - return method(*args, **kwarg) - - return fun - - -def attach_triton(module, proxy): - method_list = [func for func in dir(TritonLangProxy) if func[0] != "_"] - for name in method_list: - if hasattr(module, name): - attr = getattr(module, name) - tl_method_backup[name] = attr - if callable(attr): - setattr(module, name, get_proxy_method(proxy, name)) - else: - setattr(module, name, getattr(proxy, name)) - - -def detach_triton(module): - for name, method in tl_method_backup.items(): - setattr(module, name, method) - - -def program_ids_from_grid(grid: Tuple[int, ...]) -> Tuple[int, ...]: - # reverse the grid dimensions and generate the range for each dimension - reversed_grid = reversed(grid) - ranges_for_each_dimension = [range(dim) for dim in reversed_grid] - - # gen all combinations - index_combinations = list(itertools.product(*ranges_for_each_dimension)) - random.shuffle(index_combinations) - - for index_combination in index_combinations: - yield index_combination - - -class DebuggerFunction: - def __init__(self, func, grid=(1,)): - self.func = func - self.grid = grid - - def _is_constexpr(self, name): - return name in self.func.__annotations__ and self.func.__annotations__[name] is lcore.constexpr - - def _get_constexpr(self): - result = [] - for name, annotation in self.func.__annotations__.items(): - if annotation is lcore.constexpr: - result.append(name) - return result - - def _assert_constexpr(self, **kwargs): - constexp = self._get_constexpr() - missing = [i for i in constexp if i not in kwargs.keys()] - assert len(missing) == 0, f"You must specify constexpr {missing}" - - def _get_grid(self, **kwargs): - if callable(self.grid): - return self.grid(kwargs) - else: - return self.grid - - def __call__(self, *args, **kwargs): - self._assert_constexpr(**kwargs) - - memory = MemoryMap() - - def convert_arg(v): - name, arg = v - if torch.is_tensor(arg): - ptr = memory.add_tensor(arg) - return WrappedTensor(torch.tensor([ptr], dtype=torch.int64, device="cuda")) - if self._is_constexpr(name): - return debugger_constexpr(arg) - return WrappedTensor(_primitive_to_tensor(arg)) - - new_args = tuple(map(convert_arg, zip(self.func.__code__.co_varnames, args))) - new_kwargs = {k: convert_arg((k, v)) for (k, v) in kwargs.items() if k not in ["num_warps", "num_stages"]} - - grid = self._get_grid(**kwargs) - for program_id in program_ids_from_grid(grid): - proxy = TritonLangProxy(memory, ExecutionContext(program_id, grid)) - attach_triton(tl, proxy) - self.func(*new_args, **new_kwargs) - detach_triton(tl) - - -class GridSelector: - """ - Entry point of the debugger - """ - - def __init__(self, func): - version = torch.__version__ - assert version[0] == "2", f"Triton Debugger only supports torch >= 2.0, using {version}" - self.func = func - - def __getitem__(self, grid): - return DebuggerFunction(self.func, grid) - - def __call__(self, *args, **kwargs): - return DebuggerFunction(self.func)(*args, **kwargs) - - -class AutotuneGridSelector: - def __init__(self, func, autotune_params): - self.func = func - self.autotune_params = autotune_params - - def __getitem__(self, grid): - return AutotuneRunner(self.func, self.autotune_params, grid) - - def __call__(self, *args, **kwargs): - return AutotuneRunner(self.func, self.autotune_params)(*args, **kwargs) - - -class AutotuneRunner: - def __init__(self, func, autotune_params, grid=None): - self.func = func - self.autotune_params = autotune_params - self.grid = grid - - def __call__(self, *args, **kwargs): - assert len(self.autotune_params["configs"]) >= 1 - - for config in self.autotune_params["configs"][1:]: - - def convert_arg(v): - if torch.is_tensor(v): - return torch.clone(v) - return v - - new_args = tuple(map(convert_arg, args)) - new_kwargs = {k: convert_arg(v) for k, v in kwargs.items()} - if self.grid: - self.func[self.grid](*new_args, **new_kwargs, **config.kwargs) - else: - self.func(*new_args, **new_kwargs, **config.kwargs) - - main_config = self.autotune_params["configs"][0] - if self.grid: - self.func[self.grid](*args, **kwargs, **main_config.kwargs) - else: - self.func(*args, **kwargs, **main_config.kwargs) - - -def triton_debug_autotune(**kwars): - def wrapper(func): - return AutotuneGridSelector(func, kwars) - - return wrapper diff --git a/python/triton/interpreter/memory_map.py b/python/triton/interpreter/memory_map.py deleted file mode 100644 index d0ff732a74b9..000000000000 --- a/python/triton/interpreter/memory_map.py +++ /dev/null @@ -1,102 +0,0 @@ -from __future__ import annotations - -import dataclasses - -from . import torch_wrapper - -torch = torch_wrapper.torch - - -@dataclasses.dataclass -class RegisteredStorage: - storage: torch.Storage - dtype: torch.dtype - size: int - ptr: int - - @property - def end_ptr(self) -> int: - return self.ptr + self.size - - @property - def access_tensor(self) -> torch.Tensor: - return torch.tensor(self.storage, dtype=self.dtype, device=self.storage.device) - - def ensure_immutable(self): - assert self.storage.data_ptr() == self.ptr and self.storage.size() == self.size - - -class MemoryMap: - storages: [RegisteredStorage] - - def __init__(self): - self.storages = [] - - def _get_registered_storage(self, pointer: torch.Tensor): - max_pointer = torch.max(pointer).item() - min_pointer = torch.min(pointer).item() - - registered_storage = next( - filter( - lambda registered: min_pointer >= registered.ptr and max_pointer < registered.end_ptr, self.storages - ), - None, - ) - if registered_storage is None: - raise Exception("Storage not found or pointers spanning multiple tensors") - registered_storage.ensure_immutable() - return registered_storage - - def add_tensor(self, t: torch.Tensor): - storage = t.untyped_storage() - self.storages.append(RegisteredStorage(storage, t.dtype, storage.size(), storage.data_ptr())) - return t.data_ptr() - - def load( - self, - pointer: torch.Tensor, - mask: torch.Tensor = None, - other=0.0, - ): - assert pointer.is_cuda - assert 0 < pointer.dim() < 3 - assert pointer.dtype == torch.int64 - - if mask is None: - mask = torch.ones_like(pointer).bool() - assert mask.is_cuda - assert 0 < mask.dim() < 3 - assert mask.dtype == torch.bool - mask = mask.expand(pointer.size()) - - if torch.all(~mask): - # Todo: The type is wrong here, we can't determine the correct type - return torch.full_like(pointer, fill_value=other, dtype=torch.float16, device="cuda") - - registered_storage = self._get_registered_storage(pointer[mask]) - access_tensor = registered_storage.access_tensor - - index_tensor = pointer - registered_storage.ptr - - block = torch.full_like(pointer, fill_value=other, dtype=access_tensor.dtype, device="cuda") - block[mask] = access_tensor[index_tensor[mask]] - return block - - def store(self, pointer: torch.Tensor, value: torch.Tensor, mask=None): - assert 0 < pointer.dim() < 3 - assert pointer.dtype == torch.int64 - - if mask is None: - mask = torch.ones_like(pointer).bool() - assert 0 < mask.dim() < 3 - assert mask.dtype == torch.bool - mask = mask.expand(pointer.size()) - - if torch.all(~mask): - return - - registered_storage = self._get_registered_storage(pointer[mask]) - access_tensor = registered_storage.access_tensor - - index_tensor = pointer - registered_storage.ptr - access_tensor[index_tensor[mask]] = value[mask].to(access_tensor.dtype) diff --git a/python/triton/interpreter/tl_lang.py b/python/triton/interpreter/tl_lang.py deleted file mode 100644 index e2a578fa580f..000000000000 --- a/python/triton/interpreter/tl_lang.py +++ /dev/null @@ -1,641 +0,0 @@ -from __future__ import annotations - -from ..language import core as lcore -from . import torch_wrapper -from .core import ExecutionContext -from .memory_map import MemoryMap - -torch = torch_wrapper.torch - - -def _primitive_to_tensor(x): - """ - Converts various Python primitive data types to PyTorch tensor. - """ - tensor_args = {"device": "cuda"} - if isinstance(x, bool): - return torch.tensor([x], dtype=torch.bool, **tensor_args) - elif isinstance(x, int): - if -(2**31) <= x < 2**31: - return torch.tensor([x], dtype=torch.int32, **tensor_args) - elif -(2**63) <= x < 2**63: - return torch.tensor([x], dtype=torch.int64, **tensor_args) - else: - raise RuntimeError(f"Nonrepresentable integer {x}.") - elif isinstance(x, float): - return torch.tensor([x], dtype=torch.float32, **tensor_args) - elif torch.is_tensor(x): - return x - elif isinstance(x, WrappedTensor): - return x - elif isinstance(x, debugger_constexpr): - if x.value is None: - return None - return _primitive_to_tensor(x.value) - elif x is None: - return None - assert False, f"cannot convert {x} of type {type(x)} to tensor" - - -def _infer_tensor(func): - """ - A decorator function to harmonize function args: - - converts primitives to PyTorch tensors - - wraps PyTorch tensors with WrappedTensors - """ - def wrapper(*args): - new_args = tuple(map(lambda v: _primitive_to_tensor(v), args)) - new_args = tuple(map(lambda v: WrappedTensor(v) if torch.is_tensor(v) else v, new_args)) - - return func(*new_args) - - return wrapper - - -def _tensor_operation(func): - """ - A decorator function to unwrap WrappedTensors and debugger_constexpr before calling the function. - Can be combined with _infer_tensor decorator to harmonize args (everything to torch tensor). - """ - def wrapper(*args, **kwargs): - for arg in args: - assert not torch.is_tensor(arg), "unexpected tensor argument" - - def unwrap_tensor(v): - if isinstance(v, WrappedTensor): - return v.tensor - if isinstance(v, debugger_constexpr): - return v.value - return v - - new_args = tuple(map(unwrap_tensor, args)) - new_kwargs = {k: unwrap_tensor(v) for k, v in kwargs.items()} - - result = func(args[0], *new_args[1:], **new_kwargs) - return WrappedTensor(result) if torch.is_tensor(result) else result - - return wrapper - - -class debugger_constexpr: - def __init__(self, value): - if isinstance(value, debugger_constexpr): - self.value = value.value - else: - self.value = value - - def __str__(self) -> str: - return "debugger_constexpr(" + str(self.value) + ")" - - def __index__(self) -> int: - return self.value - - def __bool__(self): - return bool(self.value) - - def __ge__(self, other): - other = other.value if isinstance(other, debugger_constexpr) else other - return self.value >= other - - def __gt__(self, other): - other = other.value if isinstance(other, debugger_constexpr) else other - return self.value > other - - def __le__(self, other): - other = other.value if isinstance(other, debugger_constexpr) else other - return self.value <= other - - def __lt__(self, other): - other = other.value if isinstance(other, debugger_constexpr) else other - return self.value < other - - def __eq__(self, other): - other = other.value if isinstance(other, debugger_constexpr) else other - return self.value == other - - def __or__(self, other): - other = other.value if isinstance(other, debugger_constexpr) else other - return self.value | other - - def __ror__(self, other): - other = other.value if isinstance(other, debugger_constexpr) else other - return self.value | other - - def __and__(self, other): - other = other.value if isinstance(other, debugger_constexpr) else other - return self.value & other - - def __rand__(self, other): - other = other.value if isinstance(other, debugger_constexpr) else other - return self.value & other - - def to(self, dtype, bitcast=False, _builder=None): - if dtype in [torch.int64]: - ret_ty = int - elif dtype == torch.bool: - ret_ty = bool - elif dtype in [torch.float64]: - ret_ty = float - else: - raise ValueError("dtype not supported in debugger") - return debugger_constexpr(ret_ty(self.value)) - - -class WrappedTensor: - def __init__(self, tensor): - self.tensor = tensor - - def __index__(self) -> int: - return self.tensor.item() - - def __str__(self) -> str: - return "wrapped_" + str(self.tensor) - - def __bool__(self) -> bool: - return torch.all(self.tensor == True).item() # noqa: E712 - - @property - def dtype(self): - return self.tensor.dtype - - @_infer_tensor - @_tensor_operation - def __add__(self, other): - return torch.add(self.tensor, other) - - @_infer_tensor - @_tensor_operation - def __radd__(self, other): - return self.__add__(other) - - @_infer_tensor - @_tensor_operation - def __sub__(self, other): - return torch.sub(self.tensor, other) - - @_infer_tensor - @_tensor_operation - def __rsub__(self, other): - return torch.sub(other, self.tensor) - - @_infer_tensor - @_tensor_operation - def __mul__(self, other): - return torch.mul(self.tensor, other) - - @_infer_tensor - @_tensor_operation - def __rmul__(self, other): - return self.__mul__(other) - - @_infer_tensor - @_tensor_operation - def __truediv__(self, other): - return torch.div(self.tensor, other) - - @_infer_tensor - @_tensor_operation - def __rtruediv__(self, other): - return torch.div(other, self.tensor) - - @_infer_tensor - @_tensor_operation - def __floordiv__(self, other): - return torch.floor_divide(self.tensor, other) - - @_infer_tensor - @_tensor_operation - def __rfloordiv__(self, other): - return torch.floor_divide(other, self.tensor) - - @_infer_tensor - @_tensor_operation - def __mod__(self, other): - return torch.remainder(self.tensor, other) - - @_infer_tensor - @_tensor_operation - def __rmod__(self, other): - return torch.remainder(other, self.tensor) - - @_infer_tensor - @_tensor_operation - def __neg__(self): - return -self.tensor - - @_infer_tensor - @_tensor_operation - def __invert__(self): - return ~self.tensor - - @_infer_tensor - @_tensor_operation - def __and__(self, other): - return torch.bitwise_and(self.tensor, other) - - @_infer_tensor - @_tensor_operation - def __or__(self, other): - return torch.bitwise_or(self.tensor, other) - - @_infer_tensor - @_tensor_operation - def __xor__(self, other): - return torch.bitwise_xor(self.tensor, other) - - @_infer_tensor - @_tensor_operation - def __lshift__(self, other): - return torch.bitwise_left_shift(self.tensor, other) - - @_infer_tensor - @_tensor_operation - def __rshift__(self, other): - return torch.bitwise_right_shift(self.tensor, other) - - @_infer_tensor - @_tensor_operation - def __gt__(self, other): - return self.tensor > other - - @_infer_tensor - @_tensor_operation - def __rgt__(self, other): - return other > self.tensor - - @_infer_tensor - @_tensor_operation - def __ge__(self, other): - return self.tensor >= other - - @_infer_tensor - @_tensor_operation - def __rge__(self, other): - return other >= self.tensor - - @_infer_tensor - @_tensor_operation - def __lt__(self, other): - return self.tensor < other - - @_infer_tensor - @_tensor_operation - def __rlt__(self, other): - return other < self.tensor - - @_infer_tensor - @_tensor_operation - def __le__(self, other): - return self.tensor <= other - - @_infer_tensor - @_tensor_operation - def __rle__(self, other): - return other <= self.tensor - - @_infer_tensor - @_tensor_operation - def __eq__(self, other): - return torch.equal(self.tensor, other) - - @_infer_tensor - @_tensor_operation - def __ne__(self, other): - return not torch.equal(self.tensor, other) - - @_tensor_operation - def __getitem__(self, slices): - return self.tensor.__getitem__(slices) - # if isinstance(slices, slice): - # slices = [slices] - # src_shape = self.shape - # dst_shape = [] - # curr = 0 - # for sl in slices: - # if isinstance(sl, constexpr) and sl.value is None: - # dst_shape.append(1) - # elif sl == slice(None, None, None): - # dst_shape.append(src_shape[curr].value) - # curr += 1 - # ret = torch.reshape(self.tensor, dst_shape, ) - # return ret - - @_tensor_operation - def to(self, dtype, bitcast=False): - return self.tensor.to(dtype) - # if isinstance(bitcast, constexpr): - # bitcast = bitcast.value - # if bitcast: - # return semantic.bitcast(self, dtype, ) - # return semantic.cast(self, dtype, ) - - -def _constexpr_to_value(v): - if isinstance(v, debugger_constexpr): - return v.value - return v - - -class TritonLangProxy: - _memory_map: MemoryMap - _context: ExecutionContext - - def __init__(self, memory_map: MemoryMap, context: ExecutionContext): - self._memory_map = memory_map - self._context = context - - # Types - # Removed void, int1, float8, uint16, uint32, uint64, pi32_t - - # constexpr = debugger_constexpr - - # Program functions - - @_tensor_operation - def load( - self, - pointer: torch.Tensor, - mask: torch.Tensor = None, - other=0.0, - cache_modifier="", - eviction_policy="", - volatile=False, - ): - return self._memory_map.load(pointer, mask, other) - - @_tensor_operation - def store(self, pointer: torch.Tensor, value: torch.Tensor, mask=None): - return self._memory_map.store(pointer, value, mask) - - @_tensor_operation - def program_id(self, axis): - assert axis < len(self._context.program_id) - return torch.tensor([self._context.program_id[axis]], dtype=torch.int32, device="cuda") - - @_tensor_operation - def num_programs(self, axis): - assert axis < len(self._context.program_size) - return torch.tensor([self._context.program_size[axis]], dtype=torch.int32, device="cuda") - - @_tensor_operation - def arange(self, start, end): - return torch.arange(start=start, end=end, dtype=torch.int32, device="cuda") - - @_tensor_operation - def zeros(self, shape, dtype): - for i, d in enumerate(shape): - if not isinstance(d, debugger_constexpr): - raise TypeError(f"Shape element {i} must have type `constexpr`") - if not isinstance(d.value, int): - raise TypeError(f"Shape element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]") - shape = [x.value for x in shape] - if isinstance(dtype, lcore.dtype): - if dtype.is_fp32(): - dtype = torch.float32 - elif dtype.is_fp16(): - dtype = torch.float16 - elif dtype.is_bf16(): - dtype = torch.bfloat16 - elif dtype.is_int32(): - dtype = torch.int32 - elif dtype.is_int16(): - dtype = torch.int16 - elif dtype.is_int8(): - dtype = torch.int8 - else: - raise TypeError(f"Unsupported dtype {dtype}") - return torch.zeros(size=shape, dtype=dtype, device="cuda") - - @_tensor_operation - def dequantize(self, input, scale, shift, nbit, dst_ty=None): - if dst_ty is None: - dst_ty = torch.float16 - raise NotImplementedError() - - @_tensor_operation - def broadcast(self, input, other): - raise NotImplementedError() - - @_tensor_operation - def broadcast_to(self, input, shape): - raise NotImplementedError() - - @_tensor_operation - def cat(self, input, shape): - raise NotImplementedError() - - @_tensor_operation - def reshape(self, input, shape): - raise NotImplementedError() - - @_tensor_operation - def dot(self, input, other, trans_a=False, trans_b=False, allow_tf32=True): - assert input.dtype == other.dtype - if trans_a: - input = input.T - if trans_b: - other = other.T - return torch.matmul(input=input, other=other) - - @_tensor_operation - def atomic_cas(self, pointer, cmp, val): - stored = self._memory_map.load(pointer, None, 0.0) - if not isinstance(cmp, torch.Tensor): - cmp = torch.tensor([cmp], dtype=stored.dtype, device="cuda") - if not isinstance(val, torch.Tensor): - val = torch.tensor([val], dtype=stored.dtype, device="cuda") - if stored == cmp: - self._memory_map.store(pointer, val, None) - return stored - - @_tensor_operation - def atomic_xchg(self, pointer, val, mask=None): - if isinstance(val, int): - val = torch.tensor([val], dtype=torch.int32, device="cuda") - stored = self._memory_map.load(pointer, mask, 0.0) - self._memory_map.store(pointer, val, mask) - return stored - - @_tensor_operation - def atomic_add(self, pointer, val, mask=None): - # arbitrary other value as it will masked during storing - stored = self._memory_map.load(pointer, mask, 0.0) - result = stored + val - self._memory_map.store(pointer, result, mask) - return stored - - @_tensor_operation - def atomic_max(self, pointer, val, mask=None): - stored = self._memory_map.load(pointer, mask, 0.0) - result = torch.maximum(stored, val) - self._memory_map.store(pointer, result, mask) - return stored - - @_tensor_operation - def atomic_min(self, pointer, val, mask=None): - stored = self._memory_map.load(pointer, mask, 0.0) - result = torch.minimum(stored, val) - self._memory_map.store(pointer, result, mask) - return stored - - @_tensor_operation - def atomic_and(self, pointer, val, mask=None): - stored = self._memory_map.load(pointer, mask, 0) - result = torch.bitwise_and(stored, val) - self._memory_map.store(pointer, result, mask) - return stored - - @_tensor_operation - def atomic_or(self, pointer, val, mask=None): - stored = self._memory_map.load(pointer, mask, 0) - result = torch.bitwise_or(stored, val) - self._memory_map.store(pointer, result, mask) - return stored - - @_tensor_operation - def atomic_xor(self, pointer, val, mask=None): - stored = self._memory_map.load(pointer, mask, 0) - result = torch.bitwise_xor(stored, val) - self._memory_map.store(pointer, result, mask) - return stored - - @_tensor_operation - def where(self, condition, x, y): - condition = _primitive_to_tensor(condition) - x = _primitive_to_tensor(x) - y = _primitive_to_tensor(y) - return torch.where(condition, x, y) - - @_tensor_operation - def umulhi(self, x, y): - raise NotImplementedError() - - @_tensor_operation - def fdiv(self, x, y, ieee_rounding=False): - raise NotImplementedError() - - @_tensor_operation - def exp(self, x): - return torch.exp(x) - - @_tensor_operation - def log(self, x): - return torch.log(x) - - @_tensor_operation - def cos(self, x): - return torch.cos(x) - - @_tensor_operation - def sin(self, x): - return torch.sin(x) - - @_tensor_operation - def sqrt(self, x): - return torch.sqrt(x) - - @_tensor_operation - def globaltimer(self): - raise NotImplementedError() - - @_tensor_operation - def clock(self): - raise NotImplementedError() - - @_tensor_operation - def debug_barrier(self): - raise NotImplementedError() - - @_tensor_operation - def multiple_of(self, input, values): - return input - - @_tensor_operation - def max_contiguous(self, input, values): - return input - - @_tensor_operation - def max_constancy(self, input, values): - return input - - @_tensor_operation - def abs(self, x): - return torch.abs(x) - - @_tensor_operation - def cdiv(self, x, div): - return (x + div - 1) // div - - @_tensor_operation - def minimum(self, x, y): - if isinstance(x, int): - x = torch.tensor(x, device="cuda") - if isinstance(y, int): - y = torch.tensor(y, device="cuda") - return torch.minimum(x, y) - - @_tensor_operation - def maximum(self, x, y): - return torch.maximum(x, y) - - @_tensor_operation - def sigmoid(self, x): - raise NotImplementedError() - - @_tensor_operation - def softmax(self, x, ieee_rounding=False): - raise NotImplementedError() - - @_tensor_operation - def ravel(self, x): - raise NotImplementedError() - - @_tensor_operation - def swizzle2d(self, i, j, size_i, size_j, size_g): - raise NotImplementedError() - - @_tensor_operation - def zeros_like(self, input): - raise NotImplementedError() - - @_tensor_operation - def max(self, input, axis=None): - if axis is None: - return torch.max(input) - return torch.max(input, dim=axis).values - - @_tensor_operation - def argmax(self, input, axis): - raise NotImplementedError() - - @_tensor_operation - def min(self, input, axis=None): - if axis is None: - return torch.min(input) - return torch.min(input, dim=axis).values - - @_tensor_operation - def argmin(self, input, axis): - raise NotImplementedError() - - @_tensor_operation - def sum(self, input, axis=None): - if axis is None: - return torch.sum(input) - return torch.sum(input, dim=axis) - - @_tensor_operation - def xor_sum(self, input, axis): - raise NotImplementedError() - - @_tensor_operation - def cumsum(self, input, axis=None): - if axis is None: - return torch.cumsum(input) - return torch.cumsum(input, dim=axis) - - @_tensor_operation - def cumprod(self, input, axis=None): - if axis is None: - return torch.cumprod(input) - return torch.cumprod(input, dim=axis) diff --git a/python/triton/interpreter/torch_wrapper.py b/python/triton/interpreter/torch_wrapper.py deleted file mode 100644 index 44aa17eb1355..000000000000 --- a/python/triton/interpreter/torch_wrapper.py +++ /dev/null @@ -1,18 +0,0 @@ -try: - import torch as _torch -except ImportError: - _torch = None - - -class TorchWrapper: - """ - Helps in making torch an optional dependency - """ - - def __getattr__(self, name): - if _torch is None: - raise ImportError("Triton requires PyTorch to be installed") - return getattr(_torch, name) - - -torch = TorchWrapper() diff --git a/python/triton/language/core.py b/python/triton/language/core.py index a0fcb1633b5e..150d3936018f 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -542,14 +542,15 @@ def __init__(self, handle, type: dtype): self.shape = [constexpr(s) for s in self.shape] def __str__(self) -> str: - # ex. "float32[3,4]" - return str(self.dtype) + '[' + ','.join(str(s) for s in self.shape) + ']' + # ex. "float32[16, 32]" + return str(self.dtype) + '[' + ', '.join(str(s) for s in self.shape) + ']' @builtin def __add__(self, other, _builder=None): other = _to_tensor(other, _builder) return semantic.add(self, other, _builder) + @builtin def __radd__(self, other, _builder=None): return self.__add__(other, _builder=_builder) @@ -558,6 +559,7 @@ def __sub__(self, other, _builder=None): other = _to_tensor(other, _builder) return semantic.sub(self, other, _builder) + @builtin def __rsub__(self, other, _builder=None): other = _to_tensor(other, _builder) return semantic.sub(other, self, _builder) @@ -567,6 +569,7 @@ def __mul__(self, other, _builder=None): other = _to_tensor(other, _builder) return semantic.mul(self, other, _builder) + @builtin def __rmul__(self, other, _builder=None): return self.__mul__(other, _builder=_builder) @@ -575,6 +578,7 @@ def __truediv__(self, other, _builder=None): other = _to_tensor(other, _builder) return semantic.truediv(self, other, _builder) + @builtin def __rtruediv__(self, other, _builder=None): other = _to_tensor(other, _builder) return semantic.truediv(other, self, _builder) @@ -666,8 +670,6 @@ def __rrshift__(self, other, _builder=None): else: return semantic.lshr(other, self, _builder) - # comparison operators - # > @builtin def __gt__(self, other, _builder=None): @@ -745,7 +747,7 @@ def __getitem__(self, slices, _builder=None): slices = [slices] ret = self for dim, sl in enumerate(slices): - if isinstance(sl, constexpr) and sl.value is None: + if sl is None or isinstance(sl, constexpr) and sl.value is None: ret = semantic.expand_dims(ret, dim, _builder) elif isinstance(sl, slice) and sl.start is None and sl.stop is None and sl.step is None: pass @@ -830,6 +832,8 @@ def arange(start, end, _builder=None): def _shape_check_impl(shape): shape = _constexpr_to_value(shape) for i, d in enumerate(shape): + if isinstance(d, int): + d = constexpr(d) if not isinstance(d, constexpr): raise TypeError(f"Shape element {i} must have type `constexpr`") if not isinstance(d.value, int): diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 019919629a40..7b9109b23b87 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1570,6 +1570,8 @@ def device_assert(cond: tl.tensor, msg: str, file_name: str, func_name, lineno: def _convert_elem_to_ir_value(builder, elem, require_i64): + if isinstance(elem, int): + elem = tl.constexpr(elem) if isinstance(elem, tl.constexpr): return builder.get_int64(elem.value) if require_i64 else builder.get_int32(elem.value) elif isinstance(elem, tl.tensor): diff --git a/python/triton/language/standard.py b/python/triton/language/standard.py index 8acc4261585f..8ef52cb9cfd6 100644 --- a/python/triton/language/standard.py +++ b/python/triton/language/standard.py @@ -160,7 +160,7 @@ def max(input, axis=None, return_indices=False, return_indices_tie_break_left=Tr else: return core._reduce_with_indices(input, axis, _argmax_combine_tie_break_fast) else: - if core.constexpr(input.dtype.primitive_bitwidth) < 32: + if core.constexpr(input.dtype.primitive_bitwidth) < core.constexpr(32): if core.constexpr(input.dtype.is_floating()): input = input.to(core.float32) else: diff --git a/python/triton/runtime/interpreter.py b/python/triton/runtime/interpreter.py new file mode 100644 index 000000000000..218208d86aa7 --- /dev/null +++ b/python/triton/runtime/interpreter.py @@ -0,0 +1,525 @@ +import inspect + +import numpy as np + +import triton +import triton.language as tl +from .._C.libtriton.triton import interpreter as _interpreter + + +# TODO: duplicate +def str_to_ty(name): + language = tl + if name[0] == "*": + ty = str_to_ty(name[1:]) + return language.pointer_type(ty) + tys = { + "fp8e4nv": language.float8e4nv, + "fp8e5": language.float8e5, + "fp8e4b15": language.float8e4b15, + "fp8e4b15x4": language.float8e4b15x4, + "fp16": language.float16, + "bf16": language.bfloat16, + "fp32": language.float32, + "fp64": language.float64, + "i1": language.int1, + "i8": language.int8, + "i16": language.int16, + "i32": language.int32, + "i64": language.int64, + "u8": language.uint8, + "u16": language.uint16, + "u32": language.uint32, + "u64": language.uint64, + "B": language.int1, + } + return tys[name] + + +class TensorHandle: + + def __init__(self, data, dtype): + self.data = data + self.dtype = dtype + + def __bool__(self): + return bool(self.data.all()) + + +class BlockPointerHandle: + + def __init__(self, base, shape, strides, offsets, tensor_shape, order): + self.base = base + self.shape = shape + self.strides = strides + self.offsets = offsets + self.tensor_shape = tensor_shape + self.order = order + + def materialize_pointers(self, boundary_check): + dtype_tt = self.base.dtype.element_ty + n_bytes = dtype_tt.primitive_bitwidth // 8 + tensor_shape = self.tensor_shape + ptrs = np.broadcast_to(self.base.data, self.tensor_shape) + masks = np.ones(self.tensor_shape, dtype=bool) + for dim in range(len(tensor_shape)): + bcast_dims = [1] * len(tensor_shape) + bcast_dims[dim] = tensor_shape[dim] + off = (self.offsets[dim].data + np.arange(tensor_shape[dim])).reshape(bcast_dims) + ptrs = ptrs + (n_bytes * off * self.strides[dim].data).astype(np.uint64) + if dim in boundary_check: + masks = np.logical_and(masks, off < self.shape[dim].data) + ptrs = TensorHandle(ptrs, self.base.dtype) + return ptrs, masks + + +def wrap_ret(compute_ret_ty): + def wrapper(fn): + def wrapped(*args, **kwargs): + ret = fn(*args, **kwargs) + return TensorHandle(ret.data, compute_ret_ty(*args, **kwargs)) + return wrapped + return wrapper + + +class Builder: + + def __init__(self) -> None: + self.arch = None + # pass + + def set_grid_idx(self, x, y, z): + assert x < self.grid_dim[0] + assert y < self.grid_dim[1] + assert z < self.grid_dim[2] + self.grid_idx = (x, y, z) + + def set_grid_dim(self, nx, ny, nz): + self.grid_dim = (nx, ny, nz) + + def np_dtype(self, tt_dtype): + if isinstance(tt_dtype, tl.pointer_type): + return np.dtype(np.uint64) + np_types = { + tl.float16: np.dtype(np.float16), + tl.float32: np.dtype(np.float32), + tl.float64: np.dtype(np.float64), + tl.int8: np.dtype(np.int8), + tl.uint8: np.dtype(np.uint8), + tl.int16: np.dtype(np.int16), + tl.uint16: np.dtype(np.uint16), + tl.int32: np.dtype(np.int32), + tl.uint32: np.dtype(np.uint32), + tl.int64: np.dtype(np.int64), + tl.uint64: np.dtype(np.uint64), + } + return np_types[tt_dtype] + + # constants + def get_half_ty(self): + return tl.float16 + + def get_float_ty(self): + return tl.float32 + + def get_int64_ty(self): + return tl.int64 + + def get_ptr_ty(self, elt_ty, addr_space): + return tl.pointer_type(elt_ty, addr_space) + + def get_block_ty(self, dtype, shape): + return tl.tensor(shape, dtype) + + def get_int32(self, value): + return TensorHandle(np.array([value], dtype=np.int32), tl.int32) + + def get_int64(self, value): + return TensorHandle(np.array([value], dtype=np.int64), tl.int64) + + def get_fp16(self, value): + return TensorHandle(np.array([value], dtype=np.float16), tl.float16) + + def get_fp32(self, value): + return TensorHandle(np.array([value], dtype=np.float32), tl.float32) + + def get_null_value(self, type): + return TensorHandle(np.array([0], dtype=self.np_dtype(type)), type) + + # programming model + def create_get_program_id(self, axis): + assert self.grid_idx is not None + return TensorHandle(np.array([self.grid_idx[axis]], dtype=np.int32), tl.int32) + + def create_get_num_programs(self, axis): + return TensorHandle(np.array([self.grid_dim[axis]], dtype=np.int32), tl.int32) + + # memory ops + def create_load(self, ptr, _0, _1, is_volatile): + mask = TensorHandle(np.ones_like(ptr.data, dtype=bool), tl.int1) + other = None + return self.create_masked_load(ptr, mask, other, _0, _1, is_volatile) + + def create_store(self, ptr, val, _0, _1): + mask = TensorHandle(np.ones_like(ptr.data, dtype=bool), tl.int1) + return self.create_masked_store(ptr, val, mask, None, None) + + def create_masked_load(self, ptrs, mask, other, cache_modifier, eviction_policy, is_volatile): + dtype_tt = ptrs.dtype.element_ty + dtype_np = self.np_dtype(dtype_tt) + if other is None: + other = TensorHandle(np.ones_like(ptrs.data, dtype=dtype_np), dtype_tt) + ret = _interpreter.load(ptrs.data, mask.data, other.data, dtype_np) + return TensorHandle(ret, dtype_tt) + + def create_masked_store(self, ptrs, value, mask, cache_modifier, eviction_policy): + return _interpreter.store(ptrs.data, value.data, mask.data) + + # casting ops + def cast_impl(self, src, dst_type): + if isinstance(dst_type, tl.tensor): + dst_type = dst_type.dtype + return TensorHandle(src.data.astype(self.np_dtype(dst_type)), dst_type) + + create_si_to_fp = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_ui_to_fp = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_fp_to_si = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_fp_to_ui = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_fp_ext = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_fp_trunc = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_int_cast = lambda self, src, dst_type, is_signed: self.cast_impl(src, dst_type) + + def create_fp_to_fp(self, src, dst_type): + assert "float8 not NotImplemented yet" + + def create_bitcast(self, src, dst_type): + return TensorHandle(src.data.view(self.np_dtype(dst_type)), dst_type) + + # binary operators + def binary_op(self, lhs, rhs, op): + return TensorHandle(op(lhs.data, rhs.data), lhs.dtype) + + create_fadd = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.add) + create_fmul = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.multiply) + create_fdiv = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.divide) + create_frem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.remainder) + create_fsub = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.subtract) + create_mul = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.multiply) + create_sdiv = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.floor_divide) + create_udiv = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.floor_divide) + create_srem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.remainder) + create_urem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.remainder) + create_add = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.add) + create_sub = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.subtract) + create_shl = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.left_shift) + create_lshr = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.right_shift) + create_ashr = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.right_shift) + create_minsi = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum) + create_minui = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum) + create_minf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum) + create_maxsi = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum) + create_maxui = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum) + create_maxf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum) + create_icmpSLE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal) + create_icmpSLT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less) + create_icmpSGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal) + create_icmpSGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater) + create_icmpULE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal) + create_icmpULT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less) + create_icmpUGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal) + create_icmpUGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater) + create_icmpEQ = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.equal) + create_icmpNE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.not_equal) + create_fcmpOLT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less) + create_fcmpOGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater) + create_fcmpOLE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal) + create_fcmpOGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal) + create_fcmpOEQ = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.equal) + create_fcmpONE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.not_equal) + create_fcmpULT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less) + create_fcmpUGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater) + create_fcmpULE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal) + create_fcmpUGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal) + create_fcmpUEQ = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.equal) + create_fcmpUNE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.not_equal) + create_and = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.bitwise_and) + create_xor = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.bitwise_xor) + create_or = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.bitwise_or) + + # ternary functions + def ternary_op(self, lhs, rhs, other, op): + return TensorHandle(op(lhs.data, rhs.data, other.data), other.dtype) + create_select = lambda self, cond, lhs, rhs: self.ternary_op(cond, lhs, rhs, np.where) + + # unary functions + def unary_op(self, arg, op): + return TensorHandle(op(arg.data), arg.dtype) + create_exp = lambda self, arg: self.unary_op(arg, np.exp) + create_cos = lambda self, arg: self.unary_op(arg, np.cos) + create_sin = lambda self, arg: self.unary_op(arg, np.sin) + create_log = lambda self, arg: self.unary_op(arg, np.log) + create_sqrt = lambda self, arg: self.unary_op(arg, np.sqrt) + create_fabs = lambda self, arg: self.unary_op(arg, np.abs) + create_iabs = lambda self, arg: self.unary_op(arg, np.abs) + + # tensor operators + create_dot = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.dot) + create_view = lambda self, arg, shape: TensorHandle(arg.data.reshape(shape), arg.dtype) + create_trans = lambda self, arg: self.unary_op(arg, np.transpose) + + def create_dot(self, a, b, d, allow_tf32, maxNumImpreciseAcc): + return TensorHandle(np.dot(a.data, b.data) + d.data, a.dtype) + + def create_make_range(self, start, stop): + return TensorHandle(np.arange(start, stop, dtype=np.int32), tl.int32) + + # pointer arithmetic + + def create_addptr(self, ptr, offset): + dtype_tt = ptr.dtype.element_ty + return TensorHandle(ptr.data + (dtype_tt.primitive_bitwidth // 8) * offset.data.astype(np.uint64), ptr.dtype) + + def create_tensor_pointer_load(self, ptr, boundary_check, padding_option, cache_modifier, eviction_policy, is_volatile): + ptrs, masks = ptr.materialize_pointers(boundary_check) + assert padding_option is None + other = None + return self.create_masked_load(ptrs, masks, other, cache_modifier, eviction_policy, is_volatile) + + def create_tensor_pointer_store(self, ptr, value, boundary_check, cache_modifier, eviction_policy): + ptrs, masks = ptr.materialize_pointers(boundary_check) + return self.create_masked_store(ptrs, value, masks, cache_modifier, eviction_policy) + + def create_expand_dims(self, arg, axis): + return TensorHandle(np.expand_dims(arg.data, axis), arg.dtype) + + def create_broadcast(self, arg, shape): + return TensorHandle(np.broadcast_to(arg.data, shape), arg.dtype) + + def create_int_to_ptr(self, val, dst_ty): + return TensorHandle(val.data.astype(np.uint64), dst_ty) + # def create_cat(self, lhs, rhs): + # pass + + # def create_broadcast(self, arg, shape): + # pass + + def create_splat(self, arg, shape): + return TensorHandle(np.full(shape, arg.data[0], dtype=self.np_dtype(arg.dtype)), arg.dtype) + + # def create_atomic_cas(self, ptr, cmp, val, sem): + # pass + + # def create_atomic_rmw(self, rmwOp, ptr, val, mask, sem): + # pass + + # def create_extern_elementwise(self, libName, libPath, symbol, argList, retType, isPure): + # pass + + # def create_reduce(self, operands, axis): + # pass + + # def create_reduce_ret(self, args): + # pass + + # def create_scan(self, operands, axis): + # pass + + # def create_scan_ret(self, args): + # pass + + # def create_ptr_to_int(self, val, type): + # pass + + # def create_int_to_ptr(self, val, type): + # pass + + # def create_inline_asm(self, inlineAsm, constraints, values, type, isPure, pack): + # pass + + # def create_print(self, prefix, values): + # pass + + # def create_assert(self, condition, message, fileName, funcName, lineNo): + # pass + + # def create_undef(self, type): + # pass + + # def create_barrier(self): + # pass + + def create_make_block_ptr(self, base, shape, strides, offsets, tensor_shape, order): + return BlockPointerHandle(base, shape, strides, np.array(offsets), tensor_shape, order) + + def create_advance(self, ptr, offsets): + assert len(ptr.offsets) == len(offsets) + ret = BlockPointerHandle(ptr.base, ptr.shape, ptr.strides, ptr.offsets, ptr.tensor_shape, ptr.order) + for i in range(len(offsets)): + ret.offsets[i].data += offsets[i].data + return ret + + +def patch_attr(obj, name, member, builder): + new_member = lambda *args, member=member, **kwargs: (member(*args, **{k: v for k, v in kwargs.items() if k != '_builder'}, _builder=builder)) + setattr(obj, name, new_member) + + +def _patch_lang_tensor(tensor, builder): + for name, member in inspect.getmembers(tensor): + if tl.core.is_builtin(member): + patch_attr(tensor, name, member, builder) + tensor.__index__ = lambda self: int(self.handle.data) + tensor.__bool__ = lambda self: True + + +def _patch_lang_core(lang, builder): + for name, member in inspect.getmembers(lang): + if tl.core.is_builtin(member): + patch_attr(lang, name, member, builder) + # reduce is better off with a separate patch due to how + # the builder currently interfaces with custom functions + + def _new_reduce(input, axis, combine_fn): + fn = combine_fn.fn.__name__ + mapping = { + 'maximum': np.max, + '_sum_combine': np.sum, + } + ret = mapping[fn](input.handle.data, axis=axis) + ret_type = tl.block_type(input.dtype, ret.shape) + return tl.core.tensor(TensorHandle(ret, input.dtype), ret_type) + + lang.reduce = _new_reduce + + +def _patch_lang_math(lang, builder): + math = lang.math + mapping = { + 'abs': 'abs', + 'acos': 'arccos', + 'asin': 'arcsin', + 'exp2': 'exp2', + 'log2': 'log2', + 'max': 'maximum', + } + + def make_numpy(name): + def impl(*args, **kwargs): + ret_type = args[0].type # TODO: incorrect + ret_dtype = args[0].dtype # TODO: incorrect + args = [arg.handle.data for arg in args] + kwargs = {k: v.handle.data for k, v in kwargs.items()} + ret = getattr(np, mapping[name])(*args, **kwargs) + ret = tl.core.tensor(TensorHandle(ret, ret_dtype), ret_type) + return ret + return impl + + def make_fallback(name): + def fallback(*args, **kwargs): + raise NotImplementedError(f""" +{name} not supported in interpreter mode: no known numpy implementation. +If you think that {name} in fact does have a numpy implementation, please add it +to the mapping in python/triton/interpreter/new_interpreter.py:_patch_lang_math. +""") + return fallback + + for name, member in inspect.getmembers(math): + if name in mapping: + setattr(math, name, make_numpy(name)) + else: + setattr(math, name, make_fallback(name)) + + +# TODO: wrap everything in triton tensors +def _implicit_cvt(arg): + if isinstance(arg, int): + ty = str_to_ty(triton.runtime.jit.JITFunction._type_of(triton.runtime.jit.JITFunction._key_of(arg))) + handle = TensorHandle(np.array([arg], dtype=np.int32), ty) + return tl.tensor(handle, ty) + if hasattr(arg, 'data_ptr'): + ty = str_to_ty(triton.runtime.jit.JITFunction._type_of(triton.runtime.jit.JITFunction._key_of(arg))) + handle = TensorHandle(np.array([arg.data_ptr()], dtype=np.uint64), ty) + return tl.tensor(handle, ty) + return arg + + +def _unwrap(tensor): + if isinstance(tensor, triton.TensorWrapper): + return tensor.base + return tensor + + +builder = Builder() + +RESERVED_KWS = ['num_warps', 'num_stages', 'num_ctas', 'enable_warp_specialization'] + + +class GridExecutor: + + def __init__(self, fn, arg_names, grid): + from .jit import _normalize_ty # TODO: modularize + self.fn = fn + self.arg_names = arg_names + self.grid = grid + __annotations__ = {name: _normalize_ty(ty) for name, ty in fn.__annotations__.items()} + self.constexprs = [name for name in arg_names if __annotations__.get(name) == 'constexpr'] + + def _patch_lang(self, builder): + lang = [value for _, value in self.fn.__globals__.items() if value in [tl, tl.core]] + assert len(lang) == 1, "triton.language must be visible from within jit'd function" + _patch_lang_tensor(getattr(lang[0], 'tensor'), builder) + _patch_lang_core(lang[0], builder) + _patch_lang_math(lang[0], builder) + + def __call__(self, *args_dev, **kwargs): + args_hst = [_unwrap(arg).cpu() if hasattr(arg, 'data_ptr') else arg for arg in args_dev] + # removes reserved keywords from kwargs + kwargs = {k: v for k, v in kwargs.items() if k not in RESERVED_KWS} + # remaps core language functions to interpreted ones + self._patch_lang(builder) + # we need to copy arguments to the host for the interpreter + # implicitly convert tensor arguments to their base pointers + args = inspect.getcallargs(self.fn, *args_hst, **kwargs) + args = {name: arg if name in self.constexprs else _implicit_cvt(arg) for name, arg in args.items()} + # iterate through grid + grid = self.grid(args) if callable(self.grid) else self.grid + assert len(grid) <= 3 + grid = grid + (1,) * (3 - len(grid)) + builder.set_grid_dim(*grid) + for x in range(grid[0]): + for y in range(grid[1]): + for z in range(grid[2]): + builder.set_grid_idx(x, y, z) + self.fn(**args) + # copy arguments back to propagate side-effects + for arg_dev, arg_hst in zip(args_dev, args_hst): + if hasattr(arg_dev, 'data_ptr'): + _unwrap(arg_dev).copy_(arg_hst.to(arg_dev.device)) + + +class InterpretedFunction: + + def _patch_lang(self, builder): + lang = [value for _, value in self.fn.__globals__.items() if value in [tl, tl.core]] + assert len(lang) == 1, "triton.language must be visible from within jit'd function" + _patch_lang_tensor(getattr(lang[0], 'tensor'), builder) + _patch_lang_core(lang[0], builder) + + def __init__(self, fn) -> None: + self.fn = fn + + def run(*args, **kwargs): + grid = kwargs['grid'] + kwargs = {k: v for k, v in kwargs.items() if k not in RESERVED_KWS + ['grid']} + + return GridExecutor(self.fn, self.arg_names, grid)(*args, **kwargs) + self.run = run + signature = inspect.signature(fn) + self.arg_names = [v.name for v in signature.parameters.values()] + + def __getitem__(self, grid): + return GridExecutor(self.fn, self.arg_names, grid) + + def __call__(self, *args, **kwargs): + self._patch_lang(builder) + return self.fn(*args, **kwargs) diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index f664b33c71ab..1809ce36cf50 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -14,6 +14,7 @@ from .._C.libtriton.triton import TMAInfos from ..common.backend import get_backend, path_to_ptxas from ..language.core import dtype +from .interpreter import InterpretedFunction TRITON_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) TRITON_VERSION = "2.1.0" @@ -270,10 +271,6 @@ def _type_of(key): tys[v] = v return key if isinstance(key, str) else f"*{tys[dtype_str]}" - def _make_signature(self, sig_key): - signature = ",".join([self._type_of(k) for i, k in enumerate(sig_key)]) - return signature - def _make_constants(self, constexpr_key): constants = dict(zip(self.constexprs, constexpr_key)) return constants @@ -568,7 +565,6 @@ def jit( do_not_specialize: Optional[Iterable[int]] = None, debug: Optional[bool] = None, noinline: Optional[bool] = None, - interpret: Optional[bool] = None, ) -> Union[JITFunction[T], Callable[[T], JITFunction[T]]]: """ Decorator for JIT-compiling a function using the Triton compiler. @@ -590,9 +586,8 @@ def jit( def decorator(fn: T) -> JITFunction[T]: assert callable(fn) - if interpret: - from ..interpreter.interpreter import GridSelector - return GridSelector(fn) + if os.getenv("TRITON_INTERPRET", "0") == "1": + return InterpretedFunction(fn) else: return JITFunction( fn, From 894fa9e9436bd268fe770ff383a5527368e03436 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sun, 17 Sep 2023 16:49:30 -0700 Subject: [PATCH 059/122] [RUNTIME][INTERPRETER] now also override __str__ method for tensors (#2325) --- python/triton/runtime/interpreter.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/triton/runtime/interpreter.py b/python/triton/runtime/interpreter.py index 218208d86aa7..e066ea082926 100644 --- a/python/triton/runtime/interpreter.py +++ b/python/triton/runtime/interpreter.py @@ -370,6 +370,8 @@ def _patch_lang_tensor(tensor, builder): patch_attr(tensor, name, member, builder) tensor.__index__ = lambda self: int(self.handle.data) tensor.__bool__ = lambda self: True + tensor.__str__ = lambda self: str(self.handle.data) + tensor.__getitem__ = lambda self, slices: self.handle.data.__getitem__(slices) def _patch_lang_core(lang, builder): From a9ae9886dc3dada09668ca44c53b05527fe5f701 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Mon, 18 Sep 2023 17:22:06 +0200 Subject: [PATCH 060/122] Integration fixes for llvm/llvm-project#66512 (#2328) Some duplicate functions on `scf.for` have been removed in llvm/llvm-project#66512. This PR works with and without llvm/llvm-project#66512. --- lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp | 10 +++++----- lib/Dialect/TritonGPU/Transforms/Pipeline.cpp | 2 +- lib/Dialect/TritonGPU/Transforms/Prefetch.cpp | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp b/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp index 2393599143be..c6ec9673cf1b 100644 --- a/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp +++ b/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp @@ -325,9 +325,9 @@ class RewriteTensorPointerPass Operation *rewriteForOp(OpBuilder &builder, scf::ForOp op, std::stack &eraser) { // Generate new iteration operands and set rewrited information - SmallVector oldIterOperands = op.getIterOperands(); - SmallVector newIterOperands = op.getIterOperands(); - for (unsigned i = 0, oldI = 0, size = op.getNumIterOperands(); i < size; + SmallVector oldIterOperands = llvm::to_vector(op.getInitArgs()); + SmallVector newIterOperands = llvm::to_vector(op.getInitArgs()); + for (unsigned i = 0, oldI = 0, size = op.getInitArgs().size(); i < size; ++i, ++oldI) { if (!triton::isTensorPointerType(newIterOperands[i].getType())) continue; @@ -350,7 +350,7 @@ class RewriteTensorPointerPass // mapping. It may refer to a value in the old loop, but we will rewrite it // later IRMapping mapping; - for (unsigned i = 0, oldI = 0; oldI < op.getNumIterOperands(); + for (unsigned i = 0, oldI = 0, sz = op.getInitArgs().size(); oldI < sz; ++i, ++oldI) { auto oldRegionIterArg = op.getRegionIterArg(oldI); if (triton::isTensorPointerType(oldRegionIterArg.getType())) { @@ -377,7 +377,7 @@ class RewriteTensorPointerPass } // Replace later usages - assert(op.getNumResults() == op.getNumIterOperands()); + assert(op.getNumResults() == op.getInitArgs().size()); for (unsigned i = 0, oldI = 0; oldI < op.getNumResults(); ++i, ++oldI) { auto oldResult = op.getResult(oldI); if (triton::isTensorPointerType(oldResult.getType())) { diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp index ef9f60e1b1e5..0a370100fc9a 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp @@ -1007,7 +1007,7 @@ SmallVector LoopPipeliner::collectNewLoopArgs() { // We need this to update operands for yield // original block arg => new arg's idx SmallVector newLoopArgs; - for (auto v : forOp.getIterOperands()) + for (auto v : forOp.getInitArgs()) newLoopArgs.push_back(v); bufferIdx = newLoopArgs.size(); diff --git a/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp b/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp index 07e982dbf65d..a597ada6ca49 100644 --- a/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp @@ -269,7 +269,7 @@ scf::ForOp Prefetcher::createNewForOp() { OpBuilder builder(forOp); SmallVector loopArgs; - for (auto v : forOp.getIterOperands()) + for (auto v : forOp.getInitArgs()) loopArgs.push_back(v); for (Value dot : dots) { loopArgs.push_back( From 307b5caa491dc9c10274400fb400677cd98de4a1 Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Mon, 18 Sep 2023 17:45:05 -0400 Subject: [PATCH 061/122] [BACKEND] Fix scan issues on repetitive warps and improve perf when there's a single warp on the axis (#2330) 1. On the axis, using `getAxisNumWarpsWithUniqueData` instead of getting the raw number of warps to avoid communication among warps that handle the same piece of data. 2. When there's a single warp on the axis, using warp Intrinsics for communication and skip shared memory. Need a follow up PR for code clean up. --- include/triton/Analysis/Utility.h | 2 + lib/Analysis/Utility.cpp | 12 ++ .../TritonGPUToLLVM/ScanOpToLLVM.cpp | 161 +++++++++++++++--- lib/Conversion/TritonGPUToLLVM/Utility.cpp | 19 ++- lib/Conversion/TritonGPUToLLVM/Utility.h | 4 + python/test/unit/operators/test_inductor.py | 21 +++ 6 files changed, 195 insertions(+), 24 deletions(-) diff --git a/include/triton/Analysis/Utility.h b/include/triton/Analysis/Utility.h index 0af8eceaade1..8ad32a30c992 100644 --- a/include/triton/Analysis/Utility.h +++ b/include/triton/Analysis/Utility.h @@ -88,6 +88,8 @@ class ScanLoweringHelper { unsigned getNonAxisNumThreadsPerCTA(); // Return the number of warps per CTA along axis dim. unsigned getAxisNumWarps(); + // Return the number of warps per CTA along axis dim with unique data. + unsigned getAxisNumWarpsWithUniqueData(); // Return the number of threads per warp along axis dim. unsigned getAxisNumThreadsPerWarp(); // Return the number of blocks along axis dim. diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 6b4141170042..ec9ffaab9ffe 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -237,11 +237,20 @@ unsigned ScanLoweringHelper::getNonAxisNumThreadsPerCTA() { unsigned numParallelWarpsPerCTA = product(warpsPerCTA); return numParallelThreadsPerWarp * numParallelWarpsPerCTA; } + unsigned ScanLoweringHelper::getAxisNumWarps() { auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcEncoding); return warpsPerCTA[getAxis()]; } +unsigned ScanLoweringHelper::getAxisNumWarpsWithUniqueData() { + auto type = scanOp.getOperand(0).getType().cast(); + auto shape = type.getShape(); + auto warpsPerCTA = + triton::gpu::getWarpsPerCTAWithUniqueData(srcEncoding, shape); + return warpsPerCTA[getAxis()]; +} + unsigned ScanLoweringHelper::getAxisNumBlocks() { auto type = scanOp.getOperand(0).getType().cast(); auto sizePerThreads = triton::gpu::getSizePerThread(srcEncoding); @@ -282,6 +291,9 @@ bool ScanLoweringHelper::isSupported() { } unsigned ScanLoweringHelper::getScratchSizeInBytes() { + unsigned axisNumWarps = getAxisNumWarpsWithUniqueData(); + if (axisNumWarps == 1) + return 0; auto type = scanOp.getOperand(0).getType().cast(); unsigned elementSizeInBytes = type.getElementTypeBitWidth() / 8; auto mod = scanOp->getParentOfType(); diff --git a/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp index 42db21bd05a6..285118be3456 100644 --- a/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp @@ -7,6 +7,7 @@ using namespace mlir::triton; using ::mlir::LLVM::delinearize; using ::mlir::LLVM::linearize; +using ::mlir::LLVM::shflIdxSync; using ::mlir::LLVM::shflUpSync; using ::mlir::LLVM::storeShared; @@ -41,7 +42,6 @@ static void scanThreadContiguousElements(SmallVector &srcValues, // contiguous in srcValues. Keep track of what elements belong to the same // chunk of contiguous elements. unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread(); - unsigned parallelElementsPerThread = helper.getAxisNumElementsPerThread(); unsigned numChunks = srcValues.size() / scanElementsPerThreads; unsigned stride = helper.getAxisElementStride(); SmallVector accs(numChunks); @@ -98,7 +98,7 @@ static void storeWarpAccumulator(SmallVector &srcValues, unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread(); unsigned scanDim = helper.getAxisNumThreadsPerWarp(); unsigned numParallelLane = helper.getNonAxisNumThreadsPerCTA(); - unsigned numWarps = helper.getAxisNumWarps(); + unsigned axisNumWarps = helper.getAxisNumWarpsWithUniqueData(); unsigned chunkId = 0; unsigned elementStride = helper.getAxisElementStride(); for (unsigned srcIndex = 0; srcIndex < srcValues.size(); srcIndex++) { @@ -109,7 +109,7 @@ static void storeWarpAccumulator(SmallVector &srcValues, Value lastElement = srcValues[srcIndex]; Value mask = icmp_eq(laneId, i32_val(scanDim - 1)); Value index = add(parallelLaneId, mul(warpId, i32_val(numParallelLane))); - index = add(index, i32_val(chunkId * numParallelLane * numWarps)); + index = add(index, i32_val(chunkId * numParallelLane * axisNumWarps)); Value writePtr = gep(baseSharedMemPtr.getType(), baseSharedMemPtr, index); storeShared(rewriter, loc, writePtr, lastElement, mask); chunkId++; @@ -128,11 +128,11 @@ static void AddPartialReduce(SmallVector &srcValues, Value parallelLaneId) { Location loc = helper.getLoc(); unsigned numParallelLane = helper.getNonAxisNumThreadsPerCTA(); - unsigned numWarps = helper.getAxisNumWarps(); unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread(); unsigned parallelElementsPerThread = helper.getNonAxisNumElementsPerThread(); unsigned elementStride = helper.getAxisElementStride(); unsigned threadStride = helper.getAxisThreadStride(); + unsigned axisNumWarps = helper.getAxisNumWarpsWithUniqueData(); Value maskFirstWarp = icmp_eq(warpId, i32_val(0)); Value maskFirstLane = icmp_eq(laneIdAxis, i32_val(0)); Value maskFirstThread = and_(maskFirstWarp, maskFirstLane); @@ -164,9 +164,10 @@ static void AddPartialReduce(SmallVector &srcValues, unsigned accumulatorIndex = chunkId % parallelElementsPerThread + parallelBlockId * parallelElementsPerThread; Accumulator &accumulator = accumulators[accumulatorIndex]; - for (unsigned i = 0; i < numWarps; ++i) { - Value index = add(parallelLaneId, - i32_val(numParallelLane * (i + chunkId * numWarps))); + unsigned axisBlockId = (blockId / blockStride) % numScanBlocks; + for (unsigned i = 0; i < axisNumWarps; ++i) { + Value index = add(parallelLaneId, i32_val(numParallelLane * + (i + chunkId * axisNumWarps))); Value ptr = gep(sharedMemoryPtr.getType(), sharedMemoryPtr, index); Value partialReduce = load(ptr); if (!accumulator.acc) { @@ -182,7 +183,6 @@ static void AddPartialReduce(SmallVector &srcValues, } Value temp = srcValues[srcIndex]; accumulate(rewriter, helper.getCombineOp(), temp, accumulator.maskedAcc); - unsigned axisBlockId = (blockId / blockStride) % numScanBlocks; if (axisBlockId == 0) { // For the first warp and first chunk we don't have anything to // accumulate. @@ -211,6 +211,75 @@ static void AddPartialReduce(SmallVector &srcValues, } } +static void AddPartialReduceOneWarp(SmallVector &srcValues, + ConversionPatternRewriter &rewriter, + ScanLoweringHelper &helper, Value warpId, + Value laneIdAxis, Value laneIdLast) { + Location loc = helper.getLoc(); + unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread(); + unsigned parallelElementsPerThread = helper.getNonAxisNumElementsPerThread(); + unsigned elementStride = helper.getAxisElementStride(); + unsigned threadStride = helper.getAxisThreadStride(); + unsigned axisNumWarps = helper.getAxisNumWarpsWithUniqueData(); + unsigned numParallelLane = helper.getNonAxisNumThreadsPerCTA(); + unsigned scanDim = helper.getAxisNumThreadsPerWarp(); + Value maskFirstWarp = icmp_eq(warpId, i32_val(0)); + Value maskFirstLane = icmp_eq(laneIdAxis, i32_val(0)); + Value maskFirstThread = and_(maskFirstWarp, maskFirstLane); + unsigned numScanBlocks = helper.getAxisNumBlocks(); + unsigned numParallelBlocks = helper.getNonAxisNumBlocks(); + assert(numScanBlocks * numParallelBlocks * parallelElementsPerThread * + scanElementsPerThreads == + srcValues.size()); + SmallVector accumulators(numParallelBlocks * + parallelElementsPerThread); + unsigned chunkId = 0; + unsigned blockStride = helper.getAxisBlockStride(); + for (unsigned srcIndex = 0; srcIndex < srcValues.size(); srcIndex++) { + unsigned elementIdx = (srcIndex / elementStride) % scanElementsPerThreads; + // Only consider the last element of each contiguous chunk of elements. + if (elementIdx != scanElementsPerThreads - 1) + continue; + unsigned blockId = chunkId / parallelElementsPerThread; + unsigned parallelBlockId = + blockId % blockStride + + ((blockId / blockStride) / numScanBlocks) * blockStride; + unsigned accumulatorIndex = chunkId % parallelElementsPerThread + + parallelBlockId * parallelElementsPerThread; + Value &accumulator = accumulators[accumulatorIndex]; + unsigned axisBlockId = (blockId / blockStride) % numScanBlocks; + if (axisBlockId == 0) // First chunk and first block + accumulator = srcValues[srcIndex]; + else + accumulate(rewriter, helper.getCombineOp(), srcValues[srcIndex], + accumulator); + // Update the rest of the contiguous elements. + Value lastElement = srcValues[srcIndex]; + if (scanDim > 1) { + lastElement = + shflUpSync(loc, rewriter, srcValues[srcIndex], threadStride); + lastElement = select(maskFirstLane, accumulator, lastElement); + if (numScanBlocks > 1) + // Update accumulator with the value from the last lane. + accumulator = + shflIdxSync(loc, rewriter, srcValues[srcIndex], laneIdLast); + } + for (unsigned i = 1; i < scanElementsPerThreads; ++i) { + Value laneValue = srcValues[srcIndex - i * elementStride]; + accumulate(rewriter, helper.getCombineOp(), laneValue, lastElement); + if (axisBlockId == 0) + // For the first warp and first chunk we don't have anything to + // accumulate. + laneValue = select(maskFirstThread, + srcValues[srcIndex - i * elementStride], laneValue); + srcValues[srcIndex - i * elementStride] = laneValue; + } + // For the next chunk start back from the value containing the + // accumulated value of all the warps. + chunkId++; + } +} + namespace { struct ScanOpConversion : public ConvertTritonGPUOpToLLVMPattern { @@ -227,6 +296,12 @@ struct ScanOpConversion } private: + SmallVector getMultiDimLaneId(ConversionPatternRewriter &rewriter, + ScanLoweringHelper &helper, + Value laneId) const; + SmallVector getMultiDimWarpId(ConversionPatternRewriter &rewriter, + ScanLoweringHelper &helper, + Value warpId) const; std::tuple getDelinearizedIds(ConversionPatternRewriter &rewriter, ScanLoweringHelper &helper, Value laneId, @@ -235,6 +310,34 @@ struct ScanOpConversion ConversionPatternRewriter &rewriter) const; }; +SmallVector +ScanOpConversion::getMultiDimLaneId(ConversionPatternRewriter &rewriter, + ScanLoweringHelper &helper, + Value laneId) const { + auto loc = helper.getLoc(); + unsigned axis = helper.getAxis(); + auto srcEncoding = helper.getEncoding(); + + auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcEncoding); + auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcEncoding); + auto order = triton::gpu::getOrder(srcEncoding); + return delinearize(rewriter, loc, laneId, threadsPerWarp, order); +} + +SmallVector +ScanOpConversion::getMultiDimWarpId(ConversionPatternRewriter &rewriter, + ScanLoweringHelper &helper, + Value warpId) const { + auto loc = helper.getLoc(); + unsigned axis = helper.getAxis(); + auto srcEncoding = helper.getEncoding(); + + auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcEncoding); + auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcEncoding); + auto order = triton::gpu::getOrder(srcEncoding); + return delinearize(rewriter, loc, warpId, warpsPerCTA, order); +} + // Break up the threadId into lane and warp id along the scan dimension and // compute a flat id for the parallel dimensions. std::tuple @@ -290,6 +393,9 @@ ScanOpConversion::emitFastScan(triton::ScanOp op, triton::ScanOpAdaptor adaptor, getDelinearizedIds(rewriter, helper, laneId, warpId); auto input = adaptor.getOperands()[0]; auto type = op.getOperand(0).getType().cast(); + auto axisNumWarps = helper.getAxisNumWarpsWithUniqueData(); + auto axisNumThreads = helper.getAxisNumThreadsPerWarp(); + warpIdAxis = urem(warpIdAxis, i32_val(axisNumWarps)); SmallVector srcValues = getTypeConverter()->unpackLLElements(loc, input, rewriter, type); @@ -299,18 +405,33 @@ ScanOpConversion::emitFastScan(triton::ScanOp op, triton::ScanOpAdaptor adaptor, // elements. warpScan(srcValues, rewriter, helper, laneIdAxis); - // Store the partial reducing for each warp into shared memory. - Type elemPtrTys = LLVM::LLVMPointerType::get(srcValues[0].getType(), 3); - Value baseSharedMemPtr = bitcast( - getSharedMemoryBase(loc, rewriter, op.getOperation()), elemPtrTys); - storeWarpAccumulator(srcValues, rewriter, helper, laneIdAxis, warpIdAxis, - baseSharedMemPtr, flatIdParallel); - barrier(); - // Read back the partial reduction of each warp and accumulate them based on - // warpId. Then update each chunk of contiguous elements by adding the - // accumulated value from the previous lane. - AddPartialReduce(srcValues, rewriter, helper, baseSharedMemPtr, warpIdAxis, - laneIdAxis, flatIdParallel); + if (axisNumWarps > 1) { + // Slow path for the case where there are multiple warps with unique data on + // the axis. + Type elemPtrTys = LLVM::LLVMPointerType::get(srcValues[0].getType(), 3); + Value baseSharedMemPtr = bitcast( + getSharedMemoryBase(loc, rewriter, op.getOperation()), elemPtrTys); + // Store the partial reducing for each warp into shared memory. + storeWarpAccumulator(srcValues, rewriter, helper, laneIdAxis, warpIdAxis, + baseSharedMemPtr, flatIdParallel); + barrier(); + // Read back the partial reduction of each warp and accumulate them based on + // warpId. Then update each chunk of contiguous elements by adding the + // accumulated value from the previous lane. + AddPartialReduce(srcValues, rewriter, helper, baseSharedMemPtr, warpIdAxis, + laneIdAxis, flatIdParallel); + } else if (srcValues.size() > 1) { + // Fast path for the case where there is only one warp with unique data on + // the axis. + unsigned scanDim = helper.getAxisNumThreadsPerWarp(); + auto multiDimLaneId = getMultiDimLaneId(rewriter, helper, laneId); + multiDimLaneId[helper.getAxis()] = i32_val(scanDim - 1); + auto threadsPerWarp = triton::gpu::getThreadsPerWarp(helper.getEncoding()); + auto laneIdLast = linearize(rewriter, loc, multiDimLaneId, threadsPerWarp, + triton::gpu::getOrder(helper.getEncoding())); + AddPartialReduceOneWarp(srcValues, rewriter, helper, warpIdAxis, laneIdAxis, + laneIdLast); + } // else axisNumWarps == 1 and srcValues.size() == 1, nothing to do. Value results = getTypeConverter()->packLLElements(loc, srcValues, rewriter, input.getType()); diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index 06d338685909..b2464c33e617 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -270,7 +270,7 @@ Value loadShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr, } static Value commonShflSync(Location loc, ConversionPatternRewriter &rewriter, - Value val, int i, const std::string &shuffleType, + Value val, Value i, const std::string &shuffleType, const std::string &clamp) { unsigned bits = val.getType().getIntOrFloatBitWidth(); @@ -291,7 +291,7 @@ static Value commonShflSync(Location loc, ConversionPatternRewriter &rewriter, auto &shfl = builder.create("shfl.sync")->o(shuffleType).o("b32"); auto *dOpr = builder.newOperand("=r"); auto *aOpr = builder.newOperand(val, "r"); - auto *bOpr = builder.newConstantOperand(i); + auto *bOpr = builder.newOperand(i, "r"); auto *cOpr = builder.newConstantOperand(clamp); auto *maskOpr = builder.newConstantOperand("0xffffffff"); shfl(dOpr, aOpr, bOpr, cOpr, maskOpr); @@ -300,13 +300,24 @@ static Value commonShflSync(Location loc, ConversionPatternRewriter &rewriter, Value shflSync(Location loc, ConversionPatternRewriter &rewriter, Value val, int i) { - return commonShflSync(loc, rewriter, val, i, "bfly", "0x1f"); + return commonShflSync(loc, rewriter, val, i32_val(i), "bfly", "0x1f"); } Value shflUpSync(Location loc, ConversionPatternRewriter &rewriter, Value val, int i) { - return commonShflSync(loc, rewriter, val, i, "up", "0x0"); + return commonShflSync(loc, rewriter, val, i32_val(i), "up", "0x0"); } + +Value shflIdxSync(Location loc, ConversionPatternRewriter &rewriter, Value val, + int i) { + return commonShflSync(loc, rewriter, val, i32_val(i), "idx", "0x1f"); +} + +Value shflIdxSync(Location loc, ConversionPatternRewriter &rewriter, Value val, + Value i) { + return commonShflSync(loc, rewriter, val, i, "idx", "0x1f"); +} + Value getSRegValue(OpBuilder &b, Location loc, const std::string &sRegStr) { PTXBuilder builder; auto &mov = builder.create("mov")->o("u32"); diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.h b/lib/Conversion/TritonGPUToLLVM/Utility.h index 9dd072d0c942..43faa333e05a 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.h +++ b/lib/Conversion/TritonGPUToLLVM/Utility.h @@ -327,6 +327,10 @@ Value shflSync(Location loc, ConversionPatternRewriter &rewriter, Value val, int i); Value shflUpSync(Location loc, ConversionPatternRewriter &rewriter, Value val, int i); +Value shflIdxSync(Location loc, ConversionPatternRewriter &rewriter, Value val, + int i); +Value shflIdxSync(Location loc, ConversionPatternRewriter &rewriter, Value val, + Value i); Value getSRegValue(OpBuilder &b, Location loc, const std::string &sRegStr); Value addStringToModule(Location loc, ConversionPatternRewriter &rewriter, StringRef key, StringRef content); diff --git a/python/test/unit/operators/test_inductor.py b/python/test/unit/operators/test_inductor.py index fa157d2c9771..835f8c14281c 100644 --- a/python/test/unit/operators/test_inductor.py +++ b/python/test/unit/operators/test_inductor.py @@ -1,3 +1,4 @@ +import pytest import torch import triton @@ -153,3 +154,23 @@ def triton_(in_ptr0, out_ptr0, XBLOCK: tl.constexpr): out_ref[:, :, 0::7, 1:7] = 2 / 3 out_ref[:, :, 0::7, 0::7] = 4 / 9 torch.testing.assert_close(out, out_ref) + + +@pytest.mark.parametrize("RBLOCK", [32, 64, 128]) +@pytest.mark.parametrize("num_warps", [1, 4]) +def test_scan2d_broadcast(RBLOCK, num_warps): + @triton.jit(debug=True) + def fn(in_ptr, out_ptr, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr): + rindex = tl.arange(0, RBLOCK)[None, :] + xindex = tl.arange(0, XBLOCK)[:, None] + data = tl.load(in_ptr + rindex) + scan = tl.cumsum(data, 1) + expected_max = tl.sum(data, 1) + tl.device_assert(scan <= expected_max) + tl.store(out_ptr + xindex * RBLOCK + rindex, scan) + + XBLOCK = 4 + input = torch.randint(0, 10, (1, RBLOCK), dtype=torch.int32, device='cuda') + output = torch.empty((XBLOCK, RBLOCK), dtype=torch.int32, device='cuda') + fn[(1,)](input, output, XBLOCK, RBLOCK, num_warps=num_warps) + torch.testing.assert_allclose(output, input.cumsum(1).broadcast_to((XBLOCK, RBLOCK))) From 73dae775df3cc72761f633839e199a28937a9f42 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Mon, 18 Sep 2023 15:07:41 -0700 Subject: [PATCH 062/122] [DOCS] improved fused attention tutorial (bwd pass) (#2332) --- .gitignore | 1 - python/tutorials/06-fused-attention.py | 689 ++++++++++++++++--------- 2 files changed, 454 insertions(+), 236 deletions(-) diff --git a/.gitignore b/.gitignore index fed9cbf4ea64..e85433df82f7 100644 --- a/.gitignore +++ b/.gitignore @@ -35,6 +35,5 @@ docs/_build/ docs/python-api/generated/ docs/dialects/ docs/getting-started/tutorials -python/tutorials/ !python/tutorials/*.py !python/tutorials/*.rst diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index 996235c793be..eac1330c4077 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -3,11 +3,11 @@ =============== This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf) +Credits: OpenAI kernel team Extra Credits: - Original flash attention paper (https://arxiv.org/abs/2205.14135) - Rabe and Staats (https://arxiv.org/pdf/2112.05682v2.pdf) -- Adam P. Goucher for simplified vector math """ @@ -19,221 +19,416 @@ @triton.jit -def max_fn(x, y): - return tl.math.max(x, y) +def _attn_fwd_inner( + acc, l_i, m_i, q, + K_block_ptr, V_block_ptr, + start_m, qk_scale, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + STAGE: tl.constexpr, + offs_m: tl.constexpr, + offs_n: tl.constexpr, +): + # range of values handled by this stage + if STAGE == 1: + lo, hi = 0, start_m * BLOCK_M + else: + lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M + lo = tl.multiple_of(lo, BLOCK_M) + K_block_ptr = tl.advance(K_block_ptr, (0, lo)) + V_block_ptr = tl.advance(V_block_ptr, (lo, 0)) + # loop over k, v and update accumulator + for start_n in range(lo, hi, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(K_block_ptr) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + if STAGE == 2: + mask = offs_m[:, None] >= (start_n + offs_n[None, :]) + qk = qk * qk_scale + tl.where(mask, 0, -1.0e6) + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk -= m_ij[:, None] + else: + m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale) + qk = qk * qk_scale - m_ij[:, None] + p = tl.math.exp2(qk) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + alpha = tl.math.exp2(m_i - m_ij) + l_i = l_i * alpha + l_ij + # -- update output accumulator -- + acc = acc * alpha[:, None] + # update acc + v = tl.load(V_block_ptr) + acc += tl.dot(p.to(tl.float16), v) + # update m_i and l_i + m_i = m_ij + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + return acc, l_i, m_i @triton.jit -def _fwd_kernel( - Q, K, V, sm_scale, - L, - Out, +def _attn_fwd( + Q, K, V, sm_scale, M, Out, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn, stride_oz, stride_oh, stride_om, stride_on, - Z, H, N_CTX, P_SEQ, - BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, + Z, H, + N_CTX: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, - IS_CAUSAL: tl.constexpr, + STAGE: tl.constexpr, ): start_m = tl.program_id(0) off_hz = tl.program_id(1) - q_offset = off_hz * stride_qh - kv_offset = off_hz * stride_kh + off_z = off_hz // H + off_h = off_hz % H + qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh + + # block pointers Q_block_ptr = tl.make_block_ptr( - base=Q + q_offset, + base=Q + qvk_offset, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_qm, stride_qk), offsets=(start_m * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0) + order=(1, 0), + ) + V_block_ptr = tl.make_block_ptr( + base=V + qvk_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_vk, stride_vn), + offsets=(0, 0), + block_shape=(BLOCK_N, BLOCK_DMODEL), + order=(1, 0), ) K_block_ptr = tl.make_block_ptr( - base=K + kv_offset, - shape=(BLOCK_DMODEL, N_CTX + P_SEQ), + base=K + qvk_offset, + shape=(BLOCK_DMODEL, N_CTX), strides=(stride_kk, stride_kn), offsets=(0, 0), block_shape=(BLOCK_DMODEL, BLOCK_N), - order=(0, 1) + order=(0, 1), ) - V_block_ptr = tl.make_block_ptr( - base=V + kv_offset, - shape=(N_CTX + P_SEQ, BLOCK_DMODEL), - strides=(stride_vk, stride_vn), - offsets=(0, 0), - block_shape=(BLOCK_N, BLOCK_DMODEL), - order=(1, 0) + O_block_ptr = tl.make_block_ptr( + base=Out + qvk_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), ) # initialize offsets offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = tl.arange(0, BLOCK_N) # initialize pointer to m and l m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - # scale sm_scale by log_2(e) and use - # 2^x instead of exp in the loop because CSE and LICM - # don't work as expected with `exp` in the loop - qk_scale = sm_scale * 1.44269504 + # load scales + qk_scale = sm_scale + qk_scale *= 1.44269504 # 1/log(2) # load q: it will stay in SRAM throughout q = tl.load(Q_block_ptr) - q = (q * qk_scale).to(tl.float16) - # loop over k, v and update accumulator - lo = 0 - hi = P_SEQ + (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX + P_SEQ - for start_n in range(lo, hi, BLOCK_N): - # -- load k, v -- - k = tl.load(K_block_ptr) - v = tl.load(V_block_ptr) - # -- compute qk --- - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float16) - if IS_CAUSAL: - qk = tl.where(P_SEQ + offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) - qk += tl.dot(q, k, out_dtype=tl.float16) - # -- compute scaling constant --- - m_i_new = tl.maximum(m_i, tl.max(qk, 1)) - alpha = tl.math.exp2(m_i - m_i_new) - p = tl.math.exp2(qk - m_i_new[:, None]) - # -- scale and update acc -- - acc *= alpha[:, None] - acc += tl.dot(p.to(tl.float16), v) - # -- update m_i and l_i -- - l_i = l_i * alpha + tl.sum(p, 1) - m_i = m_i_new - # update pointers - K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) - V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) - # write back l and m + # stage 1: off-band + if STAGE & 1: + acc, l_i, m_i = _attn_fwd_inner( + acc, l_i, m_i, q, K_block_ptr, V_block_ptr, + start_m, qk_scale, + BLOCK_M, BLOCK_DMODEL, BLOCK_N, + 1, offs_m, offs_n, + ) + # barrier makes it easier for compielr to schedule the + # two loops independently + tl.debug_barrier() + # stage 2: on-band + if STAGE & 2: + acc, l_i, m_i = _attn_fwd_inner( + acc, l_i, m_i, q, K_block_ptr, V_block_ptr, + start_m, qk_scale, + BLOCK_M, BLOCK_DMODEL, BLOCK_N, + 2, offs_m, offs_n, + ) + # epilogue + m_i += tl.math.log2(l_i) acc = acc / l_i[:, None] - l_ptrs = L + off_hz * N_CTX + offs_m - tl.store(l_ptrs, m_i + tl.math.log2(l_i)) - # write back O - O_block_ptr = tl.make_block_ptr( - base=Out + q_offset, - shape=(N_CTX, BLOCK_DMODEL), - strides=(stride_om, stride_on), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0) - ) - tl.store(O_block_ptr, acc.to(tl.float16)) + m_ptrs = M + off_hz * N_CTX + offs_m + tl.store(m_ptrs, m_i) + tl.store(O_block_ptr, acc.to(Out.type.element_ty)) @triton.jit -def _bwd_preprocess( - Out, DO, +def _attn_bwd_preprocess( + O, DO, Delta, + Z, H, N_CTX, BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr, ): off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) + off_hz = tl.program_id(1) off_n = tl.arange(0, D_HEAD) # load - o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) - do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) - # compute + o = tl.load(O + off_hz * D_HEAD * N_CTX + off_m[:, None] * D_HEAD + off_n[None, :]) + do = tl.load(DO + off_hz * D_HEAD * N_CTX + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) delta = tl.sum(o * do, axis=1) # write-back - tl.store(Delta + off_m, delta) + tl.store(Delta + off_hz * N_CTX + off_m, delta) + + +# The main inner-loop logic for computing dK and dV. +@triton.jit +def _attn_bwd_dkdv( + dk, dv, + Q, k, v, sm_scale, + DO, + M, D, + # shared by Q/K/V/DO. + stride_tok, stride_d, + H, N_CTX, + BLOCK_M1: tl.constexpr, + BLOCK_N1: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + # Filled in by the wrapper. + start_n, start_m, num_steps, + MASK: tl.constexpr, +): + offs_m = start_m + tl.arange(0, BLOCK_M1) + offs_n = start_n + tl.arange(0, BLOCK_N1) + offs_k = tl.arange(0, BLOCK_DMODEL) + qT_ptrs = Q + offs_m[None, :] * stride_tok + offs_k[:, None] * stride_d + do_ptrs = DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + curr_m = start_m + step_m = BLOCK_M1 + for blk_idx in range(num_steps): + qT = tl.load(qT_ptrs) + # Load m before computing qk to reduce pipeline stall. + offs_m = curr_m + tl.arange(0, BLOCK_M1) + m = tl.load(M + offs_m) + qkT = tl.dot(k, qT) + pT = tl.math.exp2(qkT - m[None, :]) + # Autoregressive masking. + if MASK: + mask = (offs_m[None, :] >= offs_n[:, None]) + pT = tl.where(mask, pT, 0.0) + do = tl.load(do_ptrs) + # Compute dV. + ppT = pT + ppT = ppT.to(tl.float16) + dv += tl.dot(ppT, do) + # D (= delta) is pre-divided by ds_scale. + Di = tl.load(D + offs_m) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do)).to(tl.float32) + dsT = pT * (dpT - Di[None, :]) + dsT = dsT.to(tl.float16) + dk += tl.dot(dsT, tl.trans(qT)) + # Increment pointers. + curr_m += step_m + qT_ptrs += step_m * stride_tok + do_ptrs += step_m * stride_tok + return dk, dv +# the main inner-loop logic for computing dQ @triton.jit -def _bwd_kernel( - Q, K, V, sm_scale, Out, DO, +def _attn_bwd_dq( + dq, q, K, V, + do, m, D, + # shared by Q/K/V/DO. + stride_tok, stride_d, + H, N_CTX, + BLOCK_M2: tl.constexpr, + BLOCK_N2: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + # Filled in by the wrapper. + start_m, start_n, num_steps, + MASK: tl.constexpr, +): + offs_m = start_m + tl.arange(0, BLOCK_M2) + offs_n = start_n + tl.arange(0, BLOCK_N2) + offs_k = tl.arange(0, BLOCK_DMODEL) + kT_ptrs = K + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d + vT_ptrs = V + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d + # D (= delta) is pre-divided by ds_scale. + Di = tl.load(D + offs_m) + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + curr_n = start_n + step_n = BLOCK_N2 + for blk_idx in range(num_steps): + kT = tl.load(kT_ptrs) + vT = tl.load(vT_ptrs) + qk = tl.dot(q, kT) + p = tl.math.exp2(qk - m) + # Autoregressive masking. + if MASK: + offs_n = curr_n + tl.arange(0, BLOCK_N2) + mask = (offs_m[:, None] >= offs_n[None, :]) + p = tl.where(mask, p, 0.0) + # Compute dP and dS. + dp = tl.dot(do, vT).to(tl.float32) + ds = p * (dp - Di[:, None]) + ds = ds.to(tl.float16) + # Compute dQ. + # NOTE: We need to de-scale dq in the end, because kT was pre-scaled. + dq += tl.dot(ds, tl.trans(kT)) + # Increment pointers. + curr_n += step_n + kT_ptrs += step_n * stride_tok + vT_ptrs += step_n * stride_tok + return dq + + +@triton.jit +def _attn_bwd( + Q, K, V, sm_scale, + DO, DQ, DK, DV, - L, - D, - stride_qz, stride_qh, stride_qm, stride_qk, - stride_kz, stride_kh, stride_kn, stride_kk, - stride_vz, stride_vh, stride_vk, stride_vn, - Z, H, N_CTX, P_SEQ, - num_block_q, num_block_kv, - BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - CAUSAL: tl.constexpr, + M, D, + # shared by Q/K/V/DO. + stride_z, stride_h, stride_tok, stride_d, + H, N_CTX, + BLOCK_M1: tl.constexpr, + BLOCK_N1: tl.constexpr, + BLOCK_M2: tl.constexpr, + BLOCK_N2: tl.constexpr, + BLK_SLICE_FACTOR: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, ): - off_hz = tl.program_id(0) - off_z = off_hz // H - off_h = off_hz % H - qk_scale = sm_scale * 1.44269504 + LN2: tl.constexpr = 0.6931471824645996 # = ln(2) + + bhid = tl.program_id(2) + off_chz = (bhid * N_CTX).to(tl.int64) + adj = (stride_h * (bhid % H) + stride_z * (bhid // H)).to(tl.int64) + pid = tl.program_id(0) + # offset pointers for batch/head - Q += off_z * stride_qz + off_h * stride_qh - K += off_z * stride_kz + off_h * stride_kh - V += off_z * stride_vz + off_h * stride_vh - DO += off_z * stride_qz + off_h * stride_qh - DQ += off_z * stride_qz + off_h * stride_qh - DK += off_z * stride_kz + off_h * stride_kh - DV += off_z * stride_vz + off_h * stride_vh - for start_n in range(0, num_block_kv): - if CAUSAL: - lo = tl.math.max(start_n * BLOCK_M - P_SEQ, 0) - else: - lo = 0 - # initialize row/col offsets - offs_qm = lo + tl.arange(0, BLOCK_M) - offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M) - offs_m = tl.arange(0, BLOCK_N) - offs_k = tl.arange(0, BLOCK_DMODEL) - # initialize pointers to value-like data - q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) - k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) - v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk) - do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) - dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) - # pointer to row-wise quantities in value-like data - D_ptrs = D + off_hz * N_CTX - l_ptrs = L + off_hz * N_CTX - # initialize dk amd dv - dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - # k and v stay in SRAM throughout - k = tl.load(k_ptrs) - v = tl.load(v_ptrs) - # loop over rows - for start_m in range(lo, num_block_q * BLOCK_M, BLOCK_M): - offs_m_curr = start_m + offs_m - # load q, k, v, do on-chip - q = tl.load(q_ptrs) - # recompute p = softmax(qk, dim=-1).T - if CAUSAL: - qk = tl.where(P_SEQ + offs_m_curr[:, None] >= (offs_n[None, :]), float(0.), float("-inf")) - else: - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, tl.trans(k)) - qk *= qk_scale - l_i = tl.load(l_ptrs + offs_m_curr) - p = tl.math.exp2(qk - l_i[:, None]) - # compute dv - do = tl.load(do_ptrs) - dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do) - # compute dp = dot(v, do) - Di = tl.load(D_ptrs + offs_m_curr) - dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None] - dp += tl.dot(do, tl.trans(v)) - # compute ds = p * (dp - delta[:, None]) - ds = p * dp * sm_scale - # compute dk = dot(ds.T, q) - dk += tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q) - # compute dq - dq = tl.load(dq_ptrs) - dq += tl.dot(ds.to(Q.dtype.element_ty), k) - tl.store(dq_ptrs, dq) - # increment pointers - dq_ptrs += BLOCK_M * stride_qm - q_ptrs += BLOCK_M * stride_qm - do_ptrs += BLOCK_M * stride_qm - # write-back - dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) - dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk) - tl.store(dk_ptrs, dk) + Q += adj + K += adj + V += adj + DO += adj + DQ += adj + DK += adj + DV += adj + M += off_chz + D += off_chz + + # load scales + offs_k = tl.arange(0, BLOCK_DMODEL) + + if (tl.program_id(1) == 0): + + # THIS BLOCK DOES DK/DV/DR: + + start_n = pid * BLOCK_N1 + start_m = start_n + + MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR + offs_n = start_n + tl.arange(0, BLOCK_N1) + + dv = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32) + + # load K and V: they stay in SRAM throughout the inner loop. + k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d) + v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d) + + num_steps = BLOCK_N1 // MASK_BLOCK_M1 + + dk, dv = _attn_bwd_dkdv(dk, dv, + Q, k, v, sm_scale, + DO, + M, D, + stride_tok, stride_d, + H, N_CTX, + MASK_BLOCK_M1, BLOCK_N1, BLOCK_DMODEL, + start_n, start_m, num_steps, + MASK=True, + ) + + start_m += num_steps * MASK_BLOCK_M1 + num_steps = (N_CTX - start_m) // BLOCK_M1 + + # Compute dK and dV for non-masked blocks. + dk, dv = _attn_bwd_dkdv(dk, dv, + Q, k, v, sm_scale, + DO, + M, D, + stride_tok, stride_d, + H, N_CTX, + BLOCK_M1, BLOCK_N1, BLOCK_DMODEL, + start_n, start_m, num_steps, + MASK=False, + ) + + dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d tl.store(dv_ptrs, dv) + # Write back dK. + dk *= sm_scale + dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d + tl.store(dk_ptrs, dk) + + else: + + # THIS BLOCK DOES DQ: + start_m = pid * BLOCK_M2 + end_n = start_m + BLOCK_M2 + + MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR + offs_m = start_m + tl.arange(0, BLOCK_M2) + + q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d) + dq = tl.zeros([BLOCK_M2, BLOCK_DMODEL], dtype=tl.float32) + do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d) + + m = tl.load(M + offs_m) + m = m[:, None] + + # Compute dQ for masked (diagonal) blocks. + # NOTE: This code scans each row of QK^T backward (from right to left, + # but inside each call to _attn_bwd_dq, from left to right), but that's + # not due to anything important. I just wanted to reuse the loop + # structure for dK & dV above as much as possible. + num_steps = BLOCK_M2 // MASK_BLOCK_N2 + dq = _attn_bwd_dq( + dq, q, K, V, + do, m, D, + stride_tok, stride_d, + H, N_CTX, + BLOCK_M2, MASK_BLOCK_N2, BLOCK_DMODEL, + start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps, + MASK=True, + ) + end_n -= num_steps * MASK_BLOCK_N2 + # stage 2 + num_steps = end_n // BLOCK_N2 + dq = _attn_bwd_dq( + dq, q, K, V, + do, m, D, + stride_tok, stride_d, + H, N_CTX, + BLOCK_M2, BLOCK_N2, BLOCK_DMODEL, + start_m, end_n - num_steps * BLOCK_N2, num_steps, + MASK=False, + ) + # Write back dQ. + dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d + dq *= LN2 + tl.store(dq_ptrs, dq) + empty = torch.empty(128, device="cuda") class _attention(torch.autograd.Function): - @staticmethod def forward(ctx, q, k, v, causal, sm_scale): # shape constraints @@ -246,76 +441,99 @@ def forward(ctx, q, k, v, causal, sm_scale): num_stages = 4 if Lk <= 64 else 3 num_warps = 4 grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1) - L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) - P_SEQ = 0 if q.shape[-2] == k.shape[-2] else k.shape[-2] - q.shape[-2] - _fwd_kernel[grid]( - q, k, v, sm_scale, - L, - o, + M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + _attn_fwd[grid]( + q, k, v, sm_scale, M, o, q.stride(0), q.stride(1), q.stride(2), q.stride(3), k.stride(0), k.stride(1), k.stride(2), k.stride(3), v.stride(0), v.stride(1), v.stride(2), v.stride(3), o.stride(0), o.stride(1), o.stride(2), o.stride(3), - q.shape[0], q.shape[1], q.shape[2], P_SEQ, - BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk, - IS_CAUSAL=causal, + q.shape[0], q.shape[1], + N_CTX=q.shape[2], + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_DMODEL=Lk, + STAGE=3, num_warps=num_warps, - num_stages=num_stages) + num_stages=num_stages, + ) - ctx.save_for_backward(q, k, v, o, L) + ctx.save_for_backward(q, k, v, o, M) ctx.grid = grid ctx.sm_scale = sm_scale ctx.BLOCK_DMODEL = Lk ctx.causal = causal - ctx.P_SEQ = P_SEQ return o @staticmethod def backward(ctx, do): - BLOCK = 128 - q, k, v, o, L = ctx.saved_tensors - do = do.contiguous() - dq = torch.zeros_like(q, dtype=torch.float32) + q, k, v, o, M = ctx.saved_tensors + assert do.is_contiguous() + assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride() + dq = torch.empty_like(q) dk = torch.empty_like(k) dv = torch.empty_like(v) - delta = torch.empty_like(L) - _bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )]( + BATCH, N_HEAD, N_CTX = q.shape[:3] + PRE_BLOCK = 128 + NUM_WARPS, NUM_STAGES = 4, 1 + BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32 + BLK_SLICE_FACTOR = 2 + RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2) + arg_k = k + arg_k = arg_k * (ctx.sm_scale * RCP_LN2) + PRE_BLOCK = 128 + assert N_CTX % PRE_BLOCK == 0 + pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD) + delta = torch.empty_like(M) + _attn_bwd_preprocess[pre_grid]( o, do, delta, - BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL, + BATCH, N_HEAD, N_CTX, + BLOCK_M=PRE_BLOCK, D_HEAD=ctx.BLOCK_DMODEL, ) - _bwd_kernel[(ctx.grid[1],)]( - q, k, v, ctx.sm_scale, - o, do, - dq, dk, dv, - L, delta, + grid = (N_CTX // BLOCK_N1, 2, BATCH * N_HEAD) + _attn_bwd[grid]( + q, arg_k, v, ctx.sm_scale, do, dq, dk, dv, + M, delta, q.stride(0), q.stride(1), q.stride(2), q.stride(3), - k.stride(0), k.stride(1), k.stride(2), k.stride(3), - v.stride(0), v.stride(1), v.stride(2), v.stride(3), - q.shape[0], q.shape[1], q.shape[2], ctx.P_SEQ, - ctx.grid[0], triton.cdiv(k.shape[2], BLOCK), - BLOCK_M=BLOCK, BLOCK_N=BLOCK, - BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8, - CAUSAL=ctx.causal, - num_stages=1, + N_HEAD, N_CTX, + BLOCK_M1=BLOCK_M1, BLOCK_N1=BLOCK_N1, + BLOCK_M2=BLOCK_M2, BLOCK_N2=BLOCK_N2, + BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, + BLOCK_DMODEL=ctx.BLOCK_DMODEL, + num_warps=NUM_WARPS, + num_stages=NUM_STAGES, ) + return dq, dk, dv, None, None attention = _attention.apply -@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD, P_SEQ', [(6, 9, 1024, 64, 128)]) -@pytest.mark.parametrize('causal', [False, True]) -def test_op(Z, H, N_CTX, D_HEAD, P_SEQ, causal, dtype=torch.float16): +@pytest.mark.parametrize("Z, H, N_CTX, D_HEAD", [(1, 2, 1024, 64)]) +@pytest.mark.parametrize("causal", [True]) +def test_op(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16): torch.manual_seed(20) - q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() - k = torch.empty((Z, H, N_CTX + P_SEQ, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() - v = torch.empty((Z, H, N_CTX + P_SEQ, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + q = ( + torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda") + .normal_(mean=0.0, std=0.5) + .requires_grad_() + ) + k = ( + torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda") + .normal_(mean=0.0, std=0.5) + .requires_grad_() + ) + v = ( + torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda") + .normal_(mean=0.0, std=0.5) + .requires_grad_() + ) sm_scale = 0.5 dout = torch.randn_like(q) # reference implementation - M = torch.tril(torch.ones((N_CTX, N_CTX + P_SEQ), device="cuda"), diagonal=P_SEQ) + M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) p = torch.matmul(q, k.transpose(2, 3)) * sm_scale if causal: p[:, :, M == 0] = float("-inf") @@ -342,33 +560,41 @@ def test_op(Z, H, N_CTX, D_HEAD, P_SEQ, causal, dtype=torch.float16): try: from flash_attn.flash_attn_interface import \ flash_attn_qkvpacked_func as flash_attn_func - FLASH_VER = 2 + HAS_FLASH = True except BaseException: - try: - from flash_attn.flash_attn_interface import flash_attn_func - FLASH_VER = 1 - except BaseException: - FLASH_VER = None -HAS_FLASH = FLASH_VER is not None + HAS_FLASH = False BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64 # vary seq length for fixed head and batch=4 -configs = [triton.testing.Benchmark( - x_names=['N_CTX'], - x_vals=[2**i for i in range(10, 15)], - line_arg='provider', - line_vals=['triton'] + (['flash'] if HAS_FLASH else []), - line_names=['Triton'] + ([f'Flash-{FLASH_VER}'] if HAS_FLASH else []), - styles=[('red', '-'), ('blue', '-')], - ylabel='ms', - plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}', - args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.float16, 'mode': mode, 'causal': causal} -) for mode in ['fwd', 'bwd'] for causal in [False, True]] +configs = [ + triton.testing.Benchmark( + x_names=["N_CTX"], + x_vals=[2**i for i in range(10, 15)], + line_arg="provider", + line_vals=["triton"] + (["flash"] if HAS_FLASH else []), + line_names=["Triton"] + (["Flash-2"] if HAS_FLASH else []), + styles=[("red", "-"), ("blue", "-")], + ylabel="ms", + plot_name=f"fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}", + args={ + "H": N_HEADS, + "BATCH": BATCH, + "D_HEAD": D_HEAD, + "dtype": torch.float16, + "mode": mode, + "causal": causal, + }, + ) + for mode in ["fwd", "bwd"] + for causal in [True] +] @triton.testing.perf_report(configs) -def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, causal, mode, provider, dtype=torch.float16, device="cuda"): - assert mode in ['fwd', 'bwd'] +def bench_flash_attention( + BATCH, H, N_CTX, D_HEAD, causal, mode, provider, dtype=torch.float16, device="cuda" +): + assert mode in ["fwd", "bwd"] warmup = 25 rep = 100 if provider == "triton": @@ -377,36 +603,29 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, causal, mode, provider, dtype v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) sm_scale = 1.3 fn = lambda: attention(q, k, v, causal, sm_scale) - if mode == 'bwd': + if mode == "bwd": o = fn() do = torch.randn_like(o) fn = lambda: o.backward(do, retain_graph=True) ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) if provider == "flash": - qkv = torch.randn((BATCH, N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True) - if FLASH_VER == 1: - lengths = torch.full((BATCH,), fill_value=N_CTX, device=device) - cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32) - cu_seqlens[1:] = lengths.cumsum(0) - qkv = qkv.reshape(BATCH * N_CTX, 3, H, D_HEAD) - fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=causal) - elif FLASH_VER == 2: - fn = lambda: flash_attn_func(qkv, causal=causal) - else: - raise ValueError(f'unknown {FLASH_VER = }') - if mode == 'bwd': + qkv = torch.randn( + (BATCH, N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True + ) + fn = lambda: flash_attn_func(qkv, causal=causal) + if mode == "bwd": o = fn() do = torch.randn_like(o) fn = lambda: o.backward(do, retain_graph=True) ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) - flops_per_matmul = 2. * BATCH * H * N_CTX * N_CTX * D_HEAD + flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD total_flops = 2 * flops_per_matmul if causal: total_flops *= 0.5 - if mode == 'bwd': + if mode == "bwd": total_flops *= 2.5 # 2.0(bwd) + 0.5(recompute) return total_flops / ms * 1e-9 # only works on post-Ampere GPUs right now -bench_flash_attention.run(save_path='.', print_data=True) +bench_flash_attention.run(save_path=".", print_data=True) From 3a848e272930bf04f5cf2db7269b1035b4742c89 Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Mon, 18 Sep 2023 15:08:19 -0700 Subject: [PATCH 063/122] [BACKEND] Relax patterns to move sink broadcast and hoist convert (#2331) Improve patterns that sync broadcast to reduce the arithmetic density and also hoist convert on top of expand_dims to do less work. This address comments in https://github.com/openai/triton/pull/2274 --- .../Triton/Transforms/ReorderBroadcast.cpp | 10 ++++++---- .../Transforms/RemoveLayoutConversions.cpp | 16 +++++++++++----- test/Triton/reorder-broadcast.mlir | 12 ++++++++++++ test/TritonGPU/combine.mlir | 4 +++- 4 files changed, 32 insertions(+), 10 deletions(-) diff --git a/lib/Dialect/Triton/Transforms/ReorderBroadcast.cpp b/lib/Dialect/Triton/Transforms/ReorderBroadcast.cpp index 6c0c9fcc9341..931777bfa3a6 100644 --- a/lib/Dialect/Triton/Transforms/ReorderBroadcast.cpp +++ b/lib/Dialect/Triton/Transforms/ReorderBroadcast.cpp @@ -116,18 +116,20 @@ struct MoveBroadcastAfterElementwisePattern auto operands = op->getOperands(); bool seenBroadcast = false; - Type srcType; + ArrayRef srcShape; for (auto operand : operands) { auto definingOp = operand.getDefiningOp(); if (!definingOp) { return mlir::failure(); } - + auto getSrcShape = [](triton::BroadcastOp b) { + return b.getSrc().getType().cast().getShape(); + }; if (auto broadcastOp = llvm::dyn_cast(definingOp)) { if (!seenBroadcast) { seenBroadcast = true; - srcType = broadcastOp.getSrc().getType(); - } else if (srcType != broadcastOp.getSrc().getType()) { + srcShape = getSrcShape(broadcastOp); + } else if (srcShape != getSrcShape(broadcastOp)) { // If the broadcast have different types we cannot re-order. return mlir::failure(); } diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp index 2d8ca362465a..b7f88948b982 100644 --- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp @@ -929,7 +929,7 @@ static void hoistConvertOnTopOfExtOrBroadcast(ConvertLayoutOp convertOp) { auto isExtOrBroadcastOp = [](Operation *op) { return isa(op); + triton::BroadcastOp, triton::ExpandDimsOp>(op); }; // 1. Take a backward slice of all the tensor dependencies. SetVector slice; @@ -950,8 +950,11 @@ static void hoistConvertOnTopOfExtOrBroadcast(ConvertLayoutOp convertOp) { if (isExtOrBroadcastOp(op)) { SetVector tempSlice; DenseMap tempLayout; + std::optional srcEncoding = inferSrcEncoding(op, layout[v]); + if (!srcEncoding) + return; LogicalResult result = getRematerializableSlice( - op->getOperand(0), layout[v], tempSlice, tempLayout); + op->getOperand(0), *srcEncoding, tempSlice, tempLayout); // If we can rematerialize the rest of the ext slice we can ignore this // ext as it won't need a convert. if (result.succeeded()) { @@ -969,13 +972,16 @@ static void hoistConvertOnTopOfExtOrBroadcast(ConvertLayoutOp convertOp) { if (extOrBroadcatOp == nullptr) return; + std::optional srcEncoding = + inferSrcEncoding(extOrBroadcatOp, layout[extOrBroadcatOp->getResult(0)]); + if (!srcEncoding) + return; // Move the convert before the ext op and rewrite the slice. OpBuilder builder(extOrBroadcatOp); auto tensorType = extOrBroadcatOp->getOperand(0).getType().cast(); - auto newType = - RankedTensorType::get(tensorType.getShape(), tensorType.getElementType(), - layout[extOrBroadcatOp->getResult(0)]); + auto newType = RankedTensorType::get( + tensorType.getShape(), tensorType.getElementType(), *srcEncoding); auto newConvertOp = builder.create( convertOp.getLoc(), newType, extOrBroadcatOp->getOperand(0)); IRMapping mapping; diff --git a/test/Triton/reorder-broadcast.mlir b/test/Triton/reorder-broadcast.mlir index d5e054337a08..201b81b1e746 100644 --- a/test/Triton/reorder-broadcast.mlir +++ b/test/Triton/reorder-broadcast.mlir @@ -53,3 +53,15 @@ tt.func @test_broadcast_binary_op_pattern(%arg0: tensor<128x1xf32>, %arg1: tenso tt.return %mul, %mul1 : tensor<128x128xf32>, tensor<128x128xf32> } + +// CHECK-LABEL: @test_broadcast_mix_type_op_pattern +tt.func @test_broadcast_mix_type_op_pattern(%arg0: tensor<128x1xf32>, %arg1: f32, %arg2: tensor<1x128xf32>, %arg3: tensor<128x1xi1>) -> (tensor<128x128xf32>) { + // CHECK: %[[sel:.*]] = arith.select %{{.*}}, %{{.*}}, %{{.*}} : tensor<128x1xi1>, tensor<128x1xf32> + // CHECK-NEXT: %{{.*}} = tt.broadcast %[[sel]] : (tensor<128x1xf32>) -> tensor<128x128xf32> + %broadcast0 = tt.broadcast %arg0 : (tensor<128x1xf32>) -> tensor<128x128xf32> + %broadcast1 = tt.splat %arg1 : (f32) -> tensor<128x128xf32> + %cond = tt.broadcast %arg3 : (tensor<128x1xi1>) -> tensor<128x128xi1> + %sel = arith.select %cond, %broadcast0, %broadcast1 : tensor<128x128xi1>, tensor<128x128xf32> + + tt.return %sel : tensor<128x128xf32> +} diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir index 5c3fdd6b9ba6..8f5685ae8649 100644 --- a/test/TritonGPU/combine.mlir +++ b/test/TritonGPU/combine.mlir @@ -1189,10 +1189,12 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} { // CHECK-LABEL: reduce_cvt2 // Match the reduction +// CHECK-NOT: triton_gpu.convert_layout // CHECK: tt.reduce // CHECK-SAME: axis = 1 -// CHECK: (tensor<1x256xf32, #blocked>) -> tensor<1xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> +// CHECK: (tensor<1x256xf32, #{{.*}}>) -> tensor<1xf32, #triton_gpu.slice<{dim = 1, parent = #{{.*}}}>> // CHECK: triton_gpu.convert_layout +// CHECK: tt.expand_dims // CHECK-NOT: triton_gpu.convert_layout // CHECK: tt.return #blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> From 8e75e392ae5009c5311cb72b508ecdb2585d152e Mon Sep 17 00:00:00 2001 From: Shantanu <12621235+hauntsaninja@users.noreply.github.com> Date: Tue, 19 Sep 2023 00:12:00 -0700 Subject: [PATCH 064/122] [FRONTEND] Fix Python error handling in launch (#2334) This was regressed by #2185 because we didn't realise CUDA_CHECK macro could do Python calls (similar to what led to #2225). I think the PyErr_Occurred got removed in that PR because there was missing error handling before the call to _launch, so it looked like it was just in the wrong place. It looks like there are also potentially a couple places in cuda.c that can return with error set, e.g. getDeviceProperties, memAlloc, memcpyHtoD, memFree, tensorMapEncodeTiled etc, but those are all pre-existing and not affected by recent changes. --- python/triton/compiler/make_launcher.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/triton/compiler/make_launcher.py b/python/triton/compiler/make_launcher.py index 68dd8aeb19e0..c7dd75ec72a8 100644 --- a/python/triton/compiler/make_launcher.py +++ b/python/triton/compiler/make_launcher.py @@ -255,6 +255,9 @@ def format_of(ty): Py_BEGIN_ALLOW_THREADS; _launch(gridX, gridY, gridZ, num_warps, num_ctas, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (CUstream)_stream, (CUfunction)_function{', ' + ', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items()) if len(signature) > 0 else ''}); Py_END_ALLOW_THREADS; + if (PyErr_Occurred()) {{ + return NULL; + }} if (launch_exit_hook != Py_None && !PyObject_CallObject(launch_exit_hook, args)) {{ return NULL; From ae07b7b3d3a9240f893a8213da88b75b546603ec Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Wed, 20 Sep 2023 01:00:33 +0200 Subject: [PATCH 065/122] Integration fixes for llvm/llvm-project#66754 (#2338) llvm/llvm-project#66754 extends the `LoopLikeOpInterface`: the signature of `getLoopBody` has changed. `ForOp::getRegion` can be used instead. This change works with and without llvm/llvm-project#66754. --- lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp | 6 +++--- lib/Dialect/TritonGPU/Transforms/Pipeline.cpp | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp index fffdb05559ee..5b980caa23be 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp @@ -729,8 +729,8 @@ struct SCFForPattern : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { auto newOp = cast(rewriter.cloneWithoutRegions(*op.getOperation())); - rewriter.inlineRegionBefore(op.getLoopBody(), newOp.getLoopBody(), - newOp.getLoopBody().end()); + rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(), + newOp.getRegion().end()); // Now, update all the types. @@ -739,7 +739,7 @@ struct SCFForPattern : public OpConversionPattern { // The entry block may have a special conversion if `entryConversion` is // provided. On success, the new entry block to the region is returned for // convenience. Otherwise, failure is returned. - if (failed(rewriter.convertRegionTypes(&newOp.getLoopBody(), + if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), *getTypeConverter()))) { return rewriter.notifyMatchFailure(op, "could not convert body types"); } diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp index 0a370100fc9a..67b86e23860f 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp @@ -347,7 +347,7 @@ LogicalResult LoopPipeliner::collectOps(SetVector &ops) { void LoopPipeliner::collectValueDep(Value v, int stage, SetVector &deps) { // Loop-invariant value, skip - if (v.getParentRegion() != &forOp.getLoopBody()) + if (v.getParentRegion() != &forOp.getRegion()) return; // Since we only need to peel the loop numStages-1 times, don't worry @@ -671,7 +671,7 @@ void LoopPipeliner::createBufferTypes() { } void LoopPipeliner::createOrderedDeps() { - for (Operation &op : forOp.getLoopBody().front()) { + for (Operation &op : *forOp.getBody()) { if (depOps.contains(&op)) orderedDeps.push_back(&op); else if (op.getNumResults() > 0 && validLoads.contains(op.getResult(0))) From 5491707093aae50005ddaa3f848e3ee16bdd4417 Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Tue, 19 Sep 2023 16:14:55 -0700 Subject: [PATCH 066/122] Switch pre-commit clang-format to v16.0.6. (#2342) Google uses clang-format at LLVM HEAD. clang-format's formatting is not stable, so we want to minimize the difference between the pre-commit clang-format and HEAD to minimize differences with Google's formatter. In practice, it appears that there are no relevant changes to the formatting, so this is a nop. :shrug: Tested by running `pre-commit run --all-files`. --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index dfe756c27d2c..1729fc92e887 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -44,7 +44,7 @@ repos: ^docs/conf.py$ ) - repo: https://github.com/pre-commit/mirrors-clang-format - rev: v14.0.6 + rev: v16.0.6 hooks: - id: clang-format stages: [commit, push, manual] From 363182928cb5a000099aa935fa9556006fea98ea Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Tue, 19 Sep 2023 23:05:47 -0700 Subject: [PATCH 067/122] Add instructions for building with custom LLVM (#2344) I tested these locally, seems to work for me. --- README.md | 40 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/README.md b/README.md index baec97a68e86..fbdd3027414b 100644 --- a/README.md +++ b/README.md @@ -81,6 +81,46 @@ pip install ninja cmake; # build-time dependencies pip install -e python ``` +# Building with a custom LLVM + +Triton uses LLVM to generate code for GPUs and CPUs. Normally, the Triton build +downloads a prebuilt LLVM, but you can also build LLVM from source and use that. + +LLVM does not have a stable API, so the Triton build will not work at an +arbitrary LLVM version. + +1. Find the version of LLVM that Triton builds against. Check `python/setup.py` + for a line like + + version = "llvm-17.0.0-c5dede880d17" + + This means that the version of Triton you have builds against + [LLVM](https://github.com/llvm/llvm-project) c5dede880d17. + +2. `git checkout` LLVM at this revision. Optionally, make additional + modifications to LLVM. + +3. [Build LLVM](https://llvm.org/docs/CMake.html). For example, you might run + + $ cd $HOME/llvm-project # your clone of LLVM. + $ mkdir build + $ cd build + $ cmake -G Ninja -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=ON ../llvm -DLLVM_ENABLE_PROJECTS="mlir" + $ ninja + +4. Grab a snack, this will take a while. + +5. Build Triton as above, but set the following environment variables. + + # Modify as appropriate to point to your LLVM build. + $ export LLVM_BUILD_DIR=$HOME/llvm-project/build + + $ cd /python + $ LLVM_INCLUDE_DIRS=$LLVM_BUILD_DIR/include \ + LLVM_LIBRARY_DIR=$LLVM_BUILD_DIR \ + LLVM_SYSPATH=$LLVM_BUILD_DIR \ + pip install -e python + # Changelog Version 2.0 is out! New features include: From e5eda098b38e69297ee2e0b83a86ad8005e44b5b Mon Sep 17 00:00:00 2001 From: Dongdong Li Date: Wed, 20 Sep 2023 14:23:46 +0800 Subject: [PATCH 068/122] [TESTS] fix flash attention (#2086) Co-authored-by: dongdongl --- .github/workflows/integration-tests.yml | 2 + .../Transforms/FenceInsertion.cpp | 78 ++++++- .../test/unit/hopper/test_flashattention.py | 17 +- test/TritonGPU/fence-inserstion.mlir | 205 ++++++++++++++++++ 4 files changed, 282 insertions(+), 20 deletions(-) create mode 100644 test/TritonGPU/fence-inserstion.mlir diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index c5b1cf2101c7..2424c3034177 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -105,6 +105,8 @@ jobs: python3 -m pytest runtime/ # run test_line_info.py separately with TRITON_DISABLE_LINE_INFO=0 TRITON_DISABLE_LINE_INFO=0 python3 -m pytest language/test_line_info.py + #run hopper/test_flashattention.py to avoid out of gpu memory + python3 -m pytest hopper/test_flashattention.py - name: Run python tests on CUDA with ENABLE_TMA=0 and ENABLE_MMA_V3=0 if: ${{ env.BACKEND == 'CUDA' && env.ENABLE_TMA == '0' && env.ENABLE_MMA_V3 == '0'}} diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp index eb5c0f2ffcd2..3d839874607d 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp @@ -3,6 +3,7 @@ #include "triton/Dialect/TritonGPU/Transforms/Utility.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" #include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" +#include "triton/Tools/Sys/GetEnv.hpp" #include "llvm/Support/Debug.h" //===----------------------------------------------------------------------===// @@ -32,16 +33,20 @@ struct FenceInsertionPass FenceInsertionPass(int computeCapability) { this->computeCapability = computeCapability; } - // TODO: support more patterns to insert fences - // only support insertion between convert layout ops and dot ops to protect - // flashattention + // TODO: support more general patterns to insert fences. eg. any op(generic) + // to shared in use-def chain which refers by async proxy. We have generic( + // convertlayout with sts/stmatix) + fence + async(wgmma) up to now void runOnOperation() override { // Only insert fences for compute capability 9.0 if (computeCapability < 90) return; + // ENABLE_MMA_V3 + if (!::triton::tools::getBoolEnv("ENABLE_MMA_V3")) + return; ModuleOp mod = getOperation(); mod.walk([&](Operation *op) { - if (isa(op)) { + if (isa(op)) { + OpBuilder builder(op); auto a = op->getOperand(0); auto b = op->getOperand(1); auto mmaEncoding = op->getResult(0) @@ -50,21 +55,70 @@ struct FenceInsertionPass .getEncoding() .dyn_cast(); auto isHopperEncoding = mmaEncoding && mmaEncoding.isHopper(); - if (isHopperEncoding && (a.getDefiningOp() && - ttg::isSharedEncoding(a)) || - (b.getDefiningOp() && - ttg::isSharedEncoding(b))) { - - // TODO: check whether cluster fence is needed - OpBuilder builder(op); + if (isHopperEncoding && + (dependOnSharedEncOperand(a) || dependOnSharedEncOperand(b))) { builder.create(op->getLoc(), false /*bCluster*/); } } }); } -}; +private: + bool dependOnSharedEncOperand(Value operand) { + static DenseSet> trace; + auto op = operand.getDefiningOp(); + // avoid redundant insertion + if (op && isa(op)) + return false; + // reach convertlayout + if (op && isa(op) && ttg::isSharedEncoding(operand)) + return true; + // root and not BlockArgument + if (!op && !isa(operand)) + return false; + // op and not BlockArgument + if (op && !isa(operand)) { + for (auto v : op->getOperands()) { + if (dependOnSharedEncOperand(v)) + return true; + } + } + // reach BlockArgument + // TODO: support other scf ops, IfOp, WhileOp, etc. + if (BlockArgument arg = dyn_cast(operand)) { + unsigned argNum = arg.getArgNumber(); + Operation *argOwner = arg.getOwner()->getParentOp(); + // suport ForOp only + if (auto forOp = dyn_cast(argOwner)) { + // prologue + auto iterOperands = forOp.getIterOperands(); + if (argNum == 0) + return false; + if (dependOnSharedEncOperand(iterOperands[argNum - 1])) + return true; + // yield + auto yieldOp = forOp.getBody()->getTerminator(); + Value v = yieldOp->getOperand(argNum - 1); + auto entry = std::make_pair(std::move(yieldOp), + std::move(argNum)); + // avoid cyclic + if (trace.contains(entry)) + return false; + else + trace.insert(entry); + + if (dependOnSharedEncOperand(v)) + return true; + } else if (auto whileOp = dyn_cast(argOwner)) { + assert(false && "FenceInsertionPass does not supported WhileOp"); + } else if (auto ifOp = dyn_cast(argOwner)) { + assert(false && "FenceInsertionPass does not supported IfOp"); + } + } + return false; + } +}; } // namespace std::unique_ptr diff --git a/python/test/unit/hopper/test_flashattention.py b/python/test/unit/hopper/test_flashattention.py index e46e1c1f2c59..60006613b625 100644 --- a/python/test/unit/hopper/test_flashattention.py +++ b/python/test/unit/hopper/test_flashattention.py @@ -368,14 +368,15 @@ def backward(ctx, do): attention = _attention.apply -@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(4, 48, 128, 64), - # (4, 48, 256, 64), - # (4, 48, 512, 64), - # (4, 48, 1024, 64), - # (4, 48, 2048, 64), - # (4, 48, 4096, 64), - # (4, 48, 8192, 64), out of memory - ]) +@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [ + (4, 48, 128, 64), + (4, 48, 256, 64), + (4, 48, 512, 64), + (4, 48, 1024, 64), + (4, 48, 2048, 64), + (4, 48, 4096, 64), + # (4, 48, 8192, 64), out of memory +]) @pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="requires arch 9+") def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16): torch.manual_seed(20) diff --git a/test/TritonGPU/fence-inserstion.mlir b/test/TritonGPU/fence-inserstion.mlir new file mode 100644 index 000000000000..c5ef88dbc68b --- /dev/null +++ b/test/TritonGPU/fence-inserstion.mlir @@ -0,0 +1,205 @@ +// RUN: ENABLE_MMA_V3=1 triton-opt %s -split-input-file --triton-nvidia-gpu-fence-insertion | FileCheck %s +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#mma = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 128, 16]}> +#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0], hasLeadingOffset = false}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @matmul_like_fence_1(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg4: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg5: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg6: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg7: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg8: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) attributes {noinline = false} { + %c2_i32 = arith.constant 2 : i32 + %c128_i32 = arith.constant 128 : i32 + %c1_i32 = arith.constant 1 : i32 + %c3_i32 = arith.constant 3 : i32 + %true = arith.constant true + %false = arith.constant false + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> + %c0_i32 = arith.constant 0 : i32 + %c1_i64 = arith.constant 1 : i64 + %0 = arith.extsi %arg3 : i32 to i64 + %1 = arith.extsi %arg5 : i32 to i64 + %2 = arith.extsi %arg6 : i32 to i64 + %3 = tt.make_tensor_ptr %arg0, [%0, %1], [%2, %c1_i64], [%c0_i32, %c0_i32] {order = array} : , 1> + %4 = arith.extsi %arg4 : i32 to i64 + %5 = arith.extsi %arg7 : i32 to i64 + %6 = tt.make_tensor_ptr %arg1, [%1, %4], [%5, %c1_i64], [%c0_i32, %c0_i32] {order = array} : , 1> + %7 = arith.extsi %arg8 : i32 to i64 + %8 = tt.make_tensor_ptr %arg2, [%0, %4], [%7, %c1_i64], [%c0_i32, %c0_i32] {order = array} : , 1> + %9 = triton_nvidia_gpu.alloc_mbarrier {count = 1 : i32} : tensor<3xi64, #shared> + %10 = arith.cmpi sgt, %arg5, %c0_i32 : i32 + %11 = triton_gpu.alloc_tensor : tensor<3x128x128xf16, #shared1> + %12 = tt.splat %10 : (i1) -> tensor<128x128xi1, #blocked1> + %13 = triton_nvidia_gpu.extract_mbarrier %9[%c0_i32] : tensor<3xi64, #shared>, i32 -> + %14 = triton_nvidia_gpu.get_thread_id : i32 + %15 = arith.cmpi eq, %14, %c0_i32 : i32 + %16 = arith.andi %15, %10 : i1 + triton_nvidia_gpu.mbarrier_arrive %13, %16 {operand_segment_sizes = array, trackAsyncOp = false, txCount = 65536 : i32} : !tt.ptr, i1 + %17 = triton_nvidia_gpu.insert_slice_async_v2 %3, %11, %c0_i32, %13, %12 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array} : !tt.ptr, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1> + %18 = triton_gpu.alloc_tensor : tensor<3x128x128xf16, #shared1> + %19 = triton_nvidia_gpu.insert_slice_async_v2 %6, %18, %c0_i32, %13, %12 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array} : !tt.ptr, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1> + %20 = tt.advance %3, [%c0_i32, %c128_i32] : , 1> + %21 = tt.advance %6, [%c128_i32, %c0_i32] : , 1> + %22 = arith.cmpi sgt, %arg5, %c128_i32 : i32 + %23 = tt.splat %22 : (i1) -> tensor<128x128xi1, #blocked1> + %24 = triton_nvidia_gpu.extract_mbarrier %9[%c1_i32] : tensor<3xi64, #shared>, i32 -> + %25 = arith.andi %15, %22 : i1 + triton_nvidia_gpu.mbarrier_arrive %24, %25 {operand_segment_sizes = array, trackAsyncOp = false, txCount = 65536 : i32} : !tt.ptr, i1 + %26 = triton_nvidia_gpu.insert_slice_async_v2 %20, %17, %c1_i32, %24, %23 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array} : !tt.ptr, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1> + %27 = triton_nvidia_gpu.insert_slice_async_v2 %21, %19, %c1_i32, %24, %23 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array} : !tt.ptr, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1> + %28 = triton_gpu.extract_slice %26[0, 0, 0] [1, 128, 128] [1, 1, 1] : tensor<3x128x128xf16, #shared1> to tensor<128x128xf16, #shared1> + %29 = triton_gpu.extract_slice %27[0, 0, 0] [1, 128, 128] [1, 1, 1] : tensor<3x128x128xf16, #shared1> to tensor<128x128xf16, #shared1> + %30:15 = scf.for %arg9 = %c0_i32 to %arg5 step %c128_i32 iter_args(%arg10 = %cst, %arg11 = %3, %arg12 = %6, %arg13 = %26, %arg14 = %27, %arg15 = %28, %arg16 = %29, %arg17 = %20, %arg18 = %21, %arg19 = %c128_i32, %arg20 = %c2_i32, %arg21 = %c0_i32, %arg22 = %c0_i32, %arg23 = %false, %arg24 = %true) -> (tensor<128x128xf32, #mma>, !tt.ptr, 1>, !tt.ptr, 1>, tensor<3x128x128xf16, #shared1>, tensor<3x128x128xf16, #shared1>, tensor<128x128xf16, #shared1>, tensor<128x128xf16, #shared1>, !tt.ptr, 1>, !tt.ptr, 1>, i32, i32, i32, i32, i1, i1) : i32 { + %33 = triton_nvidia_gpu.extract_mbarrier %9[%arg21] : tensor<3xi64, #shared>, i32 -> + triton_nvidia_gpu.mbarrier_wait %33, %arg23 : + // CHECK: triton_nvidia_gpu.fence_async_shared + %34 = triton_nvidia_gpu.dot_async %arg15, %arg16, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x128xf16, #shared1> * tensor<128x128xf16, #shared1> -> tensor<128x128xf32, #mma> + triton_nvidia_gpu.dot_wait {pendings = 1 : i32} + %35 = tt.advance %arg11, [%c0_i32, %c128_i32] : , 1> + %36 = tt.advance %arg12, [%c128_i32, %c0_i32] : , 1> + %37 = arith.addi %arg19, %c128_i32 : i32 + %38 = arith.cmpi slt, %37, %arg5 : i32 + %39 = arith.addi %arg21, %c1_i32 : i32 + %40 = arith.cmpi uge, %39, %c3_i32 : i32 + %41 = arith.select %40, %c0_i32, %39 : i32 + %42 = tt.advance %arg17, [%c0_i32, %c128_i32] : , 1> + %43 = tt.advance %arg18, [%c128_i32, %c0_i32] : , 1> + %44 = tt.splat %38 : (i1) -> tensor<128x128xi1, #blocked1> + %45 = triton_nvidia_gpu.extract_mbarrier %9[%arg20] : tensor<3xi64, #shared>, i32 -> + %46 = arith.andi %15, %38 : i1 + triton_nvidia_gpu.mbarrier_arrive %45, %46 {operand_segment_sizes = array, trackAsyncOp = false, txCount = 65536 : i32} : !tt.ptr, i1 + %47 = triton_nvidia_gpu.insert_slice_async_v2 %42, %arg13, %arg20, %45, %44 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array} : !tt.ptr, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1> + %48 = triton_gpu.extract_slice %47[%41, 0, 0] [1, 128, 128] [1, 1, 1] : tensor<3x128x128xf16, #shared1> to tensor<128x128xf16, #shared1> + %49 = triton_nvidia_gpu.insert_slice_async_v2 %43, %arg14, %arg20, %45, %44 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array} : !tt.ptr, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1> + %50 = triton_gpu.extract_slice %49[%41, 0, 0] [1, 128, 128] [1, 1, 1] : tensor<3x128x128xf16, #shared1> to tensor<128x128xf16, #shared1> + %b_48 = triton_gpu.convert_layout %48 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #blocked1> + %s_48 = triton_gpu.convert_layout %b_48 : (tensor<128x128xf16, #blocked1>) -> tensor<128x128xf16, #shared1> + %51 = arith.addi %arg20, %c1_i32 : i32 + %52 = arith.cmpi uge, %51, %c3_i32 : i32 + %53 = arith.select %52, %c0_i32, %51 : i32 + %54 = arith.addi %arg22, %c1_i32 : i32 + %55 = arith.xori %arg23, %true : i1 + %56 = arith.cmpi ult, %39, %c3_i32 : i32 + %57 = arith.andi %40, %55 : i1 + %58 = arith.andi %56, %arg23 : i1 + %59 = arith.ori %57, %58 : i1 + %60 = arith.xori %arg24, %true : i1 + %61 = arith.cmpi ult, %51, %c3_i32 : i32 + %62 = arith.andi %52, %60 : i1 + %63 = arith.andi %61, %arg24 : i1 + %64 = arith.ori %62, %63 : i1 + scf.yield %34, %35, %36, %47, %49, %s_48, %50, %42, %43, %37, %53, %41, %54, %59, %64 : tensor<128x128xf32, #mma>, !tt.ptr, 1>, !tt.ptr, 1>, tensor<3x128x128xf16, #shared1>, tensor<3x128x128xf16, #shared1>, tensor<128x128xf16, #shared1>, tensor<128x128xf16, #shared1>, !tt.ptr, 1>, !tt.ptr, 1>, i32, i32, i32, i32, i1, i1 + } + scf.if %10 { + triton_nvidia_gpu.dot_wait {pendings = 0 : i32} + } + %31 = arith.truncf %30#0 : tensor<128x128xf32, #mma> to tensor<128x128xf16, #mma> + %32 = triton_gpu.convert_layout %31 : (tensor<128x128xf16, #mma>) -> tensor<128x128xf16, #shared1> + triton_nvidia_gpu.store_async %8, %32 : !tt.ptr, 1>, tensor<128x128xf16, #shared1> + triton_gpu.async_bulk_commit_group + triton_gpu.async_bulk_wait {num = 0 : i32} + tt.return + } +} + + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#mma = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 128, 16]}> +#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0], hasLeadingOffset = false}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @matmul_like_fence_2(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg4: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg5: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg6: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg7: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg8: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) attributes {noinline = false} { + %c2_i32 = arith.constant 2 : i32 + %c128_i32 = arith.constant 128 : i32 + %c1_i32 = arith.constant 1 : i32 + %c3_i32 = arith.constant 3 : i32 + %true = arith.constant true + %false = arith.constant false + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> + %c0_i32 = arith.constant 0 : i32 + %c1_i64 = arith.constant 1 : i64 + %0 = arith.extsi %arg3 : i32 to i64 + %1 = arith.extsi %arg5 : i32 to i64 + %2 = arith.extsi %arg6 : i32 to i64 + %3 = tt.make_tensor_ptr %arg0, [%0, %1], [%2, %c1_i64], [%c0_i32, %c0_i32] {order = array} : , 1> + %4 = arith.extsi %arg4 : i32 to i64 + %5 = arith.extsi %arg7 : i32 to i64 + %6 = tt.make_tensor_ptr %arg1, [%1, %4], [%5, %c1_i64], [%c0_i32, %c0_i32] {order = array} : , 1> + %7 = arith.extsi %arg8 : i32 to i64 + %8 = tt.make_tensor_ptr %arg2, [%0, %4], [%7, %c1_i64], [%c0_i32, %c0_i32] {order = array} : , 1> + %9 = triton_nvidia_gpu.alloc_mbarrier {count = 1 : i32} : tensor<3xi64, #shared> + %10 = arith.cmpi sgt, %arg5, %c0_i32 : i32 + %11 = triton_gpu.alloc_tensor : tensor<3x128x128xf16, #shared1> + %12 = tt.splat %10 : (i1) -> tensor<128x128xi1, #blocked1> + %13 = triton_nvidia_gpu.extract_mbarrier %9[%c0_i32] : tensor<3xi64, #shared>, i32 -> + %14 = triton_nvidia_gpu.get_thread_id : i32 + %15 = arith.cmpi eq, %14, %c0_i32 : i32 + %16 = arith.andi %15, %10 : i1 + triton_nvidia_gpu.mbarrier_arrive %13, %16 {operand_segment_sizes = array, trackAsyncOp = false, txCount = 65536 : i32} : !tt.ptr, i1 + %17 = triton_nvidia_gpu.insert_slice_async_v2 %3, %11, %c0_i32, %13, %12 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array} : !tt.ptr, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1> + %18 = triton_gpu.alloc_tensor : tensor<3x128x128xf16, #shared1> + %19 = triton_nvidia_gpu.insert_slice_async_v2 %6, %18, %c0_i32, %13, %12 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array} : !tt.ptr, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1> + %20 = tt.advance %3, [%c0_i32, %c128_i32] : , 1> + %21 = tt.advance %6, [%c128_i32, %c0_i32] : , 1> + %22 = arith.cmpi sgt, %arg5, %c128_i32 : i32 + %23 = tt.splat %22 : (i1) -> tensor<128x128xi1, #blocked1> + %24 = triton_nvidia_gpu.extract_mbarrier %9[%c1_i32] : tensor<3xi64, #shared>, i32 -> + %25 = arith.andi %15, %22 : i1 + triton_nvidia_gpu.mbarrier_arrive %24, %25 {operand_segment_sizes = array, trackAsyncOp = false, txCount = 65536 : i32} : !tt.ptr, i1 + %26 = triton_nvidia_gpu.insert_slice_async_v2 %20, %17, %c1_i32, %24, %23 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array} : !tt.ptr, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1> + %27 = triton_nvidia_gpu.insert_slice_async_v2 %21, %19, %c1_i32, %24, %23 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array} : !tt.ptr, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1> + %28 = triton_gpu.extract_slice %26[0, 0, 0] [1, 128, 128] [1, 1, 1] : tensor<3x128x128xf16, #shared1> to tensor<128x128xf16, #shared1> + %29 = triton_gpu.extract_slice %27[0, 0, 0] [1, 128, 128] [1, 1, 1] : tensor<3x128x128xf16, #shared1> to tensor<128x128xf16, #shared1> + %b_29 = triton_gpu.convert_layout %29 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #blocked1> + %s_29 = triton_gpu.convert_layout %b_29 : (tensor<128x128xf16, #blocked1>) -> tensor<128x128xf16, #shared1> + %30:15 = scf.for %arg9 = %c0_i32 to %arg5 step %c128_i32 iter_args(%arg10 = %cst, %arg11 = %3, %arg12 = %6, %arg13 = %26, %arg14 = %27, %arg15 = %28, %arg16 = %s_29, %arg17 = %20, %arg18 = %21, %arg19 = %c128_i32, %arg20 = %c2_i32, %arg21 = %c0_i32, %arg22 = %c0_i32, %arg23 = %false, %arg24 = %true) -> (tensor<128x128xf32, #mma>, !tt.ptr, 1>, !tt.ptr, 1>, tensor<3x128x128xf16, #shared1>, tensor<3x128x128xf16, #shared1>, tensor<128x128xf16, #shared1>, tensor<128x128xf16, #shared1>, !tt.ptr, 1>, !tt.ptr, 1>, i32, i32, i32, i32, i1, i1) : i32 { + %33 = triton_nvidia_gpu.extract_mbarrier %9[%arg21] : tensor<3xi64, #shared>, i32 -> + triton_nvidia_gpu.mbarrier_wait %33, %arg23 : + // CHECK: triton_nvidia_gpu.fence_async_shared + %34 = triton_nvidia_gpu.dot_async %arg15, %arg16, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x128xf16, #shared1> * tensor<128x128xf16, #shared1> -> tensor<128x128xf32, #mma> + triton_nvidia_gpu.dot_wait {pendings = 1 : i32} + %35 = tt.advance %arg11, [%c0_i32, %c128_i32] : , 1> + %36 = tt.advance %arg12, [%c128_i32, %c0_i32] : , 1> + %37 = arith.addi %arg19, %c128_i32 : i32 + %38 = arith.cmpi slt, %37, %arg5 : i32 + %39 = arith.addi %arg21, %c1_i32 : i32 + %40 = arith.cmpi uge, %39, %c3_i32 : i32 + %41 = arith.select %40, %c0_i32, %39 : i32 + %42 = tt.advance %arg17, [%c0_i32, %c128_i32] : , 1> + %43 = tt.advance %arg18, [%c128_i32, %c0_i32] : , 1> + %44 = tt.splat %38 : (i1) -> tensor<128x128xi1, #blocked1> + %45 = triton_nvidia_gpu.extract_mbarrier %9[%arg20] : tensor<3xi64, #shared>, i32 -> + %46 = arith.andi %15, %38 : i1 + triton_nvidia_gpu.mbarrier_arrive %45, %46 {operand_segment_sizes = array, trackAsyncOp = false, txCount = 65536 : i32} : !tt.ptr, i1 + %47 = triton_nvidia_gpu.insert_slice_async_v2 %42, %arg13, %arg20, %45, %44 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array} : !tt.ptr, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1> + %48 = triton_gpu.extract_slice %47[%41, 0, 0] [1, 128, 128] [1, 1, 1] : tensor<3x128x128xf16, #shared1> to tensor<128x128xf16, #shared1> + %49 = triton_nvidia_gpu.insert_slice_async_v2 %43, %arg14, %arg20, %45, %44 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array} : !tt.ptr, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1> + %50 = triton_gpu.extract_slice %49[%41, 0, 0] [1, 128, 128] [1, 1, 1] : tensor<3x128x128xf16, #shared1> to tensor<128x128xf16, #shared1> + %51 = arith.addi %arg20, %c1_i32 : i32 + %52 = arith.cmpi uge, %51, %c3_i32 : i32 + %53 = arith.select %52, %c0_i32, %51 : i32 + %54 = arith.addi %arg22, %c1_i32 : i32 + %55 = arith.xori %arg23, %true : i1 + %56 = arith.cmpi ult, %39, %c3_i32 : i32 + %57 = arith.andi %40, %55 : i1 + %58 = arith.andi %56, %arg23 : i1 + %59 = arith.ori %57, %58 : i1 + %60 = arith.xori %arg24, %true : i1 + %61 = arith.cmpi ult, %51, %c3_i32 : i32 + %62 = arith.andi %52, %60 : i1 + %63 = arith.andi %61, %arg24 : i1 + %64 = arith.ori %62, %63 : i1 + scf.yield %34, %35, %36, %47, %49, %48, %50, %42, %43, %37, %53, %41, %54, %59, %64 : tensor<128x128xf32, #mma>, !tt.ptr, 1>, !tt.ptr, 1>, tensor<3x128x128xf16, #shared1>, tensor<3x128x128xf16, #shared1>, tensor<128x128xf16, #shared1>, tensor<128x128xf16, #shared1>, !tt.ptr, 1>, !tt.ptr, 1>, i32, i32, i32, i32, i1, i1 + } + scf.if %10 { + triton_nvidia_gpu.dot_wait {pendings = 0 : i32} + } + %31 = arith.truncf %30#0 : tensor<128x128xf32, #mma> to tensor<128x128xf16, #mma> + %32 = triton_gpu.convert_layout %31 : (tensor<128x128xf16, #mma>) -> tensor<128x128xf16, #shared1> + triton_nvidia_gpu.store_async %8, %32 : !tt.ptr, 1>, tensor<128x128xf16, #shared1> + triton_gpu.async_bulk_commit_group + triton_gpu.async_bulk_wait {num = 0 : i32} + tt.return + } +} From ed5a53057d39358dc857dcefd73415a6da9a2e25 Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Wed, 20 Sep 2023 12:25:52 -0400 Subject: [PATCH 069/122] [BACKEND] Handle repetitive threads in scan op when the tensor dim is small (#2345) https://github.com/openai/triton/issues/2298 --- include/triton/Analysis/Utility.h | 3 + lib/Analysis/Utility.cpp | 65 ++++++++++--------- .../TritonGPUToLLVM/ScanOpToLLVM.cpp | 10 +-- python/test/unit/operators/test_inductor.py | 2 +- 4 files changed, 43 insertions(+), 37 deletions(-) diff --git a/include/triton/Analysis/Utility.h b/include/triton/Analysis/Utility.h index 8ad32a30c992..59af824097f8 100644 --- a/include/triton/Analysis/Utility.h +++ b/include/triton/Analysis/Utility.h @@ -92,6 +92,8 @@ class ScanLoweringHelper { unsigned getAxisNumWarpsWithUniqueData(); // Return the number of threads per warp along axis dim. unsigned getAxisNumThreadsPerWarp(); + // Return the number of threads per warp along axis dim with unique data. + unsigned getAxisNumThreadsPerWarpWithUniqueData(); // Return the number of blocks along axis dim. unsigned getAxisNumBlocks(); // Return the number of blocks along non axis dim. @@ -109,6 +111,7 @@ class ScanLoweringHelper { Location getLoc() { return scanOp.getLoc(); } unsigned getAxis() { return scanOp.getAxis(); } triton::gpu::BlockedEncodingAttr getEncoding(); + llvm::ArrayRef getShape(); Region &getCombineOp(); private: diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index ec9ffaab9ffe..1d2a6b2e9bc1 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -212,7 +212,8 @@ unsigned ScanLoweringHelper::getAxisNumElementsPerThread() { } unsigned ScanLoweringHelper::getNonAxisNumElementsPerThread() { - SmallVector sizePerThreads = getContigPerThread(getEncoding()); + SmallVector sizePerThreads = + triton::gpu::getContigPerThread(getEncoding()); sizePerThreads[getAxis()] = 1; return product(sizePerThreads); } @@ -223,6 +224,11 @@ unsigned ScanLoweringHelper::getAxisNumThreadsPerWarp() { return triton::gpu::getThreadsPerWarp(getEncoding())[getAxis()]; } +unsigned ScanLoweringHelper::getAxisNumThreadsPerWarpWithUniqueData() { + return triton::gpu::getThreadsPerWarpWithUniqueData(getEncoding(), + getShape())[getAxis()]; +} + unsigned ScanLoweringHelper::getNonAxisNumThreadsPerWarp() { auto threadsPerWarp = triton::gpu::getThreadsPerWarp(getEncoding()); threadsPerWarp[getAxis()] = 1; @@ -239,42 +245,36 @@ unsigned ScanLoweringHelper::getNonAxisNumThreadsPerCTA() { } unsigned ScanLoweringHelper::getAxisNumWarps() { - auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcEncoding); - return warpsPerCTA[getAxis()]; + return triton::gpu::getWarpsPerCTA(getEncoding())[getAxis()]; } unsigned ScanLoweringHelper::getAxisNumWarpsWithUniqueData() { - auto type = scanOp.getOperand(0).getType().cast(); - auto shape = type.getShape(); - auto warpsPerCTA = - triton::gpu::getWarpsPerCTAWithUniqueData(srcEncoding, shape); - return warpsPerCTA[getAxis()]; + return triton::gpu::getWarpsPerCTAWithUniqueData(getEncoding(), + getShape())[getAxis()]; } unsigned ScanLoweringHelper::getAxisNumBlocks() { - auto type = scanOp.getOperand(0).getType().cast(); - auto sizePerThreads = triton::gpu::getSizePerThread(srcEncoding); - auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcEncoding); - auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcEncoding); + auto sizePerThreads = triton::gpu::getSizePerThread(getEncoding()); + auto threadsPerWarp = triton::gpu::getThreadsPerWarp(getEncoding()); + auto warpsPerCTA = triton::gpu::getWarpsPerCTA(getEncoding()); unsigned axis = getAxis(); return ceil( - type.getShape()[axis], + getShape()[axis], (sizePerThreads[axis] * threadsPerWarp[axis] * warpsPerCTA[axis])); } unsigned ScanLoweringHelper::getNonAxisNumBlocks() { - auto type = scanOp.getOperand(0).getType().cast(); - auto sizePerThreads = triton::gpu::getSizePerThread(srcEncoding); - auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcEncoding); - auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcEncoding); + auto sizePerThreads = triton::gpu::getSizePerThread(getEncoding()); + auto threadsPerWarp = triton::gpu::getThreadsPerWarp(getEncoding()); + auto warpsPerCTA = triton::gpu::getWarpsPerCTA(getEncoding()); unsigned axis = getAxis(); unsigned numBlocks = 1; for (unsigned i = 0; i < sizePerThreads.size(); i++) { if (i == axis) continue; - numBlocks *= ceil( - type.getShape()[i], - (sizePerThreads[i] * threadsPerWarp[i] * warpsPerCTA[i])); + numBlocks *= + ceil(getShape()[i], (sizePerThreads[i] * threadsPerWarp[i] * + warpsPerCTA[i])); } return numBlocks; } @@ -283,7 +283,7 @@ bool ScanLoweringHelper::isSupported() { // TODO: Support the following cases: // 1. Scan on non-blocking encodings // 2. Scan with multiple operands - if (!isa(srcEncoding)) + if (!isa(getEncoding())) return false; if (scanOp.getNumOperands() != 1) return false; @@ -309,8 +309,12 @@ triton::gpu::BlockedEncodingAttr ScanLoweringHelper::getEncoding() { return srcEncoding.cast(); } +llvm::ArrayRef ScanLoweringHelper::getShape() { + return scanOp.getOperand(0).getType().cast().getShape(); +} + unsigned ScanLoweringHelper::getAxisElementStride() { - auto order = triton::gpu::getOrder(srcEncoding); + auto order = triton::gpu::getOrder(getEncoding()); unsigned stride = 1; for (unsigned dim : order) { if (dim == getAxis()) @@ -321,7 +325,7 @@ unsigned ScanLoweringHelper::getAxisElementStride() { } unsigned ScanLoweringHelper::getAxisThreadStride() { - auto order = triton::gpu::getOrder(srcEncoding); + auto order = triton::gpu::getOrder(getEncoding()); unsigned stride = 1; for (unsigned dim : order) { if (dim == getAxis()) @@ -332,18 +336,17 @@ unsigned ScanLoweringHelper::getAxisThreadStride() { } unsigned ScanLoweringHelper::getAxisBlockStride() { - auto order = triton::gpu::getOrder(srcEncoding); + auto order = triton::gpu::getOrder(getEncoding()); unsigned stride = 1; - auto type = scanOp.getOperand(0).getType().cast(); - auto sizePerThreads = triton::gpu::getSizePerThread(srcEncoding); - auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcEncoding); - auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcEncoding); + auto sizePerThreads = triton::gpu::getSizePerThread(getEncoding()); + auto threadsPerWarp = triton::gpu::getThreadsPerWarp(getEncoding()); + auto warpsPerCTA = triton::gpu::getWarpsPerCTA(getEncoding()); for (unsigned dim : order) { if (dim == getAxis()) return stride; - stride *= ceil(type.getShape()[dim], sizePerThreads[dim] * - threadsPerWarp[dim] * - warpsPerCTA[dim]); + stride *= ceil(getShape()[dim], sizePerThreads[dim] * + threadsPerWarp[dim] * + warpsPerCTA[dim]); } llvm_unreachable("Axis not found in order"); } diff --git a/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp index 285118be3456..22aefd912d61 100644 --- a/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp @@ -64,7 +64,7 @@ static void warpScan(SmallVector &srcValues, unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread(); unsigned elementStride = helper.getAxisElementStride(); unsigned threadStride = helper.getAxisThreadStride(); - unsigned scanDim = helper.getAxisNumThreadsPerWarp(); + unsigned scanDim = helper.getAxisNumThreadsPerWarpWithUniqueData(); for (unsigned srcIndex = 0; srcIndex < srcValues.size(); srcIndex++) { unsigned elementIdx = (srcIndex / elementStride) % scanElementsPerThreads; // Only consider the last element of each contiguous chunk of elements. @@ -96,7 +96,7 @@ static void storeWarpAccumulator(SmallVector &srcValues, Value parallelLaneId) { Location loc = helper.getLoc(); unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread(); - unsigned scanDim = helper.getAxisNumThreadsPerWarp(); + unsigned scanDim = helper.getAxisNumThreadsPerWarpWithUniqueData(); unsigned numParallelLane = helper.getNonAxisNumThreadsPerCTA(); unsigned axisNumWarps = helper.getAxisNumWarpsWithUniqueData(); unsigned chunkId = 0; @@ -222,7 +222,7 @@ static void AddPartialReduceOneWarp(SmallVector &srcValues, unsigned threadStride = helper.getAxisThreadStride(); unsigned axisNumWarps = helper.getAxisNumWarpsWithUniqueData(); unsigned numParallelLane = helper.getNonAxisNumThreadsPerCTA(); - unsigned scanDim = helper.getAxisNumThreadsPerWarp(); + unsigned scanDim = helper.getAxisNumThreadsPerWarpWithUniqueData(); Value maskFirstWarp = icmp_eq(warpId, i32_val(0)); Value maskFirstLane = icmp_eq(laneIdAxis, i32_val(0)); Value maskFirstThread = and_(maskFirstWarp, maskFirstLane); @@ -394,7 +394,7 @@ ScanOpConversion::emitFastScan(triton::ScanOp op, triton::ScanOpAdaptor adaptor, auto input = adaptor.getOperands()[0]; auto type = op.getOperand(0).getType().cast(); auto axisNumWarps = helper.getAxisNumWarpsWithUniqueData(); - auto axisNumThreads = helper.getAxisNumThreadsPerWarp(); + auto axisNumThreads = helper.getAxisNumThreadsPerWarpWithUniqueData(); warpIdAxis = urem(warpIdAxis, i32_val(axisNumWarps)); SmallVector srcValues = getTypeConverter()->unpackLLElements(loc, input, rewriter, type); @@ -423,7 +423,7 @@ ScanOpConversion::emitFastScan(triton::ScanOp op, triton::ScanOpAdaptor adaptor, } else if (srcValues.size() > 1) { // Fast path for the case where there is only one warp with unique data on // the axis. - unsigned scanDim = helper.getAxisNumThreadsPerWarp(); + unsigned scanDim = helper.getAxisNumThreadsPerWarpWithUniqueData(); auto multiDimLaneId = getMultiDimLaneId(rewriter, helper, laneId); multiDimLaneId[helper.getAxis()] = i32_val(scanDim - 1); auto threadsPerWarp = triton::gpu::getThreadsPerWarp(helper.getEncoding()); diff --git a/python/test/unit/operators/test_inductor.py b/python/test/unit/operators/test_inductor.py index 835f8c14281c..17ba2eb9c8fb 100644 --- a/python/test/unit/operators/test_inductor.py +++ b/python/test/unit/operators/test_inductor.py @@ -156,7 +156,7 @@ def triton_(in_ptr0, out_ptr0, XBLOCK: tl.constexpr): torch.testing.assert_close(out, out_ref) -@pytest.mark.parametrize("RBLOCK", [32, 64, 128]) +@pytest.mark.parametrize("RBLOCK", [1, 16, 32, 64, 128]) @pytest.mark.parametrize("num_warps", [1, 4]) def test_scan2d_broadcast(RBLOCK, num_warps): @triton.jit(debug=True) From 9cab885dffb1e5c83ed155249e3864a4671a36a5 Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Wed, 20 Sep 2023 14:05:12 -0700 Subject: [PATCH 070/122] [BACKEND] Optimize wgmma with accumulator source equal to 0 (#2343) Also add a test for MMA v3 reduction. --- .../TritonGPUToLLVM/DotOpToLLVM/WGMMA.cpp | 18 ++++++++++++++++-- python/test/unit/language/test_core.py | 3 ++- test/Conversion/tritongpu_to_llvm_hopper.mlir | 17 +++++++++++++++++ 3 files changed, 35 insertions(+), 3 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/WGMMA.cpp b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/WGMMA.cpp index e69e89b5030d..fb5cfd9d66a9 100644 --- a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/WGMMA.cpp +++ b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/WGMMA.cpp @@ -278,6 +278,18 @@ static Value faddAccumulate(ConversionPatternRewriter &rewriter, Location loc, return newStruct; } +static bool isZero(Value v) { + auto constantOp = v.getDefiningOp(); + if (!constantOp) + return false; + if (auto denseAttr = dyn_cast(constantOp.getValueAttr())) + return denseAttr.isSplat() && denseAttr.getSplatValue().isZero(); + if (auto denseAttr = + dyn_cast(constantOp.getValueAttr())) + return denseAttr.isSplat() && denseAttr.getSplatValue().isZero(); + return false; +} + LogicalResult convertDot(TritonGPUToLLVMTypeConverter *typeConverter, ConversionPatternRewriter &rewriter, Location loc, Operation *op, Value a, Value b, Value c, Value d, @@ -302,7 +314,7 @@ LogicalResult convertDot(TritonGPUToLLVMTypeConverter *typeConverter, int M = 4 * instrShape[0]; int N = instrShape[1]; int K = instrShape[2]; - + bool zeroAcc = isZero(c); auto shapePerCTATile = getShapePerCTATile(mmaEncoding); int numRepM = ceil(dShapePerCTA[0], shapePerCTATile[0]); int numRepN = ceil(dShapePerCTA[1], shapePerCTATile[1]); @@ -344,7 +356,9 @@ LogicalResult convertDot(TritonGPUToLLVMTypeConverter *typeConverter, elemTypes.push_back(accEl.getType()); auto accTy = LLVM::LLVMStructType::getLiteral(rewriter.getContext(), elemTypes); - Value d = typeConverter->packLLElements(loc, mmaOut, rewriter, accTy); + Value d; + if (!zeroAcc) + d = typeConverter->packLLElements(loc, mmaOut, rewriter, accTy); uint32_t numLowPrecisionAcc = 0; Value partialAcc; for (int k = 0; k < numRepK; ++k) { diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index fe105e6ba398..e67bc9b4fb75 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1831,7 +1831,8 @@ def test_scan_layouts(M, N, src_layout, axis, device): BlockedLayout([1, 4], [8, 4], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]), BlockedLayout([4, 4], [2, 16], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), MmaLayout(version=(2, 0), warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1], instr_shape=[16, 8]), - MmaLayout(version=(2, 0), warps_per_cta=[2, 2], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1], instr_shape=[16, 8]) + MmaLayout(version=(2, 0), warps_per_cta=[2, 2], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1], instr_shape=[16, 8]), + MmaLayout(version=(3, 0), warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[1, 0], instr_shape=[16, 16, 16]), ] diff --git a/test/Conversion/tritongpu_to_llvm_hopper.mlir b/test/Conversion/tritongpu_to_llvm_hopper.mlir index 2095788623f4..15020d058366 100644 --- a/test/Conversion/tritongpu_to_llvm_hopper.mlir +++ b/test/Conversion/tritongpu_to_llvm_hopper.mlir @@ -149,3 +149,20 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : tt.return } } + +// ----- + +#mma = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: @dot_zero_acc + // Generate a wgmma with 2 sources. + // CHECK: nvgpu.wgmma %{{.*}}, %{{.*}} { + tt.func @dot_zero_acc(%a: tensor<128x64xf16, #shared>, %b: tensor<64x64xf16, #shared1>) { + %cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> + %m = triton_nvidia_gpu.dot_async %a, %b, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : + tensor<128x64xf16, #shared> * tensor<64x64xf16, #shared1> -> tensor<128x64xf32, #mma> + tt.return + } +} From bcaf14755a0496f9e0cc9574d9e4a8a768d568aa Mon Sep 17 00:00:00 2001 From: ben-zhang-609 <110140741+ben-zhang-609@users.noreply.github.com> Date: Thu, 21 Sep 2023 05:06:56 +0800 Subject: [PATCH 071/122] [HOPPER] enable flash attention with tma (#2336) --- .../TritonGPUToLLVM/LoadStoreOpToLLVM.cpp | 35 +++++++----------- .../unit/operators/test_flash_attention.py | 16 +++++--- python/triton/ops/flash_attention.py | 37 ++++++++++--------- 3 files changed, 43 insertions(+), 45 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp index c050c3ad3a38..6b01b3bac546 100644 --- a/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -21,6 +21,19 @@ using ::mlir::triton::gpu::getShapePerCTA; using ::mlir::triton::gpu::getTotalElemsPerThread; using ::mlir::triton::gpu::SharedEncodingAttr; +static CUtensorMapDataType getCUtensorMapDataType(Type ty) { + if (ty.isF16()) { + return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16; + } else if (ty.isBF16()) { + return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; + } else if (ty.isF32()) { + return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT32; + } else { + llvm::report_fatal_error("Unsupported elemTy for InsertSliceAsyncV2Op"); + return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16; + } +} + // Contains some helper functions for both Load and Store conversions. struct LoadStoreConversionBase { explicit LoadStoreConversionBase(ModuleAxisInfoAnalysis &axisAnalysisPass) @@ -857,17 +870,6 @@ struct StoreAsyncOpConversion } private: - CUtensorMapDataType getCUtensorMapDataType(Type ty) const { - if (ty.isF16()) { - return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16; - } else if (ty.isF32()) { - return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT32; - } else { - llvm::report_fatal_error("Unsupported elemTy for StoreAsyncOp"); - return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16; - } - } - unsigned getArgIdx(Value v) const { if (auto op = v.getDefiningOp()) { return -1 - @@ -1728,17 +1730,6 @@ struct InsertSliceAsyncV2OpConversion return bcastMask; } - CUtensorMapDataType getCUtensorMapDataType(Type ty) const { - if (ty.isF16()) { - return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16; - } else if (ty.isF32()) { - return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT32; - } else { - llvm::report_fatal_error("Unsupported elemTy for InsertSliceAsyncV2Op"); - return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16; - } - } - unsigned getArgIdx(Value v) const { if (auto op = v.getDefiningOp()) { return -1 - diff --git a/python/test/unit/operators/test_flash_attention.py b/python/test/unit/operators/test_flash_attention.py index 75da98e5044a..179410faea0b 100644 --- a/python/test/unit/operators/test_flash_attention.py +++ b/python/test/unit/operators/test_flash_attention.py @@ -13,17 +13,23 @@ @pytest.mark.parametrize('causal', [True, False]) @pytest.mark.parametrize('seq_par', [True, False]) def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par): - # with ENABLE_TMA=0 and ENABLE_MMA_V3=0 import os - enable_mmav3 = os.environ.get('ENABLE_MMA_V3', 'not found').lower() enable_tma = os.environ.get('ENABLE_TMA', 'not found').lower() - if enable_mmav3 in ["on", "true", "1"] and enable_tma in ["on", "true", "1"]: - pytest.skip('Segmentation fault') + if enable_tma in ["on", "true", "1"]: + if dtype == torch.bfloat16: + pytest.skip('bfloat16 tma not support currently') + if '-'.join(map(str, [seq_par, causal, Z, H, N_CTX, D_HEAD])) in [ + "True-True-2-4-512-16", + "True-True-2-4-512-32", + "True-False-2-4-512-16", + "True-False-2-4-512-32", + ]: + pytest.skip('backward ref check failed') capability = torch.cuda.get_device_capability() interpreter = os.environ.get("TRITON_INTERPRET", 'not found') in ["on", "true", "1"] if not interpreter and capability[0] < 8: - pytest.skip("Flash attention only supported for compute capability < 80") + pytest.skip("Flash attention only supported for compute capability >= 80") torch.manual_seed(20) q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() diff --git a/python/triton/ops/flash_attention.py b/python/triton/ops/flash_attention.py index 1ae37c297f81..187ec21377af 100644 --- a/python/triton/ops/flash_attention.py +++ b/python/triton/ops/flash_attention.py @@ -24,6 +24,7 @@ def _fwd_kernel( stride_vz, stride_vh, stride_vn, stride_vk, stride_oz, stride_oh, stride_om, stride_on, Z, H, N_CTX, + Z_H_N_CTX, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, IS_CAUSAL: tl.constexpr, @@ -31,27 +32,21 @@ def _fwd_kernel( start_m = tl.program_id(0) off_hz = tl.program_id(1) qvk_offset = off_hz * stride_qh - Q_block_ptr = tl.make_block_ptr( - base=Q + qvk_offset, - shape=(N_CTX, BLOCK_DMODEL), - strides=(stride_qm, stride_qk), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0) - ) + vk_offset = qvk_offset // stride_qm + K_block_ptr = tl.make_block_ptr( - base=K + qvk_offset, - shape=(BLOCK_DMODEL, N_CTX), + base=K, + shape=(BLOCK_DMODEL, Z_H_N_CTX), strides=(stride_kk, stride_kn), - offsets=(0, 0), + offsets=(0, vk_offset), block_shape=(BLOCK_DMODEL, BLOCK_N), order=(0, 1) ) V_block_ptr = tl.make_block_ptr( - base=V + qvk_offset, - shape=(N_CTX, BLOCK_DMODEL), + base=V, + shape=(Z_H_N_CTX, BLOCK_DMODEL), strides=(stride_vn, stride_vk), - offsets=(0, 0), + offsets=(vk_offset, 0), block_shape=(BLOCK_N, BLOCK_DMODEL), order=(1, 0) ) @@ -68,7 +63,11 @@ def _fwd_kernel( # don't work as expected with `exp` in the loop qk_scale = sm_scale * 1.44269504 # load q: it will stay in SRAM throughout - q = tl.load(Q_block_ptr) + + offs_k = tl.arange(0, BLOCK_DMODEL) + Q_ptrs = Q + qvk_offset + offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk + q = tl.load(Q_ptrs) + q = (q * qk_scale).to(K.dtype.element_ty) lo = 0 hi = (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX @@ -100,13 +99,14 @@ def _fwd_kernel( tl.store(l_ptrs, m_i + tl.math.log2(l_i)) # write back O O_block_ptr = tl.make_block_ptr( - base=Out + qvk_offset, - shape=(N_CTX, BLOCK_DMODEL), + base=Out, + shape=(Z_H_N_CTX, BLOCK_DMODEL), strides=(stride_om, stride_on), - offsets=(start_m * BLOCK_M, 0), + offsets=(vk_offset + start_m * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_DMODEL), order=(1, 0) ) + # O_ptrs = Out + qvk_offset + offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk tl.store(O_block_ptr, acc.to(K.dtype.element_ty)) @@ -312,6 +312,7 @@ def forward(ctx, q, k, v, causal, sm_scale, sequence_parallel=False): v.stride(0), v.stride(1), v.stride(2), v.stride(3), o.stride(0), o.stride(1), o.stride(2), o.stride(3), q.shape[0], q.shape[1], q.shape[2], + q.shape[0] * q.shape[1] * q.shape[2], BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk, IS_CAUSAL=causal, num_warps=num_warps, From 8094f466326d877e0bf202c51430b82d915e23ae Mon Sep 17 00:00:00 2001 From: peterbell10 Date: Thu, 21 Sep 2023 04:31:20 +0100 Subject: [PATCH 072/122] [FRONTEND][BACKEND] Fix various atomic_rmw bugs (#2355) This fixes a few bugs I've encountered - `atomic_add` with int64/uint64 `Operation .add requires .u32 or .s32 or .u64 [...] for instruction 'atom'` - `atomic_min/max` with float64 -> `ValueError('Cannot bitcast data-type of size 64 to data-type of size 32')` - `atomic_min/max` with float32 returns the old value as int32 --- .../TritonGPUToLLVM/LoadStoreOpToLLVM.cpp | 2 +- python/test/unit/language/test_core.py | 4 +++ python/triton/language/semantic.py | 34 +++++++++++++------ 3 files changed, 29 insertions(+), 11 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp index 6b01b3bac546..44bc9e40b734 100644 --- a/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -1103,7 +1103,7 @@ struct AtomicRMWOpConversion sTy = "b" + sBits; break; case RMWOp::ADD: - sTy = "s" + sBits; + sTy = "u" + sBits; break; case RMWOp::FADD: rmwOp = "add"; diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index e67bc9b4fb75..63a6e2310a6c 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1119,8 +1119,11 @@ def kernel(X, Y, Z): [ ('add', 'float16', mode, sem), ('add', 'uint32', mode, sem), ('add', 'int32', mode, sem), ('add', 'float32', mode, sem), + ('add', 'uint64', mode, sem), ('add', 'int64', mode, sem), ('add', 'float64', mode, sem), ('max', 'uint32', mode, sem), ('max', 'int32', mode, sem), ('max', 'float32', mode, sem), + ('max', 'uint64', mode, sem), ('max', 'int64', mode, sem), ('max', 'float64', mode, sem), ('min', 'uint32', mode, sem), ('min', 'int32', mode, sem), ('min', 'float32', mode, sem), + ('min', 'uint64', mode, sem), ('min', 'int64', mode, sem), ('min', 'float64', mode, sem), ] for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos'] for sem in [None, 'acquire', 'release', 'acq_rel', 'relaxed']])) @@ -1139,6 +1142,7 @@ def kernel(X, Z): pid = tl.program_id(0) x = tl.load(X + pid) old = GENERATE_TEST_HERE + tl.static_assert(old.dtype == x.dtype) sem_arg = sem if sem is None else f'"{sem}"' kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.atomic_{op}(Z, x, sem={sem_arg})'}) diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 7b9109b23b87..6b4de892ed30 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1139,13 +1139,20 @@ def atomic_max(ptr: tl.tensor, # for float # return atomic_smax(i_ptr, i_val) if val >= 0 # return atomic_umin(i_ptr, i_val) if val < 0 - i_val = bitcast(val, tl.int32, builder) - i_ptr = bitcast(ptr, tl.pointer_type(tl.int32, 1), builder) - pos = greater_equal(val, tl.tensor(builder.get_fp32(0), sca_ty), builder) - neg = less_than(val, tl.tensor(builder.get_fp32(0), sca_ty), builder) + if sca_ty not in {tl.float32, tl.float64}: + raise TypeError(f"atomic_max not supported for dtype {sca_ty}") + + itype = tl.int32 if sca_ty == tl.float32 else tl.float64 + zero = full([], 0.0, sca_ty, builder) + + i_val = bitcast(val, itype, builder) + i_ptr = bitcast(ptr, tl.pointer_type(itype, 1), builder) + pos = greater_equal(val, zero, builder) + neg = less_than(val, zero, builder) pos_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, i_ptr.handle, i_val.handle, and_(mask, pos, builder).handle, sem), i_val.type) neg_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, i_ptr.handle, i_val.handle, and_(mask, neg, builder).handle, sem), i_val.type) - return where(pos, pos_ret, neg_ret, builder) + ret = where(pos, pos_ret, neg_ret, builder) + return bitcast(ret, sca_ty, builder) def atomic_min(ptr: tl.tensor, @@ -1175,10 +1182,16 @@ def atomic_min(ptr: tl.tensor, # for float # return atomic_smin(i_ptr, i_val) if val >= 0 # return atomic_umax(i_ptr, i_val) if val < 0 - i_val = bitcast(val, tl.int32, builder) - i_ptr = bitcast(ptr, tl.pointer_type(tl.int32, 1), builder) - pos = greater_equal(val, tl.tensor(builder.get_fp32(0), sca_ty), builder) - neg = less_than(val, tl.tensor(builder.get_fp32(0), sca_ty), builder) + if sca_ty not in {tl.float32, tl.float64}: + raise TypeError(f"atomic_min not supported for dtype {sca_ty}") + + itype = tl.int32 if sca_ty == tl.float32 else tl.float64 + zero = full([], 0.0, sca_ty, builder) + + i_val = bitcast(val, itype, builder) + i_ptr = bitcast(ptr, tl.pointer_type(itype, 1), builder) + pos = greater_equal(val, zero, builder) + neg = less_than(val, zero, builder) pos_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, i_ptr.handle, i_val.handle, @@ -1191,7 +1204,8 @@ def atomic_min(ptr: tl.tensor, and_(mask, neg, builder).handle, sem), i_val.type) - return where(pos, pos_ret, neg_ret, builder) + ret = where(pos, pos_ret, neg_ret, builder) + return bitcast(ret, sca_ty, builder) def atomic_add(ptr: tl.tensor, From be9849bda9eda6b4aba2eee31d1b9dc9639933bf Mon Sep 17 00:00:00 2001 From: Tori Baker Date: Thu, 21 Sep 2023 15:42:29 +0200 Subject: [PATCH 073/122] [BACKEND] Set min bitwidth of shared store&load (#2358) Using `i1` results in st.shared.b1, which does not exist. Set a min bit width here to handle this case. Resolves issue https://github.com/openai/triton/issues/2351 Co-authored-by: Keren Zhou --- lib/Conversion/TritonGPUToLLVM/Utility.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index b2464c33e617..6d8fc8509f47 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -240,7 +240,7 @@ Value linearize(ConversionPatternRewriter &rewriter, Location loc, Value storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr, Value val, Value pred) { MLIRContext *ctx = rewriter.getContext(); - unsigned bits = val.getType().getIntOrFloatBitWidth(); + unsigned bits = std::max(8u, val.getType().getIntOrFloatBitWidth()); const char *c = bits == 64 ? "l" : (bits == 16 ? "h" : "r"); PTXBuilder builder; @@ -257,7 +257,7 @@ Value loadShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr, auto ptrTy = ptr.getType().cast(); assert(ptrTy.getAddressSpace() == 3 && "Invalid addr space for loadShared"); auto elemTy = ptrTy.getElementType(); - unsigned bitwidth = elemTy.getIntOrFloatBitWidth(); + unsigned bitwidth = std::max(8u, elemTy.getIntOrFloatBitWidth()); const char *c = bitwidth == 64 ? "=l" : (bitwidth == 16 ? "=h" : "=r"); From e36c99b58876306a9eebfff328d5c2d93b4f7e78 Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Thu, 21 Sep 2023 12:00:41 -0700 Subject: [PATCH 074/122] [BACKEND] Handle scan of function non commutative (#2362) Make sure we accumulate in the right order for scans so that non commutative operations are handled correctly. --- .../TritonGPUToLLVM/ScanOpToLLVM.cpp | 45 ++++++++++--------- python/test/unit/language/test_core.py | 27 ++++++++--- 2 files changed, 47 insertions(+), 25 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp index 22aefd912d61..f91223265a44 100644 --- a/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp @@ -11,13 +11,15 @@ using ::mlir::LLVM::shflIdxSync; using ::mlir::LLVM::shflUpSync; using ::mlir::LLVM::storeShared; -// Apply the region of the scan op to the acc and cur values and update acc -// inplace with the result. -static void accumulate(ConversionPatternRewriter &rewriter, Region &combineOp, - Value &acc, Value cur) { - if (!acc) { - acc = cur; - return; +// apply combine region to a and b and return the result. If a or b is null, +// return the other operand. +static Value accumulate(ConversionPatternRewriter &rewriter, Region &combineOp, + Value a, Value b) { + if (!a) { + return b; + } + if (!b) { + return a; } // Create a new copy of the reduce block, and inline it Block *currentBlock = rewriter.getBlock(); @@ -25,13 +27,14 @@ static void accumulate(ConversionPatternRewriter &rewriter, Region &combineOp, rewriter.cloneRegionBefore(combineOp, &parent.front()); auto &newScan = parent.front(); auto returnOp = dyn_cast(newScan.getTerminator()); - llvm::SmallVector combineArgs = {acc, cur}; + llvm::SmallVector combineArgs = {a, b}; rewriter.inlineBlockBefore(&newScan, &*rewriter.getInsertionPoint(), combineArgs); auto results = returnOp.getResult(); - acc = results[0]; + Value acc = results[0]; // Delete the terminator, which is no longer used rewriter.eraseOp(returnOp); + return acc; } // Scan a contiguous elements within a thread and update `srcValues` in place. @@ -49,8 +52,8 @@ static void scanThreadContiguousElements(SmallVector &srcValues, unsigned accIndex = (srcIndex % stride) + ((srcIndex / stride) / scanElementsPerThreads) * stride; - accumulate(rewriter, helper.getCombineOp(), accs[accIndex], - srcValues[srcIndex]); + accs[accIndex] = accumulate(rewriter, helper.getCombineOp(), accs[accIndex], + srcValues[srcIndex]); srcValues[srcIndex] = accs[accIndex]; } } @@ -75,7 +78,7 @@ static void warpScan(SmallVector &srcValues, for (unsigned i = 1; i <= (scanDim) / 2; i = i << 1) { Value shfl = shflUpSync(loc, rewriter, acc, i * threadStride); Value tempAcc = acc; - accumulate(rewriter, helper.getCombineOp(), tempAcc, shfl); + tempAcc = accumulate(rewriter, helper.getCombineOp(), shfl, tempAcc); Value mask = icmp_slt(laneIdAxis, i32_val(i)); acc = select(mask, acc, tempAcc); } @@ -175,14 +178,14 @@ static void AddPartialReduce(SmallVector &srcValues, accumulator.maskedAcc = partialReduce; continue; } - accumulate(rewriter, helper.getCombineOp(), accumulator.acc, - partialReduce); + accumulator.acc = accumulate(rewriter, helper.getCombineOp(), + accumulator.acc, partialReduce); Value mask = icmp_slt(warpId, i32_val(i + 1)); accumulator.maskedAcc = select(mask, accumulator.maskedAcc, accumulator.acc); } - Value temp = srcValues[srcIndex]; - accumulate(rewriter, helper.getCombineOp(), temp, accumulator.maskedAcc); + Value temp = accumulate(rewriter, helper.getCombineOp(), + accumulator.maskedAcc, srcValues[srcIndex]); if (axisBlockId == 0) { // For the first warp and first chunk we don't have anything to // accumulate. @@ -195,7 +198,8 @@ static void AddPartialReduce(SmallVector &srcValues, lastElement = select(maskFirstLane, accumulator.maskedAcc, lastElement); for (unsigned i = 1; i < scanElementsPerThreads; ++i) { Value laneValue = srcValues[srcIndex - i * elementStride]; - accumulate(rewriter, helper.getCombineOp(), laneValue, lastElement); + laneValue = + accumulate(rewriter, helper.getCombineOp(), lastElement, laneValue); if (axisBlockId == 0) { // For the first warp and first chunk we don't have anything to // accumulate. @@ -251,8 +255,8 @@ static void AddPartialReduceOneWarp(SmallVector &srcValues, if (axisBlockId == 0) // First chunk and first block accumulator = srcValues[srcIndex]; else - accumulate(rewriter, helper.getCombineOp(), srcValues[srcIndex], - accumulator); + srcValues[srcIndex] = accumulate(rewriter, helper.getCombineOp(), + accumulator, srcValues[srcIndex]); // Update the rest of the contiguous elements. Value lastElement = srcValues[srcIndex]; if (scanDim > 1) { @@ -266,7 +270,8 @@ static void AddPartialReduceOneWarp(SmallVector &srcValues, } for (unsigned i = 1; i < scanElementsPerThreads; ++i) { Value laneValue = srcValues[srcIndex - i * elementStride]; - accumulate(rewriter, helper.getCombineOp(), laneValue, lastElement); + laneValue = + accumulate(rewriter, helper.getCombineOp(), lastElement, laneValue); if (axisBlockId == 0) // For the first warp and first chunk we don't have anything to // accumulate. diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 63a6e2310a6c..ca2cb208c643 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1716,10 +1716,16 @@ def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexp for type in ['int32', 'float32'] for axis in [1, 0] for shape in scan2d_shapes - for op in ['cumsum', 'cumprod'] + for op in ['cumsum', 'cumprod', 'get_first_element'] ] +@triton.jit +# trivial associative but not commutative function +def get_first_element(a, b): + return a + + @pytest.mark.parametrize("op, dtype_str, shape, axis, num_warps", scan_configs) def test_scan2d(op, dtype_str, shape, axis, num_warps, device): if is_hip(): @@ -1735,15 +1741,26 @@ def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexp z = GENERATE_TEST_HERE tl.store(Z + range_m[:, None] * BLOCK_N + range_n[None, :], z) - kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.{op}(x, axis={axis})'}) + if op == 'cumsum' or op == 'cumprod': + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.{op}(x, axis={axis})'}) + else: + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.associative_scan(x, axis={axis}, combine_fn={op})'}) # input rs = RandomState(17) x = numpy_random(shape, dtype_str=dtype_str, rs=rs) z = np.empty_like(x) x_tri = to_triton(x, device=device) - numpy_op = {'cumsum': np.cumsum, 'cumprod': np.cumprod}[op] - z_dtype_str = dtype_str - z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str)) + if op == 'cumsum' or op == 'cumprod': + numpy_op = {'cumsum': np.cumsum, 'cumprod': np.cumprod}[op] + z_dtype_str = dtype_str + z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str)) + else: + assert op == 'get_first_element' + z_ref = x + if axis == 0: + z_ref[1:] = x[0] + else: + z_ref[:, 1:] = x[:, 0:1] # triton result z_tri = to_triton(z, device=device) kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis, num_warps=num_warps) From c4bc3fd92f79a1a470f0a3a086c0fa525124bc6d Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Thu, 21 Sep 2023 13:46:30 -0700 Subject: [PATCH 075/122] [BACKEND] Fix-up memory leak (#2365) --- lib/Target/PTX/PTXTranslation.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/lib/Target/PTX/PTXTranslation.cpp b/lib/Target/PTX/PTXTranslation.cpp index fe8841997c35..3ae1bac1a15a 100644 --- a/lib/Target/PTX/PTXTranslation.cpp +++ b/lib/Target/PTX/PTXTranslation.cpp @@ -16,6 +16,7 @@ #include #include +#include #include #include @@ -87,9 +88,9 @@ std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version) { opt.UnsafeFPMath = false; opt.NoInfsFPMath = false; opt.NoNaNsFPMath = true; - llvm::TargetMachine *machine = target->createTargetMachine( + std::unique_ptr machine{target->createTargetMachine( module.getTargetTriple(), proc, features, opt, llvm::Reloc::PIC_, - std::nullopt, llvm::CodeGenOpt::Aggressive); + std::nullopt, llvm::CodeGenOpt::Aggressive)}; // set data layout if (layout.empty()) module.setDataLayout(machine->createDataLayout()); From 32c9d2bb8fa4e75b2cf1c60cc5e3392c113276f2 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Thu, 21 Sep 2023 15:05:57 -0700 Subject: [PATCH 076/122] [FRONTEND] improved error messages (#2363) this is a combination of #1774 and #2006, which I cannot edit but fail CI pre-commit hook --- python/test/unit/language/test_core.py | 2 +- python/triton/language/semantic.py | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index ca2cb208c643..3c5707c27248 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -746,7 +746,7 @@ def test_invalid_pid_axis(device): def _kernel(dst): pid = tl.program_id(20) - with pytest.raises(triton.CompilationError, match=r"program_id must be in \[0,3\]"): + with pytest.raises(triton.CompilationError, match=r"program_id axis must be 0, 1, or 2 but got 20"): _kernel[(1,)](dst) diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 6b4de892ed30..c9b6fef79cfa 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -25,10 +25,14 @@ def __init__(self, type_a, type_b): # ===----------------------------------------------------------------------===## def program_id(axis: int, builder: ir.builder) -> tl.tensor: + if axis not in (0, 1, 2): + raise ValueError(f"program_id axis must be 0, 1, or 2 but got {axis}") return tl.tensor(builder.create_get_program_id(axis), tl.int32) def num_programs(axis: int, builder: ir.builder) -> tl.tensor: + if axis not in (0, 1, 2): + raise ValueError(f"num_programs axis must be 0, 1, or 2 but got {axis}") return tl.tensor(builder.create_get_num_programs(axis), tl.int32) # ===----------------------------------------------------------------------===// @@ -128,6 +132,8 @@ def add(input: tl.tensor, input, other = binary_op_type_checking_impl(input, other, builder, True, True) input_scalar_ty = input.type.scalar other_scalar_ty = other.type.scalar + if input_scalar_ty.is_ptr() and other_scalar_ty.is_ptr(): + raise ValueError("cannot add pointers together") # offset + ptr # ptr + offset From d543eb1a364471814993d7f12f9ef05c47adb603 Mon Sep 17 00:00:00 2001 From: Alexander Zinoviev <8257131+alexander-zinoviev@users.noreply.github.com> Date: Thu, 21 Sep 2023 16:40:53 -0700 Subject: [PATCH 077/122] [BACKEND] implement `dot` for INT8 on Turing (#2364) Replace a single mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32 instruction that is used on Ampere with 4 x mma.sync.aligned.m8n8k16.row.col.satfinite.s32.s8.s8.s32 instructions for Turing Extracted the Turing-int8, Turing-fp16 and Ampere to separate functions. Somehow I messed up with my previous PR, so just open a new one. --------- Co-authored-by: Philippe Tillet --- .../TritonGPUToLLVM/DotOpToLLVM/MMAv2.cpp | 147 ++++++++++++++---- python/test/unit/language/test_core.py | 17 +- 2 files changed, 124 insertions(+), 40 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/MMAv2.cpp b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/MMAv2.cpp index bd222f9e6a72..870c8980b490 100644 --- a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/MMAv2.cpp +++ b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/MMAv2.cpp @@ -145,6 +145,9 @@ inline static const std::map mmaInstrPtxTuring = { {TensorCoreType::FP32_FP16_FP16_FP32, "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"}, + {TensorCoreType::INT32_INT8_INT8_INT32, + "mma.sync.aligned.m8n8k16.row.col.satfinite.s32.s8.s8.s32"}, + {TensorCoreType::FP16_FP16_FP16_FP16, "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16"}, }; @@ -168,6 +171,107 @@ inline static const std::map mmaInstrPtxAmpere = { "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16"}, }; +static void callMmaTuringInt8(PTXBuilder &builder, unsigned m, unsigned n, + unsigned k, mlir::triton::PTXInstr &mma, + unsigned numMmaRets, unsigned colsPerThread, + int numCPackedElem, ValueTableV2 &ha, + ValueTableV2 &hb, const SmallVector &fc) { + auto retArgs1 = builder.newListOperand(numMmaRets / 2, "=r"); + auto retArgs2 = builder.newListOperand(numMmaRets / 2, "=r"); + auto cArgs1 = builder.newListOperand(); + for (int i = 0; i < numMmaRets / 2; ++i) { + cArgs1->listAppend( + builder.newOperand(fc[(m * colsPerThread + 4 * n) / numCPackedElem + i], + std::to_string(i))); + // reuse the output registers + } + auto cArgs2 = builder.newListOperand(); + for (int i = numMmaRets / 2; i < numMmaRets; ++i) { + cArgs2->listAppend( + builder.newOperand(fc[(m * colsPerThread + 4 * n) / numCPackedElem + i], + std::to_string(i))); + // reuse the output registers + } + auto aArgs1 = builder.newListOperand({ + {ha[{m, k}], "r"}, + }); + auto bArgs1 = builder.newListOperand({ + {hb[{n, k}], "r"}, + }); + auto aArgs2 = builder.newListOperand({ + {ha[{m, k + 1}], "r"}, + }); + auto bArgs2 = builder.newListOperand({{hb[{n, k + 1}], "r"}}); + auto aArgs3 = builder.newListOperand({ + {ha[{m + 1, k}], "r"}, + }); + auto bArgs3 = builder.newListOperand({ + {hb[{n, k}], "r"}, + }); + auto aArgs4 = builder.newListOperand({ + {ha[{m + 1, k + 1}], "r"}, + }); + auto bArgs4 = builder.newListOperand({{hb[{n, k + 1}], "r"}}); + mma(retArgs1, aArgs1, bArgs1, cArgs1); + mma(retArgs1, aArgs2, bArgs2, cArgs1); + mma(retArgs2, aArgs3, bArgs3, cArgs2); + mma(retArgs2, aArgs4, bArgs4, cArgs2); +} + +static void callMmaTuringFp16(PTXBuilder &builder, unsigned m, unsigned n, + unsigned k, mlir::triton::PTXInstr &mma, + unsigned numMmaRets, unsigned colsPerThread, + int numCPackedElem, ValueTableV2 &ha, + ValueTableV2 &hb, const SmallVector &fc, + bool isAccF16) { + auto retArgs = builder.newListOperand(numMmaRets, isAccF16 ? "=r" : "=f"); + auto cArgs = builder.newListOperand(); + for (int i = 0; i < numMmaRets; ++i) { + cArgs->listAppend( + builder.newOperand(fc[(m * colsPerThread + 4 * n) / numCPackedElem + i], + std::to_string(i))); + // reuse the output registers + } + auto aArgs1 = builder.newListOperand({ + {ha[{m, k}], "r"}, + {ha[{m + 1, k}], "r"}, + }); + auto bArgs1 = builder.newListOperand({{hb[{n, k}], "r"}}); + auto aArgs2 = builder.newListOperand({ + {ha[{m, k + 1}], "r"}, + {ha[{m + 1, k + 1}], "r"}, + }); + auto bArgs2 = builder.newListOperand({{hb[{n, k + 1}], "r"}}); + mma(retArgs, aArgs1, bArgs1, cArgs); + mma(retArgs, aArgs2, bArgs2, cArgs); +} + +static void callMmaAmpere(PTXBuilder &builder, unsigned m, unsigned n, + unsigned k, mlir::triton::PTXInstr &mma, + unsigned numMmaRets, unsigned colsPerThread, + int numCPackedElem, ValueTableV2 &ha, + ValueTableV2 &hb, const SmallVector &fc, + bool isAccF16, bool isIntMMA) { + auto retArgs = + builder.newListOperand(numMmaRets, isIntMMA || isAccF16 ? "=r" : "=f"); + auto cArgs = builder.newListOperand(); + for (int i = 0; i < numMmaRets; ++i) { + cArgs->listAppend( + builder.newOperand(fc[(m * colsPerThread + 4 * n) / numCPackedElem + i], + std::to_string(i))); + // reuse the output registers + } + auto aArgs = builder.newListOperand({ + {ha[{m, k}], "r"}, + {ha[{m + 1, k}], "r"}, + {ha[{m, k + 1}], "r"}, + {ha[{m + 1, k + 1}], "r"}, + }); + auto bArgs = + builder.newListOperand({{hb[{n, k}], "r"}, {hb[{n, k + 1}], "r"}}); + mma(retArgs, aArgs, bArgs, cArgs); +} + LogicalResult convertDot(TritonGPUToLLVMTypeConverter *typeConverter, ConversionPatternRewriter &rewriter, Location loc, Value a, Value b, Value c, Value d, Value loadedA, @@ -215,42 +319,19 @@ LogicalResult convertDot(TritonGPUToLLVMTypeConverter *typeConverter, // using =r for float32 works but leads to less readable ptx. bool isIntMMA = dTensorTy.getElementType().isInteger(32); bool isAccF16 = dTensorTy.getElementType().isF16(); - auto retArgs = - builder.newListOperand(numMmaRets, isIntMMA || isAccF16 ? "=r" : "=f"); - auto cArgs = builder.newListOperand(); - for (int i = 0; i < numMmaRets; ++i) { - cArgs->listAppend(builder.newOperand( - fc[(m * colsPerThread + 4 * n) / numCPackedElem + i], - std::to_string(i))); - // reuse the output registers - } if (isTuring) { - auto aArgs1 = builder.newListOperand({ - {ha[{m, k}], "r"}, - {ha[{m + 1, k}], "r"}, - }); - auto bArgs1 = builder.newListOperand({ - {hb[{n, k}], "r"}, - }); - auto aArgs2 = builder.newListOperand({ - {ha[{m, k + 1}], "r"}, - {ha[{m + 1, k + 1}], "r"}, - }); - auto bArgs2 = builder.newListOperand({{hb[{n, k + 1}], "r"}}); - mma(retArgs, aArgs1, bArgs1, cArgs); - mma(retArgs, aArgs2, bArgs2, cArgs); - } else { - auto aArgs = builder.newListOperand({ - {ha[{m, k}], "r"}, - {ha[{m + 1, k}], "r"}, - {ha[{m, k + 1}], "r"}, - {ha[{m + 1, k + 1}], "r"}, - }); - auto bArgs = - builder.newListOperand({{hb[{n, k}], "r"}, {hb[{n, k + 1}], "r"}}); - mma(retArgs, aArgs, bArgs, cArgs); + if (isIntMMA) // Turing int8 + callMmaTuringInt8(builder, m, n, k, mma, numMmaRets, colsPerThread, + numCPackedElem, ha, hb, fc); + else // Turing fp16 + callMmaTuringFp16(builder, m, n, k, mma, numMmaRets, colsPerThread, + numCPackedElem, ha, hb, fc, isAccF16); + } else { // Ampere + callMmaAmpere(builder, m, n, k, mma, numMmaRets, colsPerThread, + numCPackedElem, ha, hb, fc, isAccF16, isIntMMA); } + Value mmaOut = builder.launch(rewriter, loc, getMmaRetType(mmaType, op.getContext())); diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 3c5707c27248..8a4d1a2607b4 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -2256,7 +2256,7 @@ def kernel(X, stride_xm, stride_xn, [32, 32, 128, 16], [128, 128, 64, 2], [64, 128, 128, 2]] - for allow_tf32 in [True] + for allow_tf32 in [True, False] for col_a in [True, False] for col_b in [True, False] for in_dtype, out_dtype in [('int8', 'int8'), @@ -2282,12 +2282,12 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o if capability[0] < 7: pytest.skip("Only test tl.dot() on devices with sm >= 70") if capability[0] < 8: - if in_dtype == 'int8': - pytest.skip("Only test int8 on devices with sm >= 80") - elif allow_tf32: + if capability[1] == 0 and in_dtype == 'int8': + pytest.skip("Only test int8 on devices with sm >= 75") + if allow_tf32: pytest.skip("Only test tf32 on devices with sm >= 80") if capability[0] == 7: - if (M, N, K, num_warps) == (128, 256, 32, 8): + if (M, N, K, num_warps) in [(128, 256, 32, 8), (64, 128, 128, 4), (64, 128, 128, 2)]: pytest.skip("shared memory out of resource") if out_dtype == 'float16': # TODO: support out_dtype=float16 for tl.dot on V100 @@ -2472,8 +2472,11 @@ def kernel(X, stride_xm, stride_xk, else: assert re.search(r'[mma|wgmma.mma_async].sync.aligned.m\d+n\d+k16(?:.row.col)?.f16.f16.f16', ptx) elif in_dtype == 'int8': - assert 'wgmma.mma_async.sync.aligned' in ptx or\ - 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx + if capability[0] == 7 and capability[1] == 5: # Turing + assert 'mma.sync.aligned.m8n8k16.row.col.satfinite.s32.s8.s8.s32' in ptx + else: + assert 'wgmma.mma_async.sync.aligned' in ptx or\ + 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx @pytest.mark.parametrize('in_dtype', ['float32']) From c71ec14f31966de6de79742ac0a3700e0e694bb9 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Thu, 21 Sep 2023 21:23:19 -0700 Subject: [PATCH 078/122] [TEST] only test 4 configs without TF32 (#2370) --- .github/workflows/integration-tests.yml | 3 ++- python/test/unit/language/test_core.py | 7 +++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 2424c3034177..c0b7e987aa56 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -100,7 +100,8 @@ jobs: if: ${{ env.BACKEND == 'CUDA' && env.ENABLE_TMA == '1' && env.ENABLE_MMA_V3 == '1'}} run: | cd python/test/unit - python3 -m pytest -n 8 --ignore=runtime --ignore=operators --ignore=language/test_line_info.py + python3 -m pytest -n 8 --ignore=runtime --ignore=operators --ignore=language/test_line_info.py --ignore=language/test_subprocess.py + python3 -m pytest -n 8 language/test_subprocess.py # run runtime tests serially to avoid race condition with cache handling. python3 -m pytest runtime/ # run test_line_info.py separately with TRITON_DISABLE_LINE_INFO=0 diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 8a4d1a2607b4..53f026b3c42f 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -2256,13 +2256,16 @@ def kernel(X, stride_xm, stride_xn, [32, 32, 128, 16], [128, 128, 64, 2], [64, 128, 128, 2]] - for allow_tf32 in [True, False] + for allow_tf32 in [True] for col_a in [True, False] for col_b in [True, False] for in_dtype, out_dtype in [('int8', 'int8'), ('float16', 'float16'), ('float16', 'float32'), - ('float32', 'float32')]]) + ('float32', 'float32')]] + + + [(64, 64, 64, 4, col_a, col_b, 'none', False, 'float32', 'float32') + for col_a in [True, False] for col_b in [True, False]]) @pytest.mark.parametrize("num_ctas", num_ctas_list) def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, out_dtype, num_ctas, device): check_cuda_only(device) From 293b7fd592a1602f2305c1bd0bc978bbd97337d6 Mon Sep 17 00:00:00 2001 From: Zahi Moudallal <128723247+zahimoud@users.noreply.github.com> Date: Thu, 21 Sep 2023 22:37:14 -0700 Subject: [PATCH 079/122] [TESTING] cleanup (#2293) Co-authored-by: Philippe Tillet --- python/test/regression/test_performance.py | 18 +++------------ python/triton/ops/matmul_perf_model.py | 9 +++++--- python/triton/testing.py | 26 ++++------------------ 3 files changed, 13 insertions(+), 40 deletions(-) diff --git a/python/test/regression/test_performance.py b/python/test/regression/test_performance.py index b22fea3e53a9..dcab2e168ade 100644 --- a/python/test/regression/test_performance.py +++ b/python/test/regression/test_performance.py @@ -1,13 +1,10 @@ -import subprocess -import sys - import pytest import torch import triton import triton.language as tl import triton.ops -from triton.testing import get_dram_gbps, get_max_tensorcore_tflops +from triton.testing import get_dram_gbps, get_max_tensorcore_tflops, nvsmi DEVICE_NAME = {7: 'v100', 8: 'a100'}[torch.cuda.get_device_capability()[0]] @@ -21,15 +18,6 @@ def print_perf(cur_ms, cur_util, ref_util): print(f'{cur_ms:.3f} ms \t cur: {cur_util:.3f} \t ref: {ref_util:.3f} \t dif={cur_util - ref_util:.3f}', end='\t') -def nvsmi(attrs): - attrs = ','.join(attrs) - cmd = ['nvidia-smi', '-i', '0', '--query-gpu=' + attrs, '--format=csv,noheader,nounits'] - out = subprocess.check_output(cmd) - ret = out.decode(sys.stdout.encoding).split(',') - ret = [int(x) for x in ret] - return ret - - ####################### # Matrix Multiplication ####################### @@ -51,9 +39,9 @@ def nvsmi(attrs): (16, 8192, 8192): {'float16': 0.077, 'float32': 0.077, 'int8': 0.043}, (64, 1024, 1024): {'float16': 0.018, 'float32': 0.023, 'int8': 0.017}, (64, 4096, 4096): {'float16': 0.150, 'float32': 0.000, 'int8': 0.097}, - (64, 8192, 8192): {'float16': 0.338, 'float32': 0.000, 'int8': 0.174}, + (64, 8192, 8192): {'float16': 0.214, 'float32': 0.000, 'int8': 0.174}, (1024, 64, 1024): {'float16': 0.029, 'float32': 0.046, 'int8': 0.017}, - (4096, 64, 4096): {'float16': 0.179, 'float32': 0.214, 'int8': 0.102}, + (4096, 64, 4096): {'float16': 0.136, 'float32': 0.214, 'int8': 0.102}, (8192, 64, 8192): {'float16': 0.278, 'float32': 0.000, 'int8': 0.177}, # test EVEN_K==False (8192, 8192, 8176): {'float16': 0.786, 'float32': 0.743, 'int8': 0.51}, diff --git a/python/triton/ops/matmul_perf_model.py b/python/triton/ops/matmul_perf_model.py index cdb66bc9cad2..abe5325ee056 100644 --- a/python/triton/ops/matmul_perf_model.py +++ b/python/triton/ops/matmul_perf_model.py @@ -5,14 +5,16 @@ from .. import cdiv from .._C.libtriton.triton import runtime from ..runtime import driver -from ..testing import get_dram_gbps, get_max_simd_tflops, get_max_tensorcore_tflops +from ..testing import (get_dram_gbps, get_max_simd_tflops, get_max_tensorcore_tflops, + nvsmi) def get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype): ''' return compute throughput in TOPS ''' total_warps = num_ctas * min(num_warps, 4) num_subcores = driver.utils.get_device_properties(device)["multiprocessor_count"] * 4 # on recent GPUs - tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops(dtype, backend, device) + cur_sm_clock = nvsmi(['clocks.current.sm'])[0] + tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops(dtype, cur_sm_clock, backend, device) return tflops @@ -20,7 +22,8 @@ def get_simd_tflops(backend, device, num_ctas, num_warps, dtype): ''' return compute throughput in TOPS ''' total_warps = num_ctas * min(num_warps, 4) num_subcores = driver.utils.get_device_properties(device)["multiprocessor_count"] * 4 # on recent GPUs - tflops = min(num_subcores, total_warps) / num_subcores * get_max_simd_tflops(dtype, backend, device) + cur_sm_clock = nvsmi(['clocks.current.sm'])[0] + tflops = min(num_subcores, total_warps) / num_subcores * get_max_simd_tflops(dtype, cur_sm_clock, backend, device) return tflops diff --git a/python/triton/testing.py b/python/triton/testing.py index 69ee467d6b07..da7664adda86 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -368,7 +368,7 @@ def get_dram_gbps(backend=None, device=None): return bw_gbps -def get_max_tensorcore_tflops(dtype, backend=None, device=None, clock_rate=None): +def get_max_tensorcore_tflops(dtype, clock_rate, backend=None, device=None): import torch from .runtime import driver @@ -378,8 +378,6 @@ def get_max_tensorcore_tflops(dtype, backend=None, device=None, clock_rate=None) device = torch.cuda.current_device() num_subcores = driver.utils.get_device_properties(device)["multiprocessor_count"] * 4 - if not clock_rate: - clock_rate = driver.utils.get_device_properties(device)["sm_clock_rate"] # in kHz capability = torch.cuda.get_device_capability(device) if capability[0] < 8: assert dtype == torch.float16 @@ -423,21 +421,6 @@ def wrapper(*args, **kwargs): return decorator -def nvsmi_attr(attrs): - attrs = ",".join(attrs) - cmd = [ - "nvidia-smi", - "-i", - "0", - "--query-gpu=" + attrs, - "--format=csv,noheader,nounits", - ] - out = subprocess.check_output(cmd) - ret = out.decode(sys.stdout.encoding).split(",") - ret = [int(x) for x in ret] - return ret - - @contextmanager def set_gpu_clock(ref_sm_clock=1350, ref_mem_clock=1215): try: @@ -458,8 +441,8 @@ def set_gpu_clock(ref_sm_clock=1350, ref_mem_clock=1215): f"--lock-memory-clocks={ref_mem_clock},{ref_mem_clock}", ] ) - cur_sm_clock = nvsmi_attr(["clocks.current.sm"])[0] - cur_mem_clock = nvsmi_attr(["clocks.current.memory"])[0] + cur_sm_clock = nvsmi(["clocks.current.sm"])[0] + cur_mem_clock = nvsmi(["clocks.current.memory"])[0] assert abs(cur_sm_clock - ref_sm_clock) < 10, f"GPU SMs must run at {ref_sm_clock} MHz" assert abs(cur_mem_clock - ref_mem_clock) < 10, f"GPU SMs must run at {ref_mem_clock} MHz" tflops = 1e-6 * 2 * 108 * 4 * 256 * ref_sm_clock @@ -471,7 +454,7 @@ def set_gpu_clock(ref_sm_clock=1350, ref_mem_clock=1215): subprocess.check_output(["nvidia-smi", "-i", "0", "-rmc"]) -def get_max_simd_tflops(dtype, backend=None, device=None): +def get_max_simd_tflops(dtype, clock_rate, backend=None, device=None): import torch from .runtime import driver @@ -481,7 +464,6 @@ def get_max_simd_tflops(dtype, backend=None, device=None): device = torch.cuda.current_device() num_subcores = driver.utils.get_device_properties(device)["multiprocessor_count"] * 4 - clock_rate = driver.utils.get_device_properties(device)["sm_clock_rate"] # in kHz capability = torch.cuda.get_device_capability() if capability[0] < 8: if dtype == torch.float32: From 413b18eb73515f86fa57313c847c85a61a5c6465 Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Fri, 22 Sep 2023 23:34:20 +0800 Subject: [PATCH 080/122] [FROJTEND] fix core.dtype.__repr__ (#2372) `function_type` does not have a `name` field, which leads to an error when debugging with gdb. --- python/triton/language/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 150d3936018f..496aa42000c9 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -275,7 +275,7 @@ def cache_key_part(self) -> str: return self.name def __repr__(self): - return f'triton.language.{self.name}' + return f'triton.language.{str(self)}' class pointer_type(dtype): From 1724604bd9d97b4912e11d5e813e2372b9afc2e1 Mon Sep 17 00:00:00 2001 From: Bin Fan Date: Fri, 22 Sep 2023 11:16:35 -0700 Subject: [PATCH 081/122] [DOCS] Add a tutorial example of grouped gemm (#2326) --- python/tutorials/11-grouped-gemm.py | 297 ++++++++++++++++++++++++++++ 1 file changed, 297 insertions(+) create mode 100644 python/tutorials/11-grouped-gemm.py diff --git a/python/tutorials/11-grouped-gemm.py b/python/tutorials/11-grouped-gemm.py new file mode 100644 index 000000000000..034e4e217abe --- /dev/null +++ b/python/tutorials/11-grouped-gemm.py @@ -0,0 +1,297 @@ +# Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files +# (the "Software"), to deal in the Software without restriction, +# including without limitation the rights to use, copy, modify, merge, +# publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +import torch + +import triton +import triton.language as tl + +# This group gemm kernel launches a fixed number of CTA to compute a group +# of gemms. The scheduling is static and we do it on device + + +@triton.autotune( + configs=[ + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32, + 'NUM_SM': 84, + } + ), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32, + 'NUM_SM': 128, + } + ), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32, + 'NUM_SM': 84, + } + ), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32, + 'NUM_SM': 128, + } + ), + ], + key=['group_size'], +) +@triton.jit +def grouped_matmul_kernel( + # device tensor of matrices pointers + group_a_ptrs, + group_b_ptrs, + group_c_ptrs, + # device tensor of gemm sizes. its shape is [group_size, 3] + # dim 0 is group_size, dim 1 is the values of of each gemm + group_gemm_sizes, + # device tensor of leading dimension sizes. its shape is [group_size, 3] + # dim 0 is group_size, dim 1 is the values of of each gemm + g_lds, + # number of gemms + group_size, + # number of virtual SM + NUM_SM: tl.constexpr, + # tile sizes + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + tile_idx = tl.program_id(0) + last_problem_end = 0 + for g in range(group_size): + # get the gemm size of the current problem + gm = tl.load(group_gemm_sizes + g * 3) + gn = tl.load(group_gemm_sizes + g * 3 + 1) + gk = tl.load(group_gemm_sizes + g * 3 + 2) + num_m_tiles = tl.cdiv(gm, BLOCK_SIZE_M) + num_n_tiles = tl.cdiv(gn, BLOCK_SIZE_N) + num_tiles = num_m_tiles * num_n_tiles + # iterate through the tiles in the current gemm problem + while ( + tile_idx >= last_problem_end and tile_idx < last_problem_end + num_tiles + ): + # pick up a tile from the current gemm problem + k = gk + lda = tl.load(g_lds + g * 3) + ldb = tl.load(g_lds + g * 3 + 1) + ldc = tl.load(g_lds + g * 3 + 2) + a_ptr = tl.load(group_a_ptrs + g).to(tl.pointer_type(tl.float16)) + b_ptr = tl.load(group_b_ptrs + g).to(tl.pointer_type(tl.float16)) + c_ptr = tl.load(group_c_ptrs + g).to(tl.pointer_type(tl.float16)) + # figure out tile coordinates + tile_idx_in_gemm = tile_idx - last_problem_end + tile_m_idx = tile_idx_in_gemm // num_n_tiles + tile_n_idx = tile_idx_in_gemm % num_n_tiles + + # do regular gemm here + offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + offs_am[:, None] * lda + offs_k[None, :] + b_ptrs = b_ptr + offs_k[:, None] * ldb + offs_bn[None, :] + accumulator = tl.zeros( + (BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32 + ) + for kk in range(0, tl.cdiv(k, BLOCK_SIZE_K)): + # hint to Triton compiler to do proper loop pipelining + tl.multiple_of(a_ptrs, [16, 16]) + tl.multiple_of(b_ptrs, [16, 16]) + # assume full tile for now + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K + b_ptrs += BLOCK_SIZE_K * ldb + c = accumulator.to(tl.float16) + + offs_cm = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + ldc * offs_cm[:, None] + offs_cn[None, :] + + # assumes full tile for now + tl.store(c_ptrs, c) + + # go to the next tile by advancing NUM_SM + tile_idx += NUM_SM + + # get ready to go to the next gemm problem + last_problem_end = last_problem_end + num_tiles + + +def group_gemm_fn(group_A, group_B): + device = torch.device('cuda') + assert len(group_A) == len(group_B) + group_size = len(group_A) + + A_addrs = [] + B_addrs = [] + C_addrs = [] + g_sizes = [] + g_lds = [] + group_C = [] + for i in range(group_size): + A = group_A[i] + B = group_B[i] + assert A.shape[1] == B.shape[0] + M, K = A.shape + K, N = B.shape + C = torch.empty((M, N), device=device, dtype=A.dtype) + group_C.append(C) + A_addrs.append(A.data_ptr()) + B_addrs.append(B.data_ptr()) + C_addrs .append(C.data_ptr()) + g_sizes += [M, N, K] + g_lds += [A.stride(0), B.stride(0), C.stride(0)] + + # note these are device tensors + d_a_ptrs = torch.tensor(A_addrs, device=device) + d_b_ptrs = torch.tensor(B_addrs, device=device) + d_c_ptrs = torch.tensor(C_addrs, device=device) + d_g_sizes = torch.tensor( + g_sizes, dtype=torch.int32, device=device + ) + d_g_lds = torch.tensor( + g_lds, dtype=torch.int32, device=device + ) + # we use a fixed number of CTA, and it's auto-tunable + grid = lambda META: (META['NUM_SM'],) + grouped_matmul_kernel[grid]( + d_a_ptrs, + d_b_ptrs, + d_c_ptrs, + d_g_sizes, + d_g_lds, + group_size, + ) + + return group_C + + +group_m = [1024, 512, 256, 128] +group_n = [1024, 512, 256, 128] +group_k = [1024, 512, 256, 128] +group_A = [] +group_B = [] +assert len(group_m) == len(group_n) +assert len(group_n) == len(group_k) +group_size = len(group_m) +for i in range(group_size): + M = group_m[i] + N = group_n[i] + K = group_k[i] + A = torch.rand((M, K), device="cuda", dtype=torch.float16) + B = torch.rand((K, N), device="cuda", dtype=torch.float16) + group_A.append(A) + group_B.append(B) + +tri_out = group_gemm_fn(group_A, group_B) +ref_out = [torch.matmul(a, b) for a, b in zip(group_A, group_B)] +for i in range(group_size): + assert torch.allclose(ref_out[i], tri_out[i], atol=1e-2, rtol=0) + + +# only launch the kernel, no tensor preparation here to remove all overhead +def triton_perf_fn(a_ptrs, b_ptrs, c_ptrs, sizes, lds, group_size): + grid = lambda META: (META['NUM_SM'],) + grouped_matmul_kernel[grid]( + a_ptrs, + b_ptrs, + c_ptrs, + sizes, + lds, + group_size, + ) + + +def torch_perf_fn(group_A, group_B): + for a, b in zip(group_A, group_B): + torch.matmul(a, b) + + +@triton.testing.perf_report( + triton.testing.Benchmark( + # argument names to use as an x-axis for the plot + x_names=['N'], + x_vals=[2 ** i for i in range(7, 11)], # different possible values for `x_name` + line_arg='provider', + # argument name whose value corresponds to a different line in the plot + # possible values for `line_arg`` + line_vals=['cublas', 'triton'], + # label name for the lines + line_names=["cuBLAS", "Triton"], + # line styles + styles=[('green', '-'), ('blue', '-')], + ylabel="runtime(ms)", # label name for the y-axis + plot_name="group-gemm-performance", + # name for the plot. Used also as a file name for saving the plot. + args={}, + ) +) +def benchmark(N, provider): + group_size = 4 + group_A = [] + group_B = [] + A_addrs = [] + B_addrs = [] + C_addrs = [] + g_sizes = [] + g_lds = [] + group_C = [] + for i in range(group_size): + A = torch.rand((N, N), device="cuda", dtype=torch.float16) + B = torch.rand((N, N), device="cuda", dtype=torch.float16) + C = torch.empty((N, N), device="cuda", dtype=torch.float16) + group_A.append(A) + group_B.append(B) + group_C.append(C) + A_addrs.append(A.data_ptr()) + B_addrs.append(B.data_ptr()) + C_addrs .append(C.data_ptr()) + g_sizes += [N, N, N] + g_lds += [N, N, N] + + d_a_ptrs = torch.tensor(A_addrs, device="cuda") + d_b_ptrs = torch.tensor(B_addrs, device="cuda") + d_c_ptrs = torch.tensor(C_addrs, device="cuda") + d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device="cuda") + d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device="cuda") + + quantiles = [0.5, 0.2, 0.8] + if provider == 'cublas': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch_perf_fn(group_A, group_B), quantiles=quantiles) + if provider == 'triton': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: triton_perf_fn(d_a_ptrs, d_b_ptrs, d_c_ptrs, d_g_sizes, d_g_lds, group_size), quantiles=quantiles) + return ms, max_ms, min_ms + + +benchmark.run(show_plots=True, print_data=True) From 840e7e7b530c9c14913db1c413f938093fb44b56 Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Fri, 22 Sep 2023 15:21:56 -0700 Subject: [PATCH 082/122] [BACKEND] Improve decision of MMA dimension on H100 (#2373) When there is a chain of mma ops we want to pick the same shape to avoid conversions. This improves the detection going through for loops. This fixes a crash in tutorial bw attention. We might want to change this logic and convert the format to allow more efficient MMA at some point. --- .../Dialect/TritonGPU/Transforms/Utility.h | 10 +++++ .../TritonGPU/Transforms/AccelerateMatmul.cpp | 27 ++++++------ lib/Dialect/TritonGPU/Transforms/Utility.cpp | 39 ++++++++++++++++ test/TritonGPU/accelerate-matmul.mlir | 44 +++++++++++++++++++ 4 files changed, 106 insertions(+), 14 deletions(-) create mode 100644 test/TritonGPU/accelerate-matmul.mlir diff --git a/include/triton/Dialect/TritonGPU/Transforms/Utility.h b/include/triton/Dialect/TritonGPU/Transforms/Utility.h index fe9f9f8c5953..375fdfac2356 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Utility.h +++ b/include/triton/Dialect/TritonGPU/Transforms/Utility.h @@ -141,6 +141,16 @@ Value linearize(OpBuilder &b, Location loc, ArrayRef multiDim, Value linearize(OpBuilder &b, Location loc, ArrayRef multiDim, ArrayRef shape); +// Implement backward and forward slice that will go through scf blocks when +// yield or scf results are in the slice. +// Note that like exisiting forward and backard slice this may add operations to +// the slice that are not actually dependent on the root because when a region +// is added to the slice in the forward slice all the operations of the region +// are added. We could implement a more accurate slice method by tracking value +// usage across scf regions. +void getBackwardSliceSCFAware(Operation *, SetVector *slices); +void getForwardSliceSCFAware(Value root, SetVector *slices); + } // namespace mlir #endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_ diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp index 8c82098416a8..3473c2123519 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -102,8 +102,7 @@ warpsPerTileV3(tt::DotOp dotOp, const ArrayRef shape, int numWarps, class BlockedToMMA : public mlir::RewritePattern { int computeCapability; mutable int mmaV1Counter{}; // used to generate ID for MMAv1 encoding - mutable llvm::SmallVector> dotOpSetVector; - mutable llvm::SmallVector mmaV3InstrNs; + mutable llvm::DenseMap dotOpInstNs; static bool bwdFilter(Operation *op) { return op->getNumOperands() == 1 && @@ -150,36 +149,36 @@ class BlockedToMMA : public mlir::RewritePattern { auto type = dotOp.getResult().getType().cast(); if (type.getEncoding().isa()) return currN; - for (size_t i = 0; i < dotOpSetVector.size(); ++i) { - if (dotOpSetVector[i].count(dotOp.getOperation()) > 0) - return mmaV3InstrNs[i]; - } + auto it = dotOpInstNs.find(dotOp.getOperation()); + if (it != dotOpInstNs.end()) + return it->second; SetVector slices; - mlir::getForwardSlice(dotOp.getResult(), &slices); - mlir::getBackwardSlice(dotOp.getOperation(), &slices); + mlir::getForwardSliceSCFAware(dotOp.getResult(), &slices); + mlir::getBackwardSliceSCFAware(dotOp.getOperation(), &slices); unsigned N = currN; - llvm::SetVector dotOpSet; + SmallVector dotOps; for (Operation *iter : slices) { if (auto nextDotOp = dyn_cast(iter)) { auto type = nextDotOp.getResult().getType().cast(); auto AType = nextDotOp.getOperand(0).getType().cast(); auto shapePerCTA = ttg::getShapePerCTA(type); auto instrShape = mmaVersionToInstrShape(3, shapePerCTA, AType); - dotOpSet.insert(iter); + dotOps.push_back(iter); if (instrShape[1] < N) N = instrShape[1]; } } - mmaV3InstrNs.push_back(N); - dotOpSetVector.push_back(dotOpSet); + for (Operation *dotOp : dotOps) + dotOpInstNs[dotOp] = N; return N; } static Value getMMAv3Operand(Value v, mlir::PatternRewriter &rewriter, int opIdx) { - auto cvtOp = dyn_cast_or_null(v.getDefiningOp()); - auto arg = cvtOp.getSrc(); + Value arg = v; + if (auto cvtOp = v.getDefiningOp()) + arg = cvtOp.getSrc(); auto argType = arg.getType().cast(); auto eltType = argType.getElementType(); assert(argType.getEncoding() && "unexpected tensor type"); diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index b4a5bbe920de..c455f20dc1ad 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -492,6 +492,45 @@ Value linearize(OpBuilder &b, Location loc, ArrayRef multiDim, return linear; } +void getBackwardSliceSCFAware(Operation *op, SetVector *slices) { + SmallVector queue = {op}; + while (!queue.empty()) { + Operation *currentOp = queue.back(); + queue.pop_back(); + SetVector temp; + auto filter = [slices](Operation *sliceOp) { + return slices->count(sliceOp) == 0; + }; + mlir::getBackwardSlice(currentOp, &temp, filter); + for (Operation *sliceOp : temp) { + if (auto forOp = dyn_cast(sliceOp)) { + queue.push_back(forOp.getBody()->getTerminator()); + } + } + slices->insert(temp.begin(), temp.end()); + } +} + +void getForwardSliceSCFAware(Value root, SetVector *slices) { + SmallVector queue = {root}; + while (!queue.empty()) { + Value currentValue = queue.back(); + queue.pop_back(); + SetVector temp; + auto filter = [slices](Operation *sliceOp) { + return slices->count(sliceOp) == 0; + }; + mlir::getForwardSlice(currentValue, &temp, filter); + for (Operation *sliceOp : temp) { + if (auto yieldOp = dyn_cast(sliceOp)) { + auto forOp = yieldOp->getParentOfType(); + queue.append(forOp->getResults().begin(), forOp->getResults().end()); + } + } + slices->insert(temp.begin(), temp.end()); + } +} + namespace { /// Detect dead arguments in scf.for op by assuming all the values are dead and diff --git a/test/TritonGPU/accelerate-matmul.mlir b/test/TritonGPU/accelerate-matmul.mlir new file mode 100644 index 000000000000..fbb83a1aec39 --- /dev/null +++ b/test/TritonGPU/accelerate-matmul.mlir @@ -0,0 +1,44 @@ +// RUN: ENABLE_MMA_V3=1 triton-opt %s -split-input-file --tritongpu-accelerate-matmul=compute-capability=90 | FileCheck %s + +// CHECK: #[[MMA:.+]] = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 16, 16]}> +#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + // CHECK: mma_chain_loop + tt.func public @mma_chain_loop( + %170: tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>, + %171: tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>, + %179: tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked1}>>, + %164: tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked2}>>, + %165: tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked2}>>, + %173: tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked1}>>, + %153: tensor<128x64x!tt.ptr, #blocked1>) { + %c0_i32 = arith.constant 0 : i32 + %c8_i32 = arith.constant 8 : i32 + %c1_i32 = arith.constant 1 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<128x16xf16, #blocked> + %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x64xf16, #blocked1> + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #blocked2> + // CHECK: scf.for + // CHECK: tt.dot {{.*}} -> tensor<128x16xf16, #[[MMA]]> + // CHECK: tt.dot {{.*}} -> tensor<128x64xf16, #[[MMA]]> + %115 = scf.for %arg15 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg16 = %cst_0) -> (tensor<128x64xf16, #blocked1>) : i32 { + %172 = tt.dot %170, %171, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x16xf16, #blocked> + %178 = triton_gpu.convert_layout %172 : (tensor<128x16xf16, #blocked>) -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>> + %180 = tt.dot %178, %179, %arg16 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>> * tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked1}>> -> tensor<128x64xf16, #blocked1> + scf.yield %180 : tensor<128x64xf16, #blocked1> + } + // CHECK: scf.for + // CHECK: tt.dot {{.*}} -> tensor<128x32xf16, #[[MMA]]> + // CHECK: tt.dot {{.*}} -> tensor<128x64xf16, #[[MMA]]> + %149 = scf.for %arg15 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg16 = %115) -> (tensor<128x64xf16, #blocked1>) : i32 { + %166 = tt.dot %164, %165, %cst_2 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked2}>> * tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked2}>> -> tensor<128x32xf16, #blocked2> + %172 = triton_gpu.convert_layout %166 : (tensor<128x32xf16, #blocked2>) -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>> + %174 = tt.dot %172, %173, %arg16 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>> * tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked1}>> -> tensor<128x64xf16, #blocked1> + scf.yield %174 : tensor<128x64xf16, #blocked1> + } + tt.store %153, %149 {cache = 1 : i32, evict = 1 : i32} : tensor<128x64xf16, #blocked1> + tt.return + } +} From 215b2e77a1d92f907bc4addfa3c3ba204028c32e Mon Sep 17 00:00:00 2001 From: ian Bearman Date: Fri, 22 Sep 2023 15:29:31 -0700 Subject: [PATCH 083/122] Add Shared Middle Layer to Triton via Plug-In (#2374) This PR leverages the plug-in system to add a shared middle-layer to Triton. Currently the middle layer is not complete but has enough functionality to demonstrate how it can work. The general idea is that Triton IR is lowered into an MLIR core dialect to allow it to be both shared across Triton targets as well as allow back-ends to be shared with other languages. The basic intended architecture looks like this: [Triton IR] -> [Middle Layer] -> [HW specific IR] The middle-layer uses MLIR's Linalg and Tenor Dialects for operations on Triton block values. Operations on Triton pointers use the Memref Dialect. ## Usage To include the shared middle-layer in your Triton build do `export TRITON_CODEGEN_TRITON_SHARED=1` before invoking your build. Once it is part of the build it can be leveraged in two ways: ### Stand-Alone The middle layer can be used as a stand-alone component to convert Triton dialect to the middle layer dialects. Stand-alone example: ``` triton-shared-opt --triton-to-linalg %file ``` ### Backend Component The middle layer can also be used as a component in a Triton back-end by adding the cmake targets it produces and its headers files to that back-end. An example back-end will be published at a later date. ## Implementation details Even though a valid triton program can perform load and store in arbitrary memory locations, the prototype only supports lowering programs that have structured memory access patterns. ### Analyses As part of the conversion process, there are three important analyses: 1. Pointer analysis: + This analysis is responsible for extracting structured memory access patterns from a `triton` program during load and store; it walks the IR and visits relevant instructions to build strided memory accesses in the `memref` dialect. The analysis is still in its early stage and does not support all scenarios. 2. Use analysis: + After "Pointer analysis", instructions that are part of memory address calculation will no longer be necessary in a triton program because their semantics have now been captured by `memref` operations representing strided memory accesses. To aid with removing these instructions safely, we perform `Use analysis` to mark which instructions are used *only* in address calculation (called `MetaUse`) or used in *both* address calculation and data manipulation (called `MixedUse`) operations. Those that are `MixedUse` are cloned and have their users adjusted accordingly with the goal of separating out the `MetaUse` ops so that they can be safely deleted. 3. Mask analysis: + This analysis is responsible for handling masked loads and stores. ### Conversion strategy We introduce the `TritonToLinalg` pass that converts the `triton` dialect to the `linalg` dialect on *tensors*. This means the resulting IR is fully compatible with `linalg` tiling and fusion transformation passes. As mentioned in the `Pointer analysis`'s description, we do however have to deal with memref instructions at the load and store boundaries and have to convert them to tensors using `bufferization.to_tensor`. Here's a simple example of what the IR looks like: ```mlir tt.func @kernel(%afloat : !tt.ptr, %res : !tt.ptr) { %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> %1 = tt.splat %afloat : (!tt.ptr) -> tensor<128x!tt.ptr> %2 = tt.addptr %1, %0 : tensor<128x!tt.ptr>, tensor<128xi32> %afm = tt.load %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xbf16> %3 = "tt.reduce"(%afm) ({ ^bb0(%arg5: bf16, %arg6: bf16): %21 = arith.addf %arg5, %arg6 : bf16 tt.reduce.return %21 : bf16 }) {axis = 0 : i32} : (tensor<128xbf16>) -> bf16 tt.store %res, %3 : bf16 tt.return } ``` after conversion: ```mlir func.func @kernel(%arg0: memref<*xbf16>, %arg1: memref<*xbf16>, %arg2: i32, %arg3: i32, %arg4: i32) { %cst = arith.constant 0.000000e+00 : f32 %reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [0], sizes: [128], strides: [1] : memref<*xbf16> to memref<128xbf16, strided<[1]>> %alloc = memref.alloc() : memref<128xbf16> memref.copy %reinterpret_cast, %alloc : memref<128xbf16, strided<[1]>> to memref<128xbf16> %0 = bufferization.to_tensor %alloc restrict writable : memref<128xbf16> %1 = bufferization.alloc_tensor() : tensor %inserted = tensor.insert %cst into %1[] : tensor %reduced = linalg.reduce ins(%0 : tensor<128xbf16>) outs(%inserted : tensor) dimensions = [0] (%in: bf16, %init: f32) { %3 = arith.extf %in : bf16 to f32 %4 = arith.addf %3, %init : f32 linalg.yield %4 : f32 } %extracted = tensor.extract %reduced[] : tensor %2 = arith.truncf %extracted : f32 to bf16 %reinterpret_cast_0 = memref.reinterpret_cast %arg1 to offset: [0], sizes: [1], strides: [1] : memref<*xbf16> to memref<1xbf16, strided<[1]>> affine.store %2, %reinterpret_cast_0[0] : memref<1xbf16, strided<[1]>> return } ``` Important details to note: + `tt.load` (together with all of its related address calculation instructions such as `tt.addptr` and `tt.splat`) are lowered to a combination of `memref.reinterpret_cast`, `memref.alloc`, and `memref.copy`. After the initialization of the local buffer, we convert the memref back to a tensor using `bufferization.to_tensor`; this op is automatically removed during bufferization. + `tt.store` lowers to a combination of `memref.reinterpret_cast` and either `affine.store` or `memref.tensor_store`: ``` %reinterpret_cast = memref.reinterpret_cast %arg2 to offset: [...] memref<*xf32> to memref<1024xf32> %extracted_slice = tensor.extract_slice %15[0] [%21] [1] : tensor<1024xf32> to tensor %subview = memref.subview %reinterpret_cast[0] [%21] [1] : memref<1024xf32> to memref memref.tensor_store %extracted_slice, %subview : memref ``` + element-wise `arith` and `math` operators are converted to their corresponding `linalg.generic` version. + `tt.dot` becomes `linalg.matmul`. + `tt.reduce` becomes `linalg.reduce`; known limitation: only support `addf` and `maxf` reduction in the reduction body for now. ### Testing The prototype was tested on the following triton kernel examples: 1. vector addition 2. fused softmax 3. matrix multiplication 4. layer normalization 5. fused attention In addition to testing on the tutorial kernels, I have also added many lit tests covering various scenarios. ## Recognition The work here represents contributions from myself as well as many of my colleagues at Microsoft. I especially want to call out @nhat-nguyen and @haishanzzz who were major contributors to this work. --- .github/workflows/integration-tests.yml | 44 +++++++++++++++++++++++++ .gitmodules | 3 ++ third_party/triton_shared | 1 + 3 files changed, 48 insertions(+) create mode 160000 third_party/triton_shared diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index c0b7e987aa56..3234a1094cd0 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -173,6 +173,50 @@ jobs: python3 -m pytest -vs . --reruns 10 sudo nvidia-smi -i 0 -rgc + Integration-Tests-Shared-Middle-Layer: + + runs-on: ubuntu-latest + + steps: + - name: Checkout + uses: actions/checkout@v2 + + - name: Clear cache + run: | + rm -rf ~/.triton + + - name: Update PATH + run: | + echo "PATH=${HOME}/.local/bin:${PATH}" >> "${GITHUB_ENV}" + + - name: Check pre-commit + run: | + python3 -m pip install --upgrade pre-commit + python3 -m pre_commit run --all-files --verbose + + - name: Install Triton + run: | + export TRITON_CODEGEN_TRITON_SHARED=1 + git submodule update --init --recursive + cd python + python3 -m pip install --upgrade pip + python3 -m pip install cmake==3.24 + python3 -m pip install ninja + python3 -m pip uninstall -y triton + python3 setup.py build + python3 -m pip install --no-build-isolation -vvv '.[tests]' + + - name: Run shared middle-layer lit tests + run: | + python3 -m pip install lit + cd python + LIT_TEST_DIR="build/$(ls build | grep -i cmake)/third_party/triton_shared/test" + if [ ! -d "${LIT_TEST_DIR}" ]; then + echo "Coult not find '${LIT_TEST_DIR}'" ; exit -1 + fi + lit -v "${LIT_TEST_DIR}" + + Integration-Tests-Third-Party: needs: Runner-Preparation if: false diff --git a/.gitmodules b/.gitmodules index 30ba4342537e..3a989c6cc969 100644 --- a/.gitmodules +++ b/.gitmodules @@ -5,3 +5,6 @@ path = third_party/amd_hip_backend url = https://github.com/ROCmSoftwarePlatform/triton branch = third_party_backend_2 +[submodule "third_party/triton_shared"] + path = third_party/triton_shared + url = https://github.com/microsoft/triton-shared diff --git a/third_party/triton_shared b/third_party/triton_shared new file mode 160000 index 000000000000..d0ac5898ff97 --- /dev/null +++ b/third_party/triton_shared @@ -0,0 +1 @@ +Subproject commit d0ac5898ff97ab33c2839306ec10bfa4fab816f5 From cb83b42ed6397d170ab539c9c0a99afff3971476 Mon Sep 17 00:00:00 2001 From: edimetia3d Date: Sat, 23 Sep 2023 08:01:54 +0800 Subject: [PATCH 084/122] [FRONTEND] using closure to create jit launcher (#2289) Hi, I'm adding some features to `triton.runtime.jit.JITFunction_make_launcher` and found it is hard to debug it: 1. The inlined Python code is hard to inspect in my editor. 2. My debugger fails to step into these inlined codes. In response, I've introduced some code to solve these issues. My modifications include: ~~1. Refactoring the launcher's inline Python code, ensuring it only relies on the "self" object.~~ ~~2. Add a utility method that generates a temporary file to create a launcher when debugging kernel in main module~~ Using a closure to hold the launcher's body Because this features might be good to others, I have initiated this Pull Request. ~~Tests are yet to be added; if this submission might be accepted, I will add it later.~~ Since this change is a refactor, no new test was added. --- python/triton/runtime/jit.py | 235 +++++++++++++++++------------------ 1 file changed, 112 insertions(+), 123 deletions(-) diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index 1809ce36cf50..2a5525d1ecde 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -297,29 +297,29 @@ def __init__(self, module, name): return JITFunction.cache_hook(key=key, repr=repr, fn=LegacyCompiler(module, name), compile={ "key": key, **kwargs}, is_manual_warmup=False, already_compiled=False) - def _get_arg_specialization_key(self, arg) -> str: - arg_annotation = self.__annotations__.get(arg, '') + def _get_arg_specialization_key(self, arg_name, arg): + arg_annotation = self.__annotations__.get(arg_name, '') if arg_annotation == '': - return f'({arg}.data_ptr() % {JITFunction.divisibility} == 0) if hasattr({arg}, "data_ptr") \ - else ({arg} % {JITFunction.divisibility} == 0, {arg} % {JITFunction.divisibility_8} == 0, {arg} == 1) if isinstance({arg}, int) \ - else (False,)' + return (arg.data_ptr() % JITFunction.divisibility == 0) if hasattr(arg, "data_ptr") \ + else (arg % JITFunction.divisibility == 0, arg % JITFunction.divisibility_8 == 0, arg == 1) if isinstance(arg, int) \ + else (False,) elif 'Tensor' in arg_annotation: - return f'({arg}.data_ptr() % {JITFunction.divisibility} == 0)' + return (arg.data_ptr() % JITFunction.divisibility == 0) elif 'int' in arg_annotation or 'bool' in arg_annotation: - return f'({arg} % {JITFunction.divisibility} == 0, {arg} % {JITFunction.divisibility_8} == 0, {arg} == 1)' + return (arg % JITFunction.divisibility == 0, arg % JITFunction.divisibility_8 == 0, arg == 1) else: - return '(False,)' + return (False,) - def _get_arg_sig_key(self, arg) -> str: - arg_annotation = self.__annotations__.get(arg, '') + def _get_arg_sig_key(self, arg_name, arg) -> str: + arg_annotation = self.__annotations__.get(arg_name, '') if 'Tensor' in arg_annotation: - return f'{arg}.dtype' + return arg.dtype elif arg_annotation == 'bool': return "i1" elif arg_annotation == 'float': return 'fp32' else: - return f'_key_of({arg})' + return self._key_of(arg) def _conclude_device_type(self, device_types: List[str], pinned_memory_flags: List[bool]) -> str: device_types = [device_type for device_type in device_types if device_type != ''] @@ -337,124 +337,113 @@ def _conclude_device_type(self, device_types: List[str], pinned_memory_flags: Li return device_types[0] if len(device_types) > 0 else 'cuda' def _make_launcher(self): - regular_args = [f'{arg}' for i, arg in enumerate( + regular_args = [arg for i, arg in enumerate( self.arg_names) if i not in self.constexprs] - constexpr_args = [ - f'{arg}' for i, arg in enumerate( - self.arg_names) if i in self.constexprs] - args = ', '.join(regular_args) - # cache key for regular argument type - sig_keys = ', '.join([self._get_arg_sig_key(arg) for arg in regular_args]) - device_types = '[' + ', '.join([f'_device_of({arg})' for arg in regular_args]) + ']' - pinned_memory_flags = '[' + ', '.join([f'_pinned_memory_of({arg})' for arg in regular_args]) + ']' - # cache key for constexpr argument values - constexpr_keys = ', '.join(constexpr_args) - # cache key for argument specialization - specializations = [] - for i, arg in enumerate(regular_args): - if i in self.do_not_specialize: - continue - specializations += [self._get_arg_specialization_key(arg)] - - spec_keys = ', '.join(specializations) - grid_args = ','.join([f'"{arg}": {arg}' for arg in self.arg_names]) + constexpr_args = [arg for i, arg in enumerate( + self.arg_names) if i in self.constexprs] + + def regular_args_v(args_proxy): + return [args_proxy[arg_name] for arg_name in regular_args] + + def launcher_body(args_proxy, grid, num_warps, num_ctas, num_stages, enable_warp_specialization, extern_libs, stream, warmup, device, device_type): + from ..compiler import (CompiledKernel, compile, + get_arch_default_num_stages, + get_arch_default_num_warps) + sig_key = tuple([self._get_arg_sig_key(arg_name, args_proxy[arg_name]) for arg_name in regular_args]) + constexpr_key = tuple([args_proxy[arg_name] for arg_name in constexpr_args]) + specializations = [] + for i, arg_name in enumerate(regular_args): + if i in self.do_not_specialize: + continue + specializations += [self._get_arg_specialization_key(arg_name, args_proxy[arg_name])] + + spec_key = tuple(specializations) + assert num_ctas > 0 + assert grid is not None + if callable(grid): + grid = grid(args_proxy) + grid_size = len(grid) + grid_0 = grid[0] + grid_1 = grid[1] if grid_size > 1 else 1 + grid_2 = grid[2] if grid_size > 2 else 1 + if device_type is None: + device_types = [self._device_of(arg) for arg in regular_args_v(args_proxy)] + device_types = [_device_type for _device_type in device_types if _device_type != ''] + device_type = self._conclude_device_type(device_types, [self._pinned_memory_of(arg) for arg in + regular_args_v(args_proxy)]) + + device_backend = None + if device_type not in ['cuda']: + device_backend = get_backend(device_type) + if device_backend is None: + raise ValueError('Cannot find backend for ' + device_type) + + if device is None: + if device_type in ['cuda']: + device = get_current_device() + set_current_device(device) + else: + device = device_backend.get_current_device() + device_backend.set_current_device(device) + if stream is None and not warmup: + if device_type in ['cuda']: + stream = get_cuda_stream(device) + else: + stream = device_backend.get_stream() + + if num_warps is None: + num_warps = get_arch_default_num_warps(device_type) + if num_stages is None: + num_stages = get_arch_default_num_stages(device_type) + + key = (version_key(), sig_key, constexpr_key, spec_key, num_warps, num_ctas, num_stages, enable_warp_specialization, self.debug) + if extern_libs is not None: + key = (key, tuple(extern_libs.items())) + + bin = self.cache[device].get(key, None) + if bin is not None: + # build dict of constant values + args = regular_args_v(args_proxy) + # Create tensormaps and append to args + args = bin.assemble_tensormap_to_arg(args) + if not warmup: + bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.num_ctas, bin.clusterDims[0], bin.clusterDims[1], bin.clusterDims[2], bin.shared, stream, bin.cu_function, CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, bin, *args) + return bin + # kernel not cached -- compile + else: + # build dict of constant values + args = regular_args_v(args_proxy) + all_args = tuple([args_proxy[arg_name] for arg_name in self.arg_names]) + configs = self._get_config(*all_args), + constants = self._make_constants(constexpr_key) + constants.update({i: None for i, arg in enumerate(all_args) if arg is None}) + constants.update({i: 1 for i in configs[0].equal_to_1}) + # build kernel signature -- doesn't include specialized arguments + signature = {i: self._type_of(self._key_of(arg)) for i, arg in enumerate(all_args) if i not in self.constexprs} + # build stub signature -- includes arguments that are specialized + for i, arg in constants.items(): + if callable(arg): + raise TypeError(f"Callable constexpr at index {i} is not supported") + if not self._call_hook(key, signature, device, constants, num_warps, num_ctas, num_stages, enable_warp_specialization, extern_libs, configs): + bin = compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, enable_warp_specialization=enable_warp_specialization, extern_libs=extern_libs, configs=configs, debug=self.debug, device_type=device_type) + # Create tensormaps and append to args + args = bin.assemble_tensormap_to_arg(args) + if not warmup: + bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.num_ctas, bin.clusterDims[0], bin.clusterDims[1], bin.clusterDims[2], bin.shared, stream, bin.cu_function, CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, bin, *args) + self.cache[device][key] = bin + return bin + return None + + # create a wrapper to call launcher_body + args_map = ','.join([f'"{arg}": {arg}' for arg in self.arg_names]) args_signature = ', '.join(name if dflt == inspect._empty else f'{name} = triton.language.dtype(\'{dflt}\')' if dtype.is_dtype(f'{dflt}') else f'{name} = {dflt}' for name, dflt in zip(self.arg_names, self.arg_defaults)) args_signature = args_signature + ', ' if len(args_signature) > 0 else '' - src = f""" import triton def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, num_stages=None, enable_warp_specialization=False, extern_libs=None, stream=None, warmup=False, device=None, device_type=None): - from ..compiler import compile, CompiledKernel, get_arch_default_num_warps, get_arch_default_num_stages - sig_key = {f'{sig_keys},' if len(sig_keys) > 0 else ()} - constexpr_key = {f'{constexpr_keys},' if len(constexpr_keys) > 0 else ()} - spec_key = {f'{spec_keys},' if len(spec_keys) > 0 else ()} - assert num_ctas > 0 - assert grid is not None - if callable(grid): - grid = grid({{{grid_args}}}) - grid_size = len(grid) - grid_0 = grid[0] - grid_1 = grid[1] if grid_size > 1 else 1 - grid_2 = grid[2] if grid_size > 2 else 1 - - if device_type is None: - device_types = [_device_type for _device_type in {device_types} if _device_type != ''] - device_type = self._conclude_device_type(device_types, {pinned_memory_flags}) - - device_backend = None - if device_type not in ['cuda']: - device_backend = get_backend(device_type) - if device_backend is None: - raise ValueError('Cannot find backend for ' + device_type) - - if device is None: - if device_type in ['cuda']: - device = get_current_device() - set_current_device(device) - else: - device = device_backend.get_current_device() - device_backend.set_current_device(device) - if stream is None and not warmup: - if device_type in ['cuda']: - stream = get_cuda_stream(device) - else: - stream = device_backend.get_stream() - - if num_warps is None: - num_warps = get_arch_default_num_warps(device_type) - if num_stages is None: - num_stages = get_arch_default_num_stages(device_type) - - key = (version_key, sig_key, constexpr_key, spec_key, num_warps, num_ctas, num_stages, enable_warp_specialization, self.debug) - if not extern_libs is None: - key = (key, tuple(extern_libs.items())) - - bin = cache[device].get(key, None) - if bin is not None: - # build dict of constant values - args = [{args}] - # Create tensormaps and append to args - args = bin.assemble_tensormap_to_arg(args) - if not warmup: - bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.num_ctas, bin.clusterDims[0], bin.clusterDims[1], bin.clusterDims[2], bin.shared, stream, bin.cu_function, CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, bin, *args) - return bin - # kernel not cached -- compile - else: - # build dict of constant values - args = [{args}] - all_args = {', '.join([f'{arg}' for arg in self.arg_names]) + ', ' if len(self.arg_names) > 0 else ()} - configs = self._get_config(*all_args), - constants = self._make_constants(constexpr_key) - constants.update({{i: None for i, arg in enumerate(all_args) if arg is None}}) - constants.update({{i: 1 for i in configs[0].equal_to_1}}) - # build kernel signature -- doesn't include specialized arguments - signature = {{ i: self._type_of(_key_of(arg)) for i, arg in enumerate(all_args) if i not in self.constexprs }} - # build stub signature -- includes arguments that are specialized - for i, arg in constants.items(): - if callable(arg): - raise TypeError(f"Callable constexpr at index {{i}} is not supported") - if not self._call_hook(key, signature, device, constants, num_warps, num_ctas, num_stages, enable_warp_specialization, extern_libs, configs): - bin = compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, enable_warp_specialization=enable_warp_specialization, extern_libs=extern_libs, configs=configs, debug=self.debug, device_type=device_type) - # Create tensormaps and append to args - args = bin.assemble_tensormap_to_arg(args) - if not warmup: - bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.num_ctas, bin.clusterDims[0], bin.clusterDims[1], bin.clusterDims[2], bin.shared, stream, bin.cu_function, CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, bin, *args) - self.cache[device][key] = bin - return bin - return None + return launcher_body({{{args_map}}}, grid, num_warps, num_ctas, num_stages, enable_warp_specialization, extern_libs, stream, warmup, device, device_type) """ - scope = {"version_key": version_key(), - "get_cuda_stream": get_cuda_stream, - "self": self, - "_spec_of": self._spec_of, - "_key_of": self._key_of, - "_device_of": self._device_of, - "_pinned_memory_of": self._pinned_memory_of, - "cache": self.cache, - "__spec__": __spec__, - "get_backend": get_backend, - "get_current_device": get_current_device, - "set_current_device": set_current_device} + scope = {"launcher_body": launcher_body} exec(src, scope) return scope[self.fn.__name__] From 57fc6d1f13beb4dd4d7db706139df6711623b872 Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Sat, 23 Sep 2023 13:05:20 -0400 Subject: [PATCH 085/122] [BACKEND] `shfl` ptx insts should have side effects (#2376) Otherwise, llvm pass could generate very weird structure of CFG and yield incorrect results. https://github.com/openai/triton/issues/2361 --- .../TritonGPUToLLVM/ScanOpToLLVM.cpp | 4 +-- lib/Conversion/TritonGPUToLLVM/Utility.cpp | 2 +- python/test/unit/operators/test_inductor.py | 25 ++++++++++++++++--- 3 files changed, 24 insertions(+), 7 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp index f91223265a44..c2c4100eef73 100644 --- a/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp @@ -77,8 +77,7 @@ static void warpScan(SmallVector &srcValues, Value acc = srcValues[srcIndex]; for (unsigned i = 1; i <= (scanDim) / 2; i = i << 1) { Value shfl = shflUpSync(loc, rewriter, acc, i * threadStride); - Value tempAcc = acc; - tempAcc = accumulate(rewriter, helper.getCombineOp(), shfl, tempAcc); + Value tempAcc = accumulate(rewriter, helper.getCombineOp(), shfl, acc); Value mask = icmp_slt(laneIdAxis, i32_val(i)); acc = select(mask, acc, tempAcc); } @@ -399,7 +398,6 @@ ScanOpConversion::emitFastScan(triton::ScanOp op, triton::ScanOpAdaptor adaptor, auto input = adaptor.getOperands()[0]; auto type = op.getOperand(0).getType().cast(); auto axisNumWarps = helper.getAxisNumWarpsWithUniqueData(); - auto axisNumThreads = helper.getAxisNumThreadsPerWarpWithUniqueData(); warpIdAxis = urem(warpIdAxis, i32_val(axisNumWarps)); SmallVector srcValues = getTypeConverter()->unpackLLElements(loc, input, rewriter, type); diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index 6d8fc8509f47..7a3483616d11 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -295,7 +295,7 @@ static Value commonShflSync(Location loc, ConversionPatternRewriter &rewriter, auto *cOpr = builder.newConstantOperand(clamp); auto *maskOpr = builder.newConstantOperand("0xffffffff"); shfl(dOpr, aOpr, bOpr, cOpr, maskOpr); - return builder.launch(rewriter, loc, val.getType(), false); + return builder.launch(rewriter, loc, val.getType()); } Value shflSync(Location loc, ConversionPatternRewriter &rewriter, Value val, diff --git a/python/test/unit/operators/test_inductor.py b/python/test/unit/operators/test_inductor.py index 17ba2eb9c8fb..579d0ad935da 100644 --- a/python/test/unit/operators/test_inductor.py +++ b/python/test/unit/operators/test_inductor.py @@ -170,7 +170,26 @@ def fn(in_ptr, out_ptr, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr): tl.store(out_ptr + xindex * RBLOCK + rindex, scan) XBLOCK = 4 - input = torch.randint(0, 10, (1, RBLOCK), dtype=torch.int32, device='cuda') - output = torch.empty((XBLOCK, RBLOCK), dtype=torch.int32, device='cuda') + input = torch.randint(0, 10, (1, RBLOCK), dtype=torch.int64, device='cuda') + output = torch.empty((XBLOCK, RBLOCK), dtype=torch.int64, device='cuda') fn[(1,)](input, output, XBLOCK, RBLOCK, num_warps=num_warps) - torch.testing.assert_allclose(output, input.cumsum(1).broadcast_to((XBLOCK, RBLOCK))) + ref = input.cumsum(1).broadcast_to((XBLOCK, RBLOCK)) + torch.testing.assert_close(output, ref) + + +def test_scan2d_for(): + @triton.jit + def fn(out_ptr0, rnumel, RBLOCK: tl.constexpr): + rbase = tl.arange(0, RBLOCK)[None, :] + for roffset in range(0, rnumel, RBLOCK): + rindex = roffset + rbase + rmask = rindex < rnumel + tmp3 = tl.where(rmask, 1, 0) + tmp6 = tl.cumsum(tmp3, 1) + tl.store(out_ptr0 + rindex, tmp6, rmask) + + RBLOCK = 8 + out0 = torch.empty(RBLOCK, device="cuda", dtype=torch.int64) + fn[(1,)](out0, RBLOCK, RBLOCK) + ref = torch.arange(RBLOCK, device="cuda", dtype=torch.int64) + 1 + torch.testing.assert_close(out0, ref) From a4dbdefe3b6f9cd0aadd9e8cc0987e35d5109eda Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Sat, 23 Sep 2023 11:50:37 -0700 Subject: [PATCH 086/122] [BACKEND] Use shuffle intrinsics instead of inline asm (#2378) This will ensure we get the proper "convergent" semantic for those instructions --- lib/Conversion/TritonGPUToLLVM/Utility.cpp | 44 ++++++++++++---------- test/Conversion/tritongpu_to_llvm.mlir | 4 +- 2 files changed, 27 insertions(+), 21 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index 7a3483616d11..183287fe4445 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -1,7 +1,7 @@ #include "Utility.h" #include "TypeConverter.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "triton/Dialect/NVGPU/IR/Dialect.h" - namespace mlir { namespace LLVM { @@ -270,8 +270,8 @@ Value loadShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr, } static Value commonShflSync(Location loc, ConversionPatternRewriter &rewriter, - Value val, Value i, const std::string &shuffleType, - const std::string &clamp) { + Value val, Value i, NVVM::ShflKind mode, + Value clamp) { unsigned bits = val.getType().getIntOrFloatBitWidth(); if (bits == 64) { @@ -279,43 +279,49 @@ static Value commonShflSync(Location loc, ConversionPatternRewriter &rewriter, Value vec = bitcast(val, vecTy); Value val0 = extract_element(f32_ty, vec, i32_val(0)); Value val1 = extract_element(f32_ty, vec, i32_val(1)); - val0 = commonShflSync(loc, rewriter, val0, i, shuffleType, clamp); - val1 = commonShflSync(loc, rewriter, val1, i, shuffleType, clamp); + val0 = commonShflSync(loc, rewriter, val0, i, mode, clamp); + val1 = commonShflSync(loc, rewriter, val1, i, mode, clamp); vec = undef(vecTy); vec = insert_element(vecTy, vec, val0, i32_val(0)); vec = insert_element(vecTy, vec, val1, i32_val(1)); return bitcast(vec, val.getType()); } - - PTXBuilder builder; - auto &shfl = builder.create("shfl.sync")->o(shuffleType).o("b32"); - auto *dOpr = builder.newOperand("=r"); - auto *aOpr = builder.newOperand(val, "r"); - auto *bOpr = builder.newOperand(i, "r"); - auto *cOpr = builder.newConstantOperand(clamp); - auto *maskOpr = builder.newConstantOperand("0xffffffff"); - shfl(dOpr, aOpr, bOpr, cOpr, maskOpr); - return builder.launch(rewriter, loc, val.getType()); + Type type = val.getType(); + if (type != i32_ty) { + val = bitcast(val, int_ty(bits)); + val = zext(i32_ty, val); + } + Value mask = i32_val(0xFFFFFFFF); + Value result = rewriter.create(loc, i32_ty, mask, val, i, clamp, + mode, UnitAttr()); + if (type != i32_ty) { + result = trunc(int_ty(bits), result); + result = bitcast(result, type); + } + return result; } Value shflSync(Location loc, ConversionPatternRewriter &rewriter, Value val, int i) { - return commonShflSync(loc, rewriter, val, i32_val(i), "bfly", "0x1f"); + return commonShflSync(loc, rewriter, val, i32_val(i), NVVM::ShflKind::bfly, + i32_val(0x1f)); } Value shflUpSync(Location loc, ConversionPatternRewriter &rewriter, Value val, int i) { - return commonShflSync(loc, rewriter, val, i32_val(i), "up", "0x0"); + return commonShflSync(loc, rewriter, val, i32_val(i), NVVM::ShflKind::up, + i32_val(0x0)); } Value shflIdxSync(Location loc, ConversionPatternRewriter &rewriter, Value val, int i) { - return commonShflSync(loc, rewriter, val, i32_val(i), "idx", "0x1f"); + return shflIdxSync(loc, rewriter, val, i32_val(i)); } Value shflIdxSync(Location loc, ConversionPatternRewriter &rewriter, Value val, Value i) { - return commonShflSync(loc, rewriter, val, i, "idx", "0x1f"); + return commonShflSync(loc, rewriter, val, i, NVVM::ShflKind::idx, + i32_val(0x1f)); } Value getSRegValue(OpBuilder &b, Location loc, const std::string &sRegStr) { diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 834c2eecbcc5..978205ec55ce 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -1349,8 +1349,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // CHECK: %[[M:.+]] = llvm.mlir.constant(-1 : i32) : i32 // CHECK: nvvm.redux.sync add %{{.*}}, %[[M]] // CHECK: nvvm.barrier0 -// CHECK: shfl.sync.bfly.b32 -// CHECK: shfl.sync.bfly.b32 +// CHECK: nvvm.shfl.sync bfly +// CHECK: nvvm.shfl.sync bfly // CHECK: nvvm.barrier0 #blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> #blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> From e7abafe4b452404feba2b96ed15f5a03a1dcdeb9 Mon Sep 17 00:00:00 2001 From: kshama-msft <66488860+kshama-msft@users.noreply.github.com> Date: Mon, 25 Sep 2023 10:41:45 -0700 Subject: [PATCH 087/122] [DOCS] create tritonconf2023.md (#2390) File and video location. --- docs/meetups/tritonconf2023.md | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 docs/meetups/tritonconf2023.md diff --git a/docs/meetups/tritonconf2023.md b/docs/meetups/tritonconf2023.md new file mode 100644 index 000000000000..27719b1079e2 --- /dev/null +++ b/docs/meetups/tritonconf2023.md @@ -0,0 +1,27 @@ +The conference slides are available [here](https://drive.google.com/drive/folders/1yDFc4ElNN_GGhWDdMlM4wcm5uFEFFVQk?usp=sharing) + +The conference videos will be available [here](https://youtube.com/playlist?list=PLc_vA1r0qoiRZfUC3o4_yjj0FtWvodKAz&feature=shared) when ready. + +# Triton Developer Conference +The Triton Developer Conference was held in a hybrid mode at the Microsoft Silicon Valley Campus in Mountain View, California. The conference was held on September 20th from 10am to 4pm, followed by a reception till 5:30 pm. + +Agenda for the conference: + +|Time |Title |Speaker +|--------|-------|-------| +|10:00 AM|Welcome|Kevin Scott (Microsoft)| +|10:20 AM|The Triton Compiler: Past, Present and Future|Phil Tillet (OpenAI)| +|11:00 AM|**Break**|| +|11:20 AM|Hopper support in Triton|Gustav Zhu (Nvidia)| +|11:40 AM|Bringing Triton to AMD GPUs|Jason Furmanek, Lixun Zhang (AMD)| +|12:00 PM|Intel XPU Backend for Triton|Eikan Wang (Intel)| +|12:20 PM|Vectorization of Triton Kernels for Qualcomm Hexagon Backend|Javed Absar (Qualcomm)| +|12:30 PM|**Lunch**|| +|1:40 PM |Triton for MTIA|Roman Levenstein et al, (Meta)| +|2:00 PM |Using Triton IR for high-performance fusions in XLA|George Karpenkov (Google)| +|2:20 PM |Triton for All: Triton as a device-independent language|Ian Bearman (Microsoft)| +|2:40 PM|**Break**|| +|3:00 PM|PyTorch 2.0 and TorchInductor|Jason Ansel, Horace He (Meta)| +|3:20 PM|Pallas: A JAX Kernel Language|Sharad Vikram (Google)| +|3:40 PM|Writing Grouped GEMMs in Triton|Vinod Grover (Nvidia)| +|4:00 PM|**Reception**|| From 8ae2ae4f40d099589727287ee4db6a01109c2732 Mon Sep 17 00:00:00 2001 From: kshama-msft <66488860+kshama-msft@users.noreply.github.com> Date: Mon, 25 Sep 2023 10:42:04 -0700 Subject: [PATCH 088/122] [DOCS] update README.md (#2391) Remove conference details. --- README.md | 24 ------------------------ 1 file changed, 24 deletions(-) diff --git a/README.md b/README.md index fbdd3027414b..53611da7709d 100644 --- a/README.md +++ b/README.md @@ -10,30 +10,6 @@ We're hiring! If you are interested in working on Triton at OpenAI, we have role ------------------- | [![Documentation](https://github.com/openai/triton/actions/workflows/documentation.yml/badge.svg)](https://triton-lang.org/) -# Triton Developer Conference Registration Now Closed -The Triton Developer Conference will be held in a hybrid mode at the Microsoft Silicon Valley Campus in Mountain View, California. The conference will be held on September 20th from 10am to 4pm, followed by a reception till 5:30 pm. - -Tentative Agenda for the conference (subject to change): - -|Time |Title |Speaker -|--------|-------|-------| -|10:00 AM|Welcome|Kevin Scott (Microsoft)| -|10:20 AM|The Triton Compiler: Past, Present and Future|Phil Tillet (OpenAI)| -|11:00 AM|**Break**|| -|11:20 AM|Hopper support in Triton|Gustav Zhu (Nvidia)| -|11:40 AM|Bringing Triton to AMD GPUs|Jason Furmanek, Lixun Zhang (AMD)| -|12:00 PM|Intel XPU Backend for Triton|Eikan Wang (Intel)| -|12:20 PM|Vectorization of Triton Kernels for Qualcomm Hexagon Backend|Javed Absar (Qualcomm)| -|12:30 PM|**Lunch**|| -|1:40 PM |Triton for MTIA|Roman Levenstein et al, (Meta)| -|2:00 PM |Using Triton IR for high-performance fusions in XLA|George Karpenkov (Google)| -|2:20 PM |Triton for All: Triton as a device-independent language|Ian Bearman (Microsoft)| -|2:40 PM|**Break**|| -|3:00 PM|PyTorch 2.0 and TorchInductor|Jason Ansel, Horace He (Meta)| -|3:20 PM|Pallas: A JAX Kernel Language|Sharad Vikram (Google)| -|3:40 PM|Writing Grouped GEMMs in Triton|Vinod Grover (Nvidia)| -|4:00 PM|**Reception**|| - # Triton From 6bc1d9e1be047bd9570fc012134cf9674e4b10cc Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Mon, 25 Sep 2023 10:43:54 -0700 Subject: [PATCH 089/122] [BACKEND] Support MMA V3 with register operand (#2375) MMA V3 support taking operand A from register. This helps for chained matmul operations like in attention. Add an optimization to use this mode when it helps and add the lowering for it. --- bin/RegisterTritonDialects.h | 3 +- include/triton/Analysis/Utility.h | 4 + lib/Analysis/Allocation.cpp | 2 + lib/Analysis/Utility.cpp | 14 ++ .../NVGPUToLLVM/NVGPUToLLVMPass.cpp | 15 +- .../TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp | 5 + .../TritonGPUToLLVM/DotOpToLLVM/WGMMA.cpp | 129 +++++++++--------- .../TritonGPUToLLVM/TypeConverter.cpp | 2 +- lib/Dialect/TritonGPU/IR/Dialect.cpp | 16 ++- .../Transforms/OptimizeDotOperands.cpp | 37 +++++ test/Conversion/tritongpu_to_llvm_hopper.mlir | 19 ++- test/NVGPU/test_wgmma.mlir | 2 +- test/TritonGPU/dot-operands.mlir | 16 +++ 13 files changed, 186 insertions(+), 78 deletions(-) diff --git a/bin/RegisterTritonDialects.h b/bin/RegisterTritonDialects.h index 29ba31eaf1f3..e88a9a5395c6 100644 --- a/bin/RegisterTritonDialects.h +++ b/bin/RegisterTritonDialects.h @@ -12,6 +12,7 @@ #include "triton/Conversion/TritonGPUToLLVM/Passes.h" #include "triton/Conversion/TritonToTritonGPU/Passes.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/InitAllPasses.h" namespace mlir { @@ -42,5 +43,5 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) { mlir::triton::gpu::TritonGPUDialect, mlir::math::MathDialect, mlir::arith::ArithDialect, mlir::scf::SCFDialect, mlir::gpu::GPUDialect, mlir::LLVM::LLVMDialect, - mlir::triton::nvgpu::NVGPUDialect>(); + mlir::NVVM::NVVMDialect, mlir::triton::nvgpu::NVGPUDialect>(); } diff --git a/include/triton/Analysis/Utility.h b/include/triton/Analysis/Utility.h index 59af824097f8..364e476be5fd 100644 --- a/include/triton/Analysis/Utility.h +++ b/include/triton/Analysis/Utility.h @@ -133,6 +133,10 @@ bool isMmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy); bool isMmaToMmaShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy); +// Return true if the src and dst layout match. +bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy, + RankedTensorType dstTy); + // TODO: Move utility functions that belong to ConvertLayoutOp to class // ConvertLayoutOpHelper in the future bool shouldUseDistSmem(Attribute srcLayout, Attribute dstLayout); diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index ec3757208016..154a573182bb 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -96,6 +96,8 @@ SmallVector getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec, unsigned &outVec) { auto repShape = getRepShapeForCvtLayout(op); + if (repShape.empty()) + return repShape; auto srcTy = op.getSrc().getType().cast(); auto dstTy = op.getResult().getType().cast(); diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 1d2a6b2e9bc1..dcf8a7704052 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -424,7 +424,21 @@ bool supportMMA(Value value, int version) { (elemTy.isInteger(8) && version >= 2); } +// For MMAV3 dotOperand layout matches mma operand for f16 case. +bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy, + RankedTensorType dstTy) { + auto srcLayout = srcTy.getEncoding(); + auto dstLayout = dstTy.getEncoding(); + auto mmaLayout = srcLayout.cast(); + auto dotOperandLayout = dstLayout.cast(); + return mmaLayout.getVersionMajor() == 3 && dotOperandLayout.getOpIdx() == 0 && + dotOperandLayout.getParent() == mmaLayout && + srcTy.getElementType().isF16(); +} + bool isMmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) { + if (matchMmaV3AndDotOperandLayout(srcTy, dstTy)) + return true; // dot_op = #mma // when #mma = MmaEncoding auto srcLayout = srcTy.getEncoding(); diff --git a/lib/Conversion/NVGPUToLLVM/NVGPUToLLVMPass.cpp b/lib/Conversion/NVGPUToLLVM/NVGPUToLLVMPass.cpp index d9f27700d5f1..02b4c024ff35 100644 --- a/lib/Conversion/NVGPUToLLVM/NVGPUToLLVMPass.cpp +++ b/lib/Conversion/NVGPUToLLVM/NVGPUToLLVMPass.cpp @@ -128,7 +128,8 @@ class NVGPUOpPatternBase : public mlir::RewritePattern { assert(val.getType().getIntOrFloatBitWidth() <= ty.getIntOrFloatBitWidth() && "Cannot convert to a smaller type"); - return zext(ty, val); + if (val.getType().getIntOrFloatBitWidth() < ty.getIntOrFloatBitWidth()) + return zext(ty, val); } } return val; @@ -731,7 +732,7 @@ class WGMMAOpPattern : public NVGPUOpPatternBase { operandsAndConstraints.push_back({opC, "0"}); if (structTypeA) { - operandsAndConstraints.push_back({opA, "f"}); + operandsAndConstraints.push_back({opA, "r"}); } else { operandsAndConstraints.push_back({opA, "l"}); } @@ -820,8 +821,7 @@ class WGMMAOpPattern : public NVGPUOpPatternBase { // Operand A if (structTypeA) { - uint32_t numARegs = m * k / 128; - assert(numARegs == structTypeA.getBody().size()); + uint32_t numARegs = structTypeA.getBody().size(); args += "{"; for (uint32_t i = 0; i < numARegs; ++i) { args += @@ -844,8 +844,11 @@ class WGMMAOpPattern : public NVGPUOpPatternBase { args += ", 1, 1"; // Push `trans-a` and `trans-b` args if needed (determined as constant) - if (needTransArgs) - args += ", " + std::to_string(transA) + ", " + std::to_string(transB); + if (needTransArgs) { + if (!structTypeA) + args += ", " + std::to_string(transA); + args += ", " + std::to_string(transB); + } auto ptxAsm = "wgmma.mma_async.sync.aligned" ".m" + diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 70e675c7bc5f..581252109ee2 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -897,6 +897,11 @@ struct ConvertLayoutOpConversion auto loc = op.getLoc(); auto srcTy = op.getSrc().getType().cast(); auto dstTy = op.getResult().getType().cast(); + if (matchMmaV3AndDotOperandLayout(srcTy, dstTy)) { + rewriter.replaceOp(op, op.getSrc()); + return success(); + } + if (isMmaToDotShortcut(srcTy, dstTy)) { // get source values auto vals = getTypeConverter()->unpackLLElements(loc, adaptor.getSrc(), diff --git a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/WGMMA.cpp b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/WGMMA.cpp index fb5cfd9d66a9..179b5a269ec2 100644 --- a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/WGMMA.cpp +++ b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/WGMMA.cpp @@ -95,13 +95,13 @@ getModeFromLayout(const SharedEncodingAttr &layout, uint32_t widthInByte) { class DotOpMmaV3SmemLoader { public: - DotOpMmaV3SmemLoader(Value tensor, const SharedMemoryObject &smemObj, - SmallVector shape, Value warpId, - unsigned int dimWpt, bool trans, + DotOpMmaV3SmemLoader() {} + DotOpMmaV3SmemLoader(Value tensor, Value base, SmallVector shape, + Value warpId, unsigned int dimWpt, bool trans, SmallVector instrShape, ConversionPatternRewriter &rewriter, Location loc) - : base(smemObj.base), shape(shape), warpId(warpId), dimWpt(dimWpt), - trans(trans), instrShape(instrShape), rewriter(rewriter), loc(loc) { + : base(base), shape(shape), warpId(warpId), dimWpt(dimWpt), trans(trans), + instrShape(instrShape) { auto tensorTy = tensor.getType().cast(); auto sharedLayout = tensorTy.getEncoding().cast(); ord = sharedLayout.getOrder(); @@ -118,7 +118,8 @@ class DotOpMmaV3SmemLoader { loc, base, i32_val(shape[ord[1]]), mode); } - Value smemLoad(int a, int b) { + Value smemLoad(int a, int b, ConversionPatternRewriter &rewriter, + Location loc) { Value k = i32_val(b * instrShape[1]); Value m = add(i32_val(a * dimWpt * instrShape[0]), mul(warpId, i32_val(instrShape[0]))); @@ -146,8 +147,6 @@ class DotOpMmaV3SmemLoader { mlir::triton::nvgpu::WGMMADescMode mode; SmallVector instrShape; ArrayRef ord; - ConversionPatternRewriter &rewriter; - Location loc; int elemsPerSwizzlingRow; int elemBytes; Value baseDesc; @@ -156,7 +155,7 @@ class DotOpMmaV3SmemLoader { DotOpMmaV3SmemLoader loadA(TritonGPUToLLVMTypeConverter *typeConverter, ConversionPatternRewriter &rewriter, Location loc, const MmaEncodingAttr &mmaEncoding, Value tensor, - const SharedMemoryObject &smemObj, Value thread) { + Value smemObjBase, Value thread) { auto aTensorTy = tensor.getType().cast(); auto aSharedLayout = aTensorTy.getEncoding().dyn_cast(); assert(aSharedLayout && "only support load dot operand from shared."); @@ -174,7 +173,7 @@ DotOpMmaV3SmemLoader loadA(TritonGPUToLLVMTypeConverter *typeConverter, Value warpId = urem(warpM, i32_val(shapePerCTA[0] / instrShape[0])); return {tensor, - smemObj, + smemObjBase, shapePerCTA, warpId, wpt[0], @@ -187,7 +186,7 @@ DotOpMmaV3SmemLoader loadA(TritonGPUToLLVMTypeConverter *typeConverter, DotOpMmaV3SmemLoader loadB(TritonGPUToLLVMTypeConverter *typeConverter, ConversionPatternRewriter &rewriter, Location loc, MmaEncodingAttr &mmaEncoding, Value tensor, - const SharedMemoryObject &smemObj, Value thread) { + Value base, Value thread) { auto bTensorTy = tensor.getType().cast(); auto bSharedLayout = bTensorTy.getEncoding().cast(); assert(bSharedLayout && "only support load B from shared."); @@ -206,7 +205,7 @@ DotOpMmaV3SmemLoader loadB(TritonGPUToLLVMTypeConverter *typeConverter, Value warpId = urem(warpN, i32_val(shapePerCTA[1] / instrShape[1])); return {tensor, - smemObj, + base, shapePerCTA, warpId, wpt[1], @@ -218,9 +217,10 @@ DotOpMmaV3SmemLoader loadB(TritonGPUToLLVMTypeConverter *typeConverter, // Return a vector of Value of the accumulator start at startIndex and pack the // values into 32bits in case the accumulator is fp16. -llvm::SmallVector loadC(ConversionPatternRewriter &rewriter, - Location loc, const SmallVector &elements, - int startIndex, int numElements) { +llvm::SmallVector loadReg(ConversionPatternRewriter &rewriter, + Location loc, + const SmallVector &elements, + int startIndex, int numElements) { if (!elements[0].getType().isF16()) { llvm::SmallVector mmaOut(numElements); for (int i = 0; i < numElements; ++i) @@ -294,19 +294,25 @@ LogicalResult convertDot(TritonGPUToLLVMTypeConverter *typeConverter, ConversionPatternRewriter &rewriter, Location loc, Operation *op, Value a, Value b, Value c, Value d, Value loadedA, Value loadedB, Value loadedC, - bool allowTF32, uint32_t maxNumImpreciseAcc, - const SharedMemoryObject &smemObjA, - const SharedMemoryObject &smemObjB, bool sync, + bool allowTF32, uint32_t maxNumImpreciseAcc, bool sync, Value thread) { auto aTensorTy = a.getType().cast(); auto bTensorTy = b.getType().cast(); auto dTensorTy = d.getType().cast(); - auto aSharedLayout = aTensorTy.getEncoding().cast(); + auto aSharedLayout = aTensorTy.getEncoding().dyn_cast(); auto bSharedLayout = bTensorTy.getEncoding().cast(); auto mmaEncoding = dTensorTy.getEncoding().cast(); - auto aOrd = aSharedLayout.getOrder(); auto bOrd = bSharedLayout.getOrder(); - bool transA = aOrd[0] == 0; + bool transA = false; + Value baseA; + Value baseB; + if (aSharedLayout) + baseA = getSharedMemoryObjectFromStruct(loc, loadedA, rewriter).base; + baseB = getSharedMemoryObjectFromStruct(loc, loadedB, rewriter).base; + if (aSharedLayout) { + auto aOrd = aSharedLayout.getOrder(); + transA = aOrd[0] == 0; + } bool transB = bOrd[0] == 1; auto dShapePerCTA = getShapePerCTA(dTensorTy); auto instrShape = mmaEncoding.getInstrShape(); @@ -319,11 +325,17 @@ LogicalResult convertDot(TritonGPUToLLVMTypeConverter *typeConverter, int numRepM = ceil(dShapePerCTA[0], shapePerCTATile[0]); int numRepN = ceil(dShapePerCTA[1], shapePerCTATile[1]); int numRepK = ceil(aTensorTy.getShape()[1], instrShape[2]); - - DotOpMmaV3SmemLoader aLoader = - loadA(typeConverter, rewriter, loc, mmaEncoding, a, smemObjA, thread); + DotOpMmaV3SmemLoader aLoader; + SmallVector structA; + if (aSharedLayout) { + aLoader = + loadA(typeConverter, rewriter, loc, mmaEncoding, a, baseA, thread); + } else { + structA = + typeConverter->unpackLLElements(loc, loadedA, rewriter, aTensorTy); + } DotOpMmaV3SmemLoader bLoader = - loadB(typeConverter, rewriter, loc, mmaEncoding, b, smemObjB, thread); + loadB(typeConverter, rewriter, loc, mmaEncoding, b, baseB, thread); auto fc = typeConverter->unpackLLElements(loc, loadedC, rewriter, dTensorTy); @@ -350,7 +362,7 @@ LogicalResult convertDot(TritonGPUToLLVMTypeConverter *typeConverter, for (int m = 0; m < numRepM; ++m) { for (int n = 0; n < numRepN; ++n) { llvm::SmallVector mmaOut = - loadC(rewriter, loc, fc, (m * numRepN + n) * accSize, accSize); + loadReg(rewriter, loc, fc, (m * numRepN + n) * accSize, accSize); llvm::SmallVector elemTypes; for (Value accEl : mmaOut) elemTypes.push_back(accEl.getType()); @@ -362,8 +374,19 @@ LogicalResult convertDot(TritonGPUToLLVMTypeConverter *typeConverter, uint32_t numLowPrecisionAcc = 0; Value partialAcc; for (int k = 0; k < numRepK; ++k) { - auto a = aLoader.smemLoad(m, k); - auto b = bLoader.smemLoad(n, k); + Value a; + if (aSharedLayout) { + a = aLoader.smemLoad(m, k, rewriter, loc); + } else { + unsigned regASize = (instrShape[0] * instrShape[2]) / 32; + llvm::SmallVector regA = loadReg( + rewriter, loc, structA, (m * numRepK + k) * regASize, regASize); + auto regATy = LLVM::LLVMStructType::getLiteral( + rewriter.getContext(), + SmallVector(regA.size(), regA[0].getType())); + a = typeConverter->packLLElements(loc, regA, rewriter, regATy); + } + auto b = bLoader.smemLoad(n, k, rewriter, loc); ValueRange operands{a, b, d}; numLowPrecisionAcc += K; // If using native accumulation would cause use to do more low precion @@ -410,28 +433,6 @@ LogicalResult convertDot(TritonGPUToLLVMTypeConverter *typeConverter, return success(); } -// Loading $c to registers, returns a Value. -Value loadC(Value tensor, Value llTensor) { - auto tensorTy = tensor.getType().cast(); - auto mmaEncoding = tensorTy.getEncoding().dyn_cast(); - assert(mmaEncoding && "Currently, we only support $c with a mma layout."); - auto instrShape = mmaEncoding.getInstrShape(); - auto wpt = mmaEncoding.getWarpsPerCTA(); - auto shapePerCTA = getShapePerCTA(tensorTy); - auto shapePerCTATile = getShapePerCTATile(mmaEncoding); - - int numRepM = ceil(shapePerCTA[0], shapePerCTATile[0]); - int numRepN = ceil(shapePerCTA[1], shapePerCTATile[1]); - - size_t fcSize = 2 * (instrShape[1] / 4) * numRepM * numRepN; - - auto structTy = llTensor.getType().cast(); - assert(structTy.getBody().size() == fcSize && - "DotOp's $c operand should pass the same number of values as $d in " - "mma layout."); - return llTensor; -} - LogicalResult convertWGMMA(triton::DotOp op, triton::DotOp::Adaptor adaptor, TritonGPUToLLVMTypeConverter *typeConverter, ConversionPatternRewriter &rewriter, Value thread) { @@ -442,21 +443,19 @@ LogicalResult convertWGMMA(triton::DotOp op, triton::DotOp::Adaptor adaptor, auto ATensorTy = A.getType().cast(); auto BTensorTy = B.getType().cast(); - assert(ATensorTy.getEncoding().isa() && - BTensorTy.getEncoding().isa() && - "Both $a and %b should be Shared layout."); + assert(ATensorTy.getEncoding().isa() || + ATensorTy.getEncoding().isa()); + assert(BTensorTy.getEncoding().isa() && + "Operand B should use Shared layout."); Value llA, llB, llC; llA = adaptor.getA(); llB = adaptor.getB(); - llC = loadC(C, adaptor.getC()); + llC = adaptor.getC(); - auto smemObjA = getSharedMemoryObjectFromStruct(loc, llA, rewriter); - auto smemObjB = getSharedMemoryObjectFromStruct(loc, llB, rewriter); return convertDot(typeConverter, rewriter, loc, op.getOperation(), A, B, C, op.getD(), llA, llB, llC, op.getAllowTF32(), - op.getMaxNumImpreciseAcc(), smemObjA, smemObjB, true, - thread); + op.getMaxNumImpreciseAcc(), true, thread); } LogicalResult convertAsyncWGMMA(triton::nvidia_gpu::DotAsyncOp op, @@ -471,19 +470,17 @@ LogicalResult convertAsyncWGMMA(triton::nvidia_gpu::DotAsyncOp op, auto ATensorTy = A.getType().cast(); auto BTensorTy = B.getType().cast(); - assert(ATensorTy.getEncoding().isa() && - BTensorTy.getEncoding().isa() && - "Both $a and %b should be Shared layout."); + assert(ATensorTy.getEncoding().isa() || + ATensorTy.getEncoding().isa()); + assert(BTensorTy.getEncoding().isa() && + "Operand B should use Shared layout."); Value llA, llB, llC; llA = adaptor.getA(); llB = adaptor.getB(); - llC = loadC(C, adaptor.getC()); + llC = adaptor.getC(); - auto smemObjA = getSharedMemoryObjectFromStruct(loc, llA, rewriter); - auto smemObjB = getSharedMemoryObjectFromStruct(loc, llB, rewriter); return convertDot(typeConverter, rewriter, loc, op.getOperation(), A, B, C, op.getD(), llA, llB, llC, op.getAllowTF32(), - op.getMaxNumImpreciseAcc(), smemObjA, smemObjB, false, - thread); + op.getMaxNumImpreciseAcc(), false, thread); } diff --git a/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp b/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp index cf388e27cb4a..6513b8bda671 100644 --- a/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp @@ -128,7 +128,7 @@ Type TritonGPUToLLVMTypeConverter::getElementTypeForStruct( if (!dotOpLayout) return elemTy; auto mmaParent = dotOpLayout.getParent().dyn_cast(); - if (!mmaParent) + if (!mmaParent || mmaParent.isHopper()) return elemTy; int bitwidth = elemTy.getIntOrFloatBitWidth(); assert(bitwidth <= 32); diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 58437f850ac7..a36c30533841 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -853,6 +853,11 @@ unsigned DotOperandEncodingAttr::getTotalElemsPerThread(ArrayRef shape, if (auto mmaParent = getParent().dyn_cast()) { int warpsPerCTAM = mmaParent.getWarpsPerCTA()[0]; int warpsPerCTAN = mmaParent.getWarpsPerCTA()[1]; + // H100 + if (mmaParent.isHopper()) { + if (eltTy.isF16()) + return mmaParent.getTotalElemsPerThread(shape, eltTy); + } // A100 if (mmaParent.isAmpere()) { auto rep = getMMAv2Rep(shapePerCTA, eltTy.getIntOrFloatBitWidth()); @@ -1472,10 +1477,13 @@ struct TritonGPUInferLayoutInterface std::optional location) const override { auto mmaRetEncoding = retEncoding.dyn_cast(); if (mmaRetEncoding && mmaRetEncoding.isHopper()) { - // TODO: support gmma when A/B does not reside in shared memory - if (!operandEncoding.isa()) + auto dotOpEnc = operandEncoding.dyn_cast(); + if (!operandEncoding.isa() && + !(opIdx == 0 && dotOpEnc && dotOpEnc.getOpIdx() == 0 && + dotOpEnc.getParent() == mmaRetEncoding)) { return emitOptionalError( location, "unexpected operand layout for MmaEncodingAttr v3"); + } } else if (auto dotOpEnc = operandEncoding.dyn_cast()) { if (opIdx != dotOpEnc.getOpIdx()) @@ -1497,6 +1505,10 @@ struct TritonGPUInferLayoutInterface operandEncodingB.dyn_cast(); if (!aEncoding && !bEncoding) return mlir::success(); + auto mmaAEncoding = + aEncoding.getParent().dyn_cast_or_null(); + if (mmaAEncoding && mmaAEncoding.isHopper()) + return success(); // Verify that the encodings are valid. if (!aEncoding || !bEncoding) return op->emitError("mismatching encoding between A and B operands"); diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index 5a1d93c2569a..15e71ea201aa 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -231,6 +231,42 @@ class FuseTransHopper : public mlir::RewritePattern { } }; +struct MMAV3UseRegOperand : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(triton::DotOp dotOp, + PatternRewriter &rewriter) const override { + auto convertLhs = + dotOp.getOperand(0).getDefiningOp(); + if (!convertLhs) + return failure(); + auto getEncoding = [](Value v) { + return v.getType().cast().getEncoding(); + }; + if (!getEncoding(dotOp.getOperand(0)).isa()) + return failure(); + auto srcEncoding = + getEncoding(convertLhs.getSrc()).dyn_cast(); + if (!srcEncoding || srcEncoding.getVersionMajor() != 3 || + srcEncoding != getEncoding(dotOp.getResult())) + return failure(); + // We currently only support convert from f16 mma to f16 dot operand as the + // other types require shuffling data across threads. + // TODO: extend it to more types. + auto srcType = convertLhs.getSrc().getType().cast(); + if (!srcType.getElementType().isF16()) + return failure(); + auto dotOperandEncoding = + DotOperandEncodingAttr::get(dotOp.getContext(), 0, srcEncoding, 0); + auto newType = RankedTensorType::get( + srcType.getShape(), srcType.getElementType(), dotOperandEncoding); + Value newOperand = rewriter.create(dotOp.getLoc(), newType, + convertLhs.getSrc()); + rewriter.updateRootInPlace(dotOp, + [&]() { dotOp.setOperand(0, newOperand); }); + return success(); + } +}; } // namespace #define GEN_PASS_CLASSES @@ -255,6 +291,7 @@ class TritonGPUOptimizeDotOperandsPass if (triton::gpu::TritonGPUDialect::getComputeCapability(m) >= 80) patterns.add(context); patterns.add(context); + patterns.add(context); if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) signalPassFailure(); } diff --git a/test/Conversion/tritongpu_to_llvm_hopper.mlir b/test/Conversion/tritongpu_to_llvm_hopper.mlir index 15020d058366..33db273c8ed5 100644 --- a/test/Conversion/tritongpu_to_llvm_hopper.mlir +++ b/test/Conversion/tritongpu_to_llvm_hopper.mlir @@ -1,4 +1,4 @@ -// RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm 2>&1 | FileCheck %s +// RUN: ENABLE_MMA_V3=1 triton-opt %s -split-input-file --convert-triton-gpu-to-llvm 2>&1 | FileCheck %s #blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 4], CTASplitNum = [1, 4], CTAOrder = [0, 1]}> #shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 4], CTASplitNum = [1, 4], CTAOrder = [0, 1], hasLeadingOffset = true}> @@ -166,3 +166,20 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : tt.return } } + +// ----- + +#mma = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: @dot_reg_operand_A + // Generate a wgmma where the first operand is a struct. + // CHECK: nvgpu.wgmma {{.*}} : (!llvm.struct<(i32, i32, i32, i32)>, i64, !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + tt.func @dot_reg_operand_A(%a: tensor<128x64xf16, #mma>, %b: tensor<64x64xf16, #shared>) { + %cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> + %opA = triton_gpu.convert_layout %a : (tensor<128x64xf16, #mma>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> + %m = tt.dot %opA, %b, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : + tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * tensor<64x64xf16, #shared> -> tensor<128x64xf32, #mma> + tt.return + } +} diff --git a/test/NVGPU/test_wgmma.mlir b/test/NVGPU/test_wgmma.mlir index ee059b329fb1..9ad7e606c2be 100644 --- a/test/NVGPU/test_wgmma.mlir +++ b/test/NVGPU/test_wgmma.mlir @@ -22,7 +22,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 2 module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 2 : i32} { tt.func @wgmma_no_acc(%descA: i64, %descB: i64) { - // CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "wgmma.mma_async.sync.aligned.m64n256k32.f32.e5m2.e5m2 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63,$64,$65,$66,$67,$68,$69,$70,$71,$72,$73,$74,$75,$76,$77,$78,$79,$80,$81,$82,$83,$84,$85,$86,$87,$88,$89,$90,$91,$92,$93,$94,$95,$96,$97,$98,$99,$100,$101,$102,$103,$104,$105,$106,$107,$108,$109,$110,$111,$112,$113,$114,$115,$116,$117,$118,$119,$120,$121,$122,$123,$124,$125,$126,$127}, $128, $129, 0, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,l,l" %0, %1 : (i64, i64) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + // CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "wgmma.mma_async.sync.aligned.m64n256k32.f32.e5m2.e5m2 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63,$64,$65,$66,$67,$68,$69,$70,$71,$72,$73,$74,$75,$76,$77,$78,$79,$80,$81,$82,$83,$84,$85,$86,$87,$88,$89,$90,$91,$92,$93,$94,$95,$96,$97,$98,$99,$100,$101,$102,$103,$104,$105,$106,$107,$108,$109,$110,$111,$112,$113,$114,$115,$116,$117,$118,$119,$120,$121,$122,$123,$124,$125,$126,$127}, $128, $129, 0, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,l,l" %{{.*}}, %{{.*}} : (i64, i64) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> %acc0 = nvgpu.wgmma %descA, %descB {eltTypeA = 3 : i32, eltTypeB = 3 : i32, eltTypeC = 7 : i32, k = 32 : i32, layoutA = 0 : i32, layoutB = 1 : i32, m = 64 : i32, n = 256 : i32} : (i64, i64) -> diff --git a/test/TritonGPU/dot-operands.mlir b/test/TritonGPU/dot-operands.mlir index 039c9429438a..1cd3b772deb2 100644 --- a/test/TritonGPU/dot-operands.mlir +++ b/test/TritonGPU/dot-operands.mlir @@ -137,3 +137,19 @@ tt.func @update_kwidth_slice( } } + +// ----- + +#mma = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], hasLeadingOffset = true}> +module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +// CHECK: tt.func @mma_v3_reg_operand_A +// CHECK: %[[A:.+]] = triton_gpu.convert_layout %{{.*}} : (tensor<128x64xf16, #mma>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> +// CHECK: tt.dot %[[A]], {{.*}} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * tensor<64x64xf16, #shared> -> tensor<128x64xf32, #mma> +tt.func @mma_v3_reg_operand_A(%arg0: tensor<128x64xf16, #mma>, %arg1: tensor<64x64xf16, #shared>, %arg2: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{ + %A = triton_gpu.convert_layout %arg0 : (tensor<128x64xf16, #mma>) -> tensor<128x64xf16, #shared1> + %r = tt.dot %A, %arg1, %arg2 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x64xf16, #shared1> * tensor<64x64xf16, #shared> -> tensor<128x64xf32, #mma> + tt.return %r : tensor<128x64xf32, #mma> +} +} From d040b58547ef76fc7445cfd4d9fd1400e40481e0 Mon Sep 17 00:00:00 2001 From: ben-zhang-609 <110140741+ben-zhang-609@users.noreply.github.com> Date: Tue, 26 Sep 2023 02:29:49 +0800 Subject: [PATCH 090/122] [HOPPER] fix ref check failure of flash attention with mma v3 (#2384) --- .../test/unit/operators/test_flash_attention.py | 7 ------- python/triton/ops/flash_attention.py | 17 +++++++++++++++-- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/python/test/unit/operators/test_flash_attention.py b/python/test/unit/operators/test_flash_attention.py index 179410faea0b..33c22675d735 100644 --- a/python/test/unit/operators/test_flash_attention.py +++ b/python/test/unit/operators/test_flash_attention.py @@ -18,13 +18,6 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par): if enable_tma in ["on", "true", "1"]: if dtype == torch.bfloat16: pytest.skip('bfloat16 tma not support currently') - if '-'.join(map(str, [seq_par, causal, Z, H, N_CTX, D_HEAD])) in [ - "True-True-2-4-512-16", - "True-True-2-4-512-32", - "True-False-2-4-512-16", - "True-False-2-4-512-32", - ]: - pytest.skip('backward ref check failed') capability = torch.cuda.get_device_capability() interpreter = os.environ.get("TRITON_INTERPRET", 'not found') in ["on", "true", "1"] diff --git a/python/triton/ops/flash_attention.py b/python/triton/ops/flash_attention.py index 187ec21377af..752c7d2f822d 100644 --- a/python/triton/ops/flash_attention.py +++ b/python/triton/ops/flash_attention.py @@ -143,6 +143,7 @@ def _bwd_kernel_one_col_block( BLOCK_N: tl.constexpr, SEQUENCE_PARALLEL: tl.constexpr, CAUSAL: tl.constexpr, + MMA_V3: tl.constexpr ): if SEQUENCE_PARALLEL: DQ += stride_dqa.to(tl.int64) * start_n @@ -202,8 +203,11 @@ def _bwd_kernel_one_col_block( dq += tl.dot(ds, k, allow_tf32=True) tl.store(dq_ptrs, dq) elif SEQUENCE_PARALLEL: - # dq = tl.dot(ds, k, allow_tf32=True) - dq = tl.trans(tl.dot(tl.trans(k), tl.trans(ds), allow_tf32=True)) + if MMA_V3: + dq = tl.dot(ds, k, allow_tf32=True) + else: + # not work with mma v3, becuase M % 64 != 0 + dq = tl.trans(tl.dot(tl.trans(k), tl.trans(ds), allow_tf32=True)) tl.store(dq_ptrs, dq) # increment pointers @@ -233,6 +237,7 @@ def _bwd_kernel( BLOCK_N: tl.constexpr, SEQUENCE_PARALLEL: tl.constexpr, CAUSAL: tl.constexpr, + MMA_V3: tl.constexpr # fmt: on ): qk_scale = sm_scale * 1.44269504 @@ -265,6 +270,7 @@ def _bwd_kernel( BLOCK_N=BLOCK_N, SEQUENCE_PARALLEL=SEQUENCE_PARALLEL, CAUSAL=CAUSAL, + MMA_V3=MMA_V3 ) else: start_n = tl.program_id(1) @@ -282,6 +288,7 @@ def _bwd_kernel( BLOCK_N=BLOCK_N, SEQUENCE_PARALLEL=SEQUENCE_PARALLEL, CAUSAL=CAUSAL, + MMA_V3=MMA_V3 ) @@ -328,6 +335,11 @@ def forward(ctx, q, k, v, causal, sm_scale, sequence_parallel=False): @staticmethod def backward(ctx, do): + import os + enable_mmav3 = os.environ.get('ENABLE_MMA_V3', 'not found').lower() + MMA_V3 = False + if enable_mmav3 in ["on", "true", "1"]: + MMA_V3 = True BLOCK = 128 q, k, v, o, L = ctx.saved_tensors sequence_parallel = ctx.sequence_parallel @@ -361,6 +373,7 @@ def backward(ctx, do): BLOCK_DMODEL=ctx.BLOCK_DMODEL, SEQUENCE_PARALLEL=sequence_parallel, CAUSAL=ctx.causal, + MMA_V3=MMA_V3, num_warps=8, num_stages=1, ) From 00c089d897d120a5c4926ac9b4d5186daca206b0 Mon Sep 17 00:00:00 2001 From: Hongtao Yu Date: Mon, 25 Sep 2023 12:18:28 -0700 Subject: [PATCH 091/122] [DOCS] tweak install instructions for custom llvm build. (#2381) --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 53611da7709d..c522144723f4 100644 --- a/README.md +++ b/README.md @@ -81,7 +81,7 @@ arbitrary LLVM version. $ cd $HOME/llvm-project # your clone of LLVM. $ mkdir build $ cd build - $ cmake -G Ninja -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=ON ../llvm -DLLVM_ENABLE_PROJECTS="mlir" + $ cmake -G Ninja -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=ON ../llvm -DLLVM_ENABLE_PROJECTS="mlir;llvm" $ ninja 4. Grab a snack, this will take a while. @@ -91,9 +91,9 @@ arbitrary LLVM version. # Modify as appropriate to point to your LLVM build. $ export LLVM_BUILD_DIR=$HOME/llvm-project/build - $ cd /python + $ cd $ LLVM_INCLUDE_DIRS=$LLVM_BUILD_DIR/include \ - LLVM_LIBRARY_DIR=$LLVM_BUILD_DIR \ + LLVM_LIBRARY_DIR=$LLVM_BUILD_DIR/lib \ LLVM_SYSPATH=$LLVM_BUILD_DIR \ pip install -e python From 80adbbb87ba116947a17643011a08ef20d8a3c02 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Mon, 25 Sep 2023 17:05:04 -0700 Subject: [PATCH 092/122] [OPTIMIZER] fix-up acceleratematmul (#2392) --- lib/Dialect/TritonGPU/Transforms/Utility.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index c455f20dc1ad..2bef4904c05b 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -524,7 +524,8 @@ void getForwardSliceSCFAware(Value root, SetVector *slices) { for (Operation *sliceOp : temp) { if (auto yieldOp = dyn_cast(sliceOp)) { auto forOp = yieldOp->getParentOfType(); - queue.append(forOp->getResults().begin(), forOp->getResults().end()); + if (forOp) + queue.append(forOp->getResults().begin(), forOp->getResults().end()); } } slices->insert(temp.begin(), temp.end()); From eea071844569188d51555c4bfe0de9c984bef133 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Mon, 25 Sep 2023 21:41:26 -0700 Subject: [PATCH 093/122] [TESTING] better cudagraph-based benchmarking (#2394) --- python/test/regression/test_performance.py | 84 +++++++++++----------- python/triton/testing.py | 41 ++++++----- 2 files changed, 62 insertions(+), 63 deletions(-) diff --git a/python/test/regression/test_performance.py b/python/test/regression/test_performance.py index dcab2e168ade..ec9966d29142 100644 --- a/python/test/regression/test_performance.py +++ b/python/test/regression/test_performance.py @@ -29,22 +29,22 @@ def print_perf(cur_ms, cur_util, ref_util): # NOTE: 'a100': { # square - (512, 512, 512): {'float16': 0.061, 'float32': 0.097, 'int8': 0.05}, - (1024, 1024, 1024): {'float16': 0.283, 'float32': 0.313, 'int8': 0.169}, - (2048, 2048, 2048): {'float16': 0.618, 'float32': 0.532, 'int8': 0.34}, - (8192, 8192, 8192): {'float16': 0.786, 'float32': 0.754, 'int8': 0.51}, + (512, 512, 512): {'float16': 0.108, 'float32': 0.097, 'int8': 0.05}, + (1024, 1024, 1024): {'float16': 0.355, 'float32': 0.313, 'int8': 0.169}, + (2048, 2048, 2048): {'float16': 0.653, 'float32': 0.532, 'int8': 0.34}, + (8192, 8192, 8192): {'float16': 0.839, 'float32': 0.754, 'int8': 0.51}, # tall-skinny - (16, 1024, 1024): {'float16': 0.006, 'float32': 0.009, 'int8': 0.005}, - (16, 4096, 4096): {'float16': 0.057, 'float32': 0.051, 'int8': 0.026}, - (16, 8192, 8192): {'float16': 0.077, 'float32': 0.077, 'int8': 0.043}, - (64, 1024, 1024): {'float16': 0.018, 'float32': 0.023, 'int8': 0.017}, - (64, 4096, 4096): {'float16': 0.150, 'float32': 0.000, 'int8': 0.097}, - (64, 8192, 8192): {'float16': 0.214, 'float32': 0.000, 'int8': 0.174}, - (1024, 64, 1024): {'float16': 0.029, 'float32': 0.046, 'int8': 0.017}, - (4096, 64, 4096): {'float16': 0.136, 'float32': 0.214, 'int8': 0.102}, - (8192, 64, 8192): {'float16': 0.278, 'float32': 0.000, 'int8': 0.177}, + (16, 1024, 1024): {'float16': 0.015, 'float32': 0.009, 'int8': 0.005}, + (16, 4096, 4096): {'float16': 0.080, 'float32': 0.051, 'int8': 0.026}, + (16, 8192, 8192): {'float16': 0.083, 'float32': 0.077, 'int8': 0.043}, + (64, 1024, 1024): {'float16': 0.045, 'float32': 0.023, 'int8': 0.017}, + (64, 4096, 4096): {'float16': 0.170, 'float32': 0.000, 'int8': 0.097}, + (64, 8192, 8192): {'float16': 0.227, 'float32': 0.000, 'int8': 0.174}, + (1024, 64, 1024): {'float16': 0.040, 'float32': 0.046, 'int8': 0.017}, + (4096, 64, 4096): {'float16': 0.160, 'float32': 0.214, 'int8': 0.102}, + (8192, 64, 8192): {'float16': 0.272, 'float32': 0.000, 'int8': 0.177}, # test EVEN_K==False - (8192, 8192, 8176): {'float16': 0.786, 'float32': 0.743, 'int8': 0.51}, + (8192, 8192, 8176): {'float16': 0.828, 'float32': 0.743, 'int8': 0.51}, } } @@ -100,15 +100,15 @@ def _add(x_ptr, y_ptr, output_ptr, n_elements, elementwise_data = { 'a100': { - 1024 * 16: {'float16': 0.003, 'float32': 0.007}, - 1024 * 64: {'float16': 0.013, 'float32': 0.026}, - 1024 * 256: {'float16': 0.053, 'float32': 0.105}, - 1024 * 1024: {'float16': 0.212, 'float32': 0.420}, - 1024 * 16384: {'float16': 0.762, 'float32': 0.812}, - 1024 * 65536: {'float16': 0.846, 'float32': 0.869}, + 1024 * 16: {'float16': 0.031, 'float32': 0.060}, + 1024 * 64: {'float16': 0.120, 'float32': 0.224}, + 1024 * 256: {'float16': 0.394, 'float32': 0.691}, + 1024 * 1024: {'float16': 1.06, 'float32': 1.453}, + 1024 * 16384: {'float16': 0.832, 'float32': 0.862}, + 1024 * 65536: {'float16': 0.873, 'float32': 0.882}, # Non pow 2 - 1020 * 100: {'float16': 0.020, 'float32': 0.041}, - 10003 * 7007: {'float16': 0.513, 'float32': 0.861}, + 1020 * 100: {'float16': 0.173, 'float32': 0.327}, + 10003 * 7007: {'float16': 0.522, 'float32': 0.873}, } } @@ -143,30 +143,30 @@ def test_elementwise(N, dtype_str): flash_attention_data = { "a100": { - (4, 48, 4096, 64, True, True, 'forward', 'float16'): 0.532, + (4, 48, 4096, 64, True, True, 'forward', 'float16'): 0.542, (4, 48, 4096, 64, True, True, 'forward', 'bfloat16'): 0.471, - (4, 48, 1024, 16, True, True, 'forward', 'float32'): 0.150, - (4, 48, 4096, 64, True, True, 'backward', 'float16'): 0.204, + (4, 48, 1024, 16, True, True, 'forward', 'float32'): 0.155, + (4, 48, 4096, 64, True, True, 'backward', 'float16'): 0.203, (4, 48, 4096, 64, True, True, 'backward', 'bfloat16'): 0.202, - (4, 48, 1024, 16, True, True, 'backward', 'float32'): 0.089, - (4, 48, 4096, 64, True, False, 'forward', 'float16'): 0.298, - (4, 48, 4096, 64, True, False, 'forward', 'bfloat16'): 0.263, - (4, 48, 1024, 16, True, False, 'forward', 'float32'): 0.095, - (4, 48, 4096, 64, True, False, 'backward', 'float16'): 0.136, + (4, 48, 1024, 16, True, True, 'backward', 'float32'): 0.108, + (4, 48, 4096, 64, True, False, 'forward', 'float16'): 0.306, + (4, 48, 4096, 64, True, False, 'forward', 'bfloat16'): 0.266, + (4, 48, 1024, 16, True, False, 'forward', 'float32'): 0.098, + (4, 48, 4096, 64, True, False, 'backward', 'float16'): 0.134, (4, 48, 4096, 64, True, False, 'backward', 'bfloat16'): 0.135, - (4, 48, 1024, 16, True, False, 'backward', 'float32'): 0.052, - (4, 48, 4096, 64, False, True, 'forward', 'float16'): 0.525, + (4, 48, 1024, 16, True, False, 'backward', 'float32'): 0.066, + (4, 48, 4096, 64, False, True, 'forward', 'float16'): 0.541, (4, 48, 4096, 64, False, True, 'forward', 'bfloat16'): 0.471, (4, 48, 1024, 16, False, True, 'forward', 'float32'): 0.150, - (4, 48, 4096, 64, False, True, 'backward', 'float16'): 0.265, - (4, 48, 4096, 64, False, True, 'backward', 'bfloat16'): 0.257, - (4, 48, 1024, 16, False, True, 'backward', 'float32'): 0.128, - (4, 48, 4096, 64, False, False, 'forward', 'float16'): 0.297, - (4, 48, 4096, 64, False, False, 'forward', 'bfloat16'): 0.263, - (4, 48, 1024, 16, False, False, 'forward', 'float32'): 0.095, + (4, 48, 4096, 64, False, True, 'backward', 'float16'): 0.263, + (4, 48, 4096, 64, False, True, 'backward', 'bfloat16'): 0.255, + (4, 48, 1024, 16, False, True, 'backward', 'float32'): 0.144, + (4, 48, 4096, 64, False, False, 'forward', 'float16'): 0.306, + (4, 48, 4096, 64, False, False, 'forward', 'bfloat16'): 0.266, + (4, 48, 1024, 16, False, False, 'forward', 'float32'): 0.098, (4, 48, 4096, 64, False, False, 'backward', 'float16'): 0.159, - (4, 48, 4096, 64, False, False, 'backward', 'bfloat16'): 0.138, - (4, 48, 1024, 16, False, False, 'backward', 'float32'): 0.076, + (4, 48, 4096, 64, False, False, 'backward', 'bfloat16'): 0.136, + (4, 48, 1024, 16, False, False, 'backward', 'float32'): 0.088, } } @@ -238,8 +238,8 @@ def _sum(x_ptr, y_ptr, output_ptr, n_elements, reduction_data = { 'a100': { - 1024 * 16384: {'float16': 0.016, 'float32': 0.031, 'int16': 0.015, 'int32': 0.031}, - 1024 * 65536: {'float16': 0.016, 'float32': 0.032, 'int16': 0.015, 'int32': 0.032}, + 1024 * 16384: {'float16': 0.016, 'float32': 0.031, 'int16': 0.022, 'int32': 0.048}, + 1024 * 65536: {'float16': 0.016, 'float32': 0.032, 'int16': 0.022, 'int32': 0.049}, } } diff --git a/python/triton/testing.py b/python/triton/testing.py index da7664adda86..f01d4f8e3faf 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -32,8 +32,11 @@ def do_bench_cudagraph(fn, rep=20, grad_to_none=None): """ if torch.cuda.current_stream() == torch.cuda.default_stream(): raise RuntimeError("Cannot capture graph in default stream. Please use side stream in benchmark code.") - # record CUDAGraph + # warmup fn() + # step 1 - we estimate the amount of time the kernel call takes + # NOTE: this estimate isn't super accurate because the GPU isn't warmed up at this point + # but it is probably good enough if grad_to_none is not None: for x in grad_to_none: x.detach_() @@ -43,39 +46,35 @@ def do_bench_cudagraph(fn, rep=20, grad_to_none=None): with torch.cuda.graph(g): fn() torch.cuda.synchronize() - fn = lambda: g.replay() - # Estimate the runtime of the function start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) start_event.record() - fn() + g.replay() end_event.record() torch.cuda.synchronize() estimate_ms = start_event.elapsed_time(end_event) - # compute number of repetition to last `rep` ms n_repeat = max(1, int(rep / estimate_ms)) - # compute number of repetition to last `rep` ms - start_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)] - end_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)] - ret = [] - n_retries = 50 - for _ in range(n_retries): - # Benchmark - torch.cuda.synchronize() + # step 2 - construct a cuda graph with `n_repeat` unrolled function calls to minimize + # host overhead + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): for i in range(n_repeat): - # we don't want `fn` to accumulate gradient values - # if it contains a backward pass. So we clear the - # provided gradients if grad_to_none is not None: for x in grad_to_none: x.grad = None - # record time of `fn` - start_event[i].record() fn() - end_event[i].record() + torch.cuda.synchronize() + # measure time and return + ret = [] + n_retries = 10 + for i in range(n_retries): + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + g.replay() + end_event.record() torch.cuda.synchronize() - times = torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)]) - ret.append(torch.min(times)) + ret += [start_event.elapsed_time(end_event) / n_repeat] return torch.mean(torch.tensor(ret)).item() From 7432fff4bef07ff7c6c70e2c9ca7f4f0c1f433b3 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Mon, 25 Sep 2023 23:58:25 -0700 Subject: [PATCH 094/122] [FRONTEND] add limited introspection capabilities in `tl.extra.cuda` ; rename `arch` into `target` (#2385) --- python/test/unit/language/test_core.py | 16 +++ python/triton/compiler/code_generator.py | 10 +- python/triton/compiler/compiler.py | 145 ++++++++++++----------- python/triton/compiler/target.py | 0 python/triton/language/extra/cuda.py | 5 + python/triton/language/semantic.py | 24 ++-- 6 files changed, 114 insertions(+), 86 deletions(-) create mode 100644 python/triton/compiler/target.py diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 53f026b3c42f..63d5b6c3acf7 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -3587,6 +3587,22 @@ def nested_while(data, countPtr): # test extra # ----------------------- +def test_num_threads(device): + if is_hip(): + pytest.skip("test_num_threads is not supported in HIP") + check_cuda_only(device) + + @triton.jit + def kernel(Out): + num_threads: tl.constexpr = tl.extra.cuda.num_threads() + offs = tl.arange(0, num_threads) + tl.store(Out + offs, 1) + + num_threads = 256 + out = to_triton(np.zeros((num_threads,), dtype=np.int32), device=device) + kernel[(1,)](out, num_warps=num_threads // 32) + assert torch.sum(out) == 256 + def test_globaltimer(device): if is_hip(): diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index 328474642aea..8480e7d182cc 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -199,7 +199,7 @@ def visit_Call(self, node: ast.Call) -> bool: class CodeGenerator(ast.NodeVisitor): - def __init__(self, context, prototype, gscope, attributes, constants, function_name, arch, + def __init__(self, context, prototype, gscope, attributes, constants, function_name, target, module=None, is_kernel=False, function_types: Optional[Dict] = None, debug=False, noinline=False, file_name: Optional[str] = None, begin_line=0): self.context = context @@ -208,7 +208,7 @@ def __init__(self, context, prototype, gscope, attributes, constants, function_n # node.lineno starts from 1, so we need to subtract 1 self.begin_line = begin_line - 1 self.builder.set_loc(file_name, begin_line, 0) - self.builder.arch = arch + self.builder.target = target self.module = self.builder.create_module() if module is None else module self.function_ret_types = {} if function_types is None else function_types self.prototype = prototype @@ -912,7 +912,7 @@ def call_JitFunction(self, fn: JITFunction, args, kwargs): file_name, begin_line = _get_fn_file_line(fn) generator = CodeGenerator(self.context, prototype, gscope, attributes, constants, module=self.module, function_name=fn_name, function_types=self.function_ret_types, debug=debug, noinline=fn.noinline, - file_name=file_name, begin_line=begin_line, arch=self.builder.arch) + file_name=file_name, begin_line=begin_line, target=self.builder.target) generator.visit(fn.parse()) callee_ret_type = generator.last_ret_type self.function_ret_types[fn_name] = callee_ret_type @@ -1106,7 +1106,7 @@ def kernel_suffix(signature, specialization): return suffix -def ast_to_ttir(fn, signature, specialization, constants, debug, arch): +def ast_to_ttir(fn, signature, specialization, constants, debug, target): # canonicalize signature if isinstance(signature, str): signature = {k: v.strip() for k, v in enumerate(signature.split(","))} @@ -1135,7 +1135,7 @@ def ast_to_ttir(fn, signature, specialization, constants, debug, arch): generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants, function_name=function_name, attributes=new_attrs, is_kernel=True, debug=debug, file_name=file_name, begin_line=begin_line, - arch=arch) + target=target) try: generator.visit(fn.parse()) except CompilationError as e: diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index 85c7b460a0ae..07414e8739e9 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -9,6 +9,8 @@ from pathlib import Path from typing import Any +from dataclasses import dataclass + from .._C.libtriton.triton import (ClusterInfo, TMAInfos, add_external_libs, compile_ptx_to_cubin, get_env_vars, get_num_warps, get_shared_memory_size, ir, runtime, @@ -30,6 +32,16 @@ get_ids_of_tensormaps, parse_tma_info) +@dataclass +class CudaTargetDescriptor: + capability: int + num_warps: int + + +def _is_cuda(target): + return isinstance(target, CudaTargetDescriptor) + + class LazyDict(dict): def __getitem__(self, key): val = dict.__getitem__(self, key) @@ -46,20 +58,20 @@ def inline_triton_ir(mod): return mod -def ttir_compute_capability_rewrite(mod, arch): +def ttir_compute_capability_rewrite(mod, target): # For hardware without support, we must rewrite all load/store # with block (tensor) pointers into tensors of pointers pm = ir.pass_manager(mod.context) pm.enable_debug() - if _is_cuda(arch): - pm.add_rewrite_tensor_pointer_pass(arch) + if _is_cuda(target): + pm.add_rewrite_tensor_pointer_pass(target.capability) pm.run(mod) return mod -def optimize_ttir(mod, arch): +def optimize_ttir(mod, target): mod = inline_triton_ir(mod) - mod = ttir_compute_capability_rewrite(mod, arch) + mod = ttir_compute_capability_rewrite(mod, target) pm = ir.pass_manager(mod.context) pm.enable_debug() pm.add_inliner_pass() @@ -73,27 +85,30 @@ def optimize_ttir(mod, arch): return mod -def ttir_to_ttgir(mod, num_warps, num_ctas, arch): +def ttir_to_ttgir(mod, num_warps, num_ctas, target): pm = ir.pass_manager(mod.context) pm.enable_debug() - pm.add_convert_triton_to_tritongpu_pass(num_warps, 32, num_ctas, arch) + pm.add_convert_triton_to_tritongpu_pass(num_warps, 32, num_ctas, target.capability) pm.run(mod) return mod -def optimize_ttgir(mod, num_stages, num_warps, num_ctas, arch, +def optimize_ttgir(mod, num_stages, num_warps, num_ctas, target, cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue): + is_cuda = _is_cuda(target) + if is_cuda: + capability = target.capability pm = ir.pass_manager(mod.context) pm.enable_debug() pm.add_tritongpu_coalesce_pass() # TODO(Qingyi): Move PlanCTAPass to the front of CoalescePass pm.add_plan_cta_pass(cluster_info) - if _is_cuda(arch): - pm.add_tritongpu_rewrite_tensor_pointer_pass(arch) + if is_cuda: + pm.add_tritongpu_rewrite_tensor_pointer_pass(capability) pm.add_plan_cta_pass(cluster_info) pm.add_tritongpu_remove_layout_conversions_pass() - if isinstance(arch, int): - pm.add_tritongpu_accelerate_matmul_pass(arch) + if is_cuda: + pm.add_tritongpu_accelerate_matmul_pass(capability) pm.add_tritongpu_remove_layout_conversions_pass() if optimize_epilogue: pm.add_tritongpu_optimize_epilogue_pass() @@ -104,24 +119,22 @@ def optimize_ttgir(mod, num_stages, num_warps, num_ctas, arch, # it's the responsibility of the compiler to figure out the exact # `num_warps` to use. # TODO: support the case where `num_warps` from user is not 4. - if arch // 10 >= 9 and enable_warp_specialization and num_warps == 4: - pm.add_tritongpu_ws_feasibility_checking_pass(arch) + if capability // 10 >= 9 and enable_warp_specialization and num_warps == 4: + pm.add_tritongpu_ws_feasibility_checking_pass(capability) pm.run(mod) ws_enabled = ir.is_ws_supported(mod) pm = ir.pass_manager(mod.context) pm.enable_debug() if ws_enabled: - pm.add_tritongpu_wsdecomposing_pass(arch) - pm.add_tritongpu_wspipeline_pass( - num_stages, num_warps, arch) - pm.add_tritongpu_wsmutex_pass(arch) - pm.add_tritongpu_wsmaterialization_pass(arch) + pm.add_tritongpu_wsdecomposing_pass(capability) + pm.add_tritongpu_wspipeline_pass(num_stages, num_warps, capability) + pm.add_tritongpu_wsmutex_pass(capability) + pm.add_tritongpu_wsmaterialization_pass(capability) pm.add_cse_pass() else: - pm.add_tritongpu_pipeline_pass( - num_stages, num_warps, num_ctas, arch) - pm.add_tritongpu_materialize_load_store_pass(num_warps, arch) - if arch // 10 <= 8: + pm.add_tritongpu_pipeline_pass(num_stages, num_warps, num_ctas, capability) + pm.add_tritongpu_materialize_load_store_pass(num_warps, capability) + if capability // 10 <= 8: pm.add_tritongpu_prefetch_pass() pm.add_tritongpu_optimize_dot_operands_pass() pm.add_tritongpu_remove_layout_conversions_pass() @@ -130,7 +143,7 @@ def optimize_ttgir(mod, num_stages, num_warps, num_ctas, arch, pm.add_tritongpu_reorder_instructions_pass() pm.add_cse_pass() pm.add_symbol_dce_pass() - if arch // 10 >= 9: + if capability // 10 >= 9: pm.add_tritongpu_fence_insertion_pass() pm.add_tritongpu_ws_fixup_missing_attrs_pass() pm.run(mod) @@ -144,12 +157,12 @@ def _add_external_libs(mod, libs): add_external_libs(mod, list(libs.keys()), list(libs.values())) -def ttgir_to_llir(mod, extern_libs, arch, tma_infos): +def ttgir_to_llir(mod, extern_libs, target, tma_infos): if extern_libs: _add_external_libs(mod, extern_libs) # TODO: separate tritongpu_to_llvmir for different backends - if _is_cuda(arch): - return translate_triton_gpu_to_llvmir(mod, arch, tma_infos, runtime.TARGET.NVVM) + if _is_cuda(target): + return translate_triton_gpu_to_llvmir(mod, target.capability, tma_infos, runtime.TARGET.NVVM) else: return translate_triton_gpu_to_llvmir(mod, 0, TMAInfos(), runtime.TARGET.ROCDL) @@ -172,7 +185,7 @@ def ptx_get_version(cuda_version) -> int: raise RuntimeError("Triton only support CUDA 10.0 or higher") -def llir_to_ptx(mod: Any, arch: int, ptx_version: int = None) -> str: +def llir_to_ptx(mod: Any, target: CudaTargetDescriptor, ptx_version: int = None) -> str: ''' Translate TritonGPU module to PTX code. :param mod: a TritonGPU dialect module @@ -181,10 +194,10 @@ def llir_to_ptx(mod: Any, arch: int, ptx_version: int = None) -> str: if ptx_version is None: _, cuda_version = path_to_ptxas() ptx_version = ptx_get_version(cuda_version) - return translate_llvmir_to_ptx(mod, arch, ptx_version) + return translate_llvmir_to_ptx(mod, target.capability, ptx_version) -def ptx_to_cubin(ptx: str, arch: int): +def ptx_to_cubin(ptx: str, target: CudaTargetDescriptor): ''' Compile TritonGPU module to cubin. :param ptx: ptx code @@ -192,7 +205,7 @@ def ptx_to_cubin(ptx: str, arch: int): :return: str ''' ptxas, _ = path_to_ptxas() - return compile_ptx_to_cubin(ptx, ptxas, arch) + return compile_ptx_to_cubin(ptx, ptxas, target.capability) # ------------------------------------------------------------------------------ @@ -220,7 +233,7 @@ def convert_type_repr(x): return x -def make_hash(fn, arch, env_vars, **kwargs): +def make_hash(fn, target, env_vars, **kwargs): if isinstance(fn, JITFunction): configs = kwargs["configs"] signature = kwargs["signature"] @@ -235,7 +248,7 @@ def make_hash(fn, arch, env_vars, **kwargs): get_conf_key = lambda conf: (sorted(conf.divisible_by_16), sorted(conf.equal_to_1), sorted(conf.ids_of_folded_args), sorted(conf.divisible_by_8)) configs_key = [get_conf_key(conf) for conf in configs] env_vars_list = [f"{env_vars[k]}" for k in sorted(env_vars.keys())] - key = f"{fn.cache_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{num_ctas}-{num_stages}-{enable_warp_specialization}-{enable_persistent}-{debug}-{arch}-{env_vars_list}" + key = f"{fn.cache_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{num_ctas}-{num_stages}-{enable_warp_specialization}-{enable_persistent}-{debug}-{target}-{env_vars_list}" return hashlib.md5(key.encode("utf-8")).hexdigest() assert isinstance(fn, str) ignore_version = kwargs.get('ignore_version', False) @@ -301,12 +314,7 @@ def parse_mlir_module(path, context): instance_descriptor = namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"], defaults=[set(), set(), set(), set()]) -# TODO: architecture descriptor class -def _is_cuda(arch): - return isinstance(arch, int) - - -def get_architecture_descriptor(capability): +def get_cuda_capability(capability): if capability is None: device = get_current_device() capability = get_device_capability(device) @@ -322,15 +330,12 @@ def get_arch_default_num_warps(device_type): assert _device_backend arch = _device_backend.get_architecture_descriptor() num_warps = arch["num_warps"] - return num_warps def get_arch_default_num_stages(device_type, capability=None): - if device_type in ["cuda", "hip"]: - arch = get_architecture_descriptor(capability) - is_cuda = device_type == "cuda" and _is_cuda(arch) - num_stages = 3 if is_cuda and arch >= 75 else 2 + if device_type == "cuda": + num_stages = 3 if get_cuda_capability(capability) >= 75 else 2 else: _device_backend = get_backend(device_type) assert _device_backend @@ -340,12 +345,12 @@ def get_arch_default_num_stages(device_type, capability=None): return num_stages -def add_cuda_stages(arch, extern_libs, stages): +def add_cuda_stages(target, extern_libs, stages): stages["ptx"] = (lambda path: Path(path).read_text(), - lambda src: llir_to_ptx(src, arch)) + lambda src: llir_to_ptx(src, target)) stages["cubin"] = (lambda path: Path(path).read_bytes(), - lambda src: ptx_to_cubin(src, arch)) + lambda src: ptx_to_cubin(src, target)) def compile(fn, **kwargs): @@ -355,18 +360,10 @@ def compile(fn, **kwargs): if is_hip(): device_type = "hip" - - if device_type == "cuda": - _device_backend = get_backend(device_type) - arch = get_architecture_descriptor(capability) - else: - _device_backend = get_backend(device_type) - assert _device_backend - arch = _device_backend.get_architecture_descriptor(**kwargs) - - is_cuda = device_type == "cuda" and _is_cuda(arch) + is_cuda = device_type == "cuda" if is_hip(): is_cuda = False + context = ir.context() constants = kwargs.get("constants", dict()) num_warps = kwargs.get("num_warps", get_arch_default_num_warps(device_type)) @@ -392,25 +389,33 @@ def compile(fn, **kwargs): cluster_info.clusterDimY = kwargs["clusterDims"][1] cluster_info.clusterDimZ = kwargs["clusterDims"][2] tma_infos = TMAInfos() + # build architecture descriptor + if device_type == "cuda": + _device_backend = get_backend(device_type) + target = CudaTargetDescriptor(capability=get_cuda_capability(capability), num_warps=num_warps) + else: + _device_backend = get_backend(device_type) + assert _device_backend + target = _device_backend.get_architecture_descriptor(**kwargs) # build compilation stages stages = dict() stages["ast"] = (lambda path: fn, None) stages["ttir"] = (lambda path: parse_mlir_module(path, context), - lambda src: optimize_ttir(ast_to_ttir(src, signature, configs[0], constants, debug=debug, arch=arch), arch)) + lambda src: optimize_ttir(ast_to_ttir(src, signature, configs[0], constants, debug=debug, target=target), target)) if is_cuda: stages["ttgir"] = (lambda path: parse_mlir_module(path, context), - lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps, num_ctas, arch), num_stages, num_warps, num_ctas, arch, cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue)) + lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps, num_ctas, target), num_stages, num_warps, num_ctas, target, cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue)) stages["llir"] = (lambda path: Path(path).read_text(), - lambda src: ttgir_to_llir(src, extern_libs, arch, tma_infos)) - add_cuda_stages(arch, extern_libs, stages) + lambda src: ttgir_to_llir(src, extern_libs, target, tma_infos)) + add_cuda_stages(target, extern_libs, stages) elif device_type == "hip": - _device_backend.add_stages(arch, extern_libs, stages, num_warps=num_warps, num_stages=num_stages) + _device_backend.add_stages(target, extern_libs, stages, num_warps=num_warps, num_stages=num_stages) else: # pass the user's configuration to the backend device. - arch["num_warps"] = num_warps - arch["num_stages"] = num_stages - arch["num_ctas"] = num_ctas - _device_backend.add_stages(arch, extern_libs, stages) + target["num_warps"] = num_warps + target["num_stages"] = num_stages + target["num_ctas"] = num_ctas + _device_backend.add_stages(target, extern_libs, stages) # find out the signature of the function if isinstance(fn, JITFunction): @@ -444,11 +449,11 @@ def compile(fn, **kwargs): first_stage = list(stages.keys()).index(ir_name) # create cache manager - fn_cache_manager = get_cache_manager(make_hash(fn, arch, get_env_vars(), **kwargs)) + fn_cache_manager = get_cache_manager(make_hash(fn, target, get_env_vars(), **kwargs)) # managers used to dump and override IR for debugging enable_override = os.environ.get("TRITON_KERNEL_OVERRIDE", "0") == "1" - fn_override_manager = get_override_manager(make_hash(fn, arch, get_env_vars(), **kwargs, ignore_version=True)) - fn_dump_manager = get_dump_manager(make_hash(fn, arch, get_env_vars(), **kwargs, ignore_version=True)) + fn_override_manager = get_override_manager(make_hash(fn, target, get_env_vars(), **kwargs, ignore_version=True)) + fn_dump_manager = get_dump_manager(make_hash(fn, target, get_env_vars(), **kwargs, ignore_version=True)) # determine name and extension type of provided function if isinstance(fn, JITFunction): @@ -481,7 +486,7 @@ def compile(fn, **kwargs): "enable_persistent": enable_persistent, "constants": _get_jsonable_constants(constants), "debug": debug, - "arch": arch, } + "target": target, } metadata.update(get_env_vars()) if ext == "ptx": assert "shared" in kwargs, "ptx compilation must provide shared memory size" diff --git a/python/triton/compiler/target.py b/python/triton/compiler/target.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/python/triton/language/extra/cuda.py b/python/triton/language/extra/cuda.py index d69120938185..8c4114739309 100644 --- a/python/triton/language/extra/cuda.py +++ b/python/triton/language/extra/cuda.py @@ -13,3 +13,8 @@ def smid(_builder=None): return core.inline_asm_elementwise("mov.u32 $0, %smid;", "=r", [], dtype=core.int32, is_pure=True, pack=1, _builder=_builder) + + +@core.builtin +def num_threads(_builder=None): + return core.constexpr(_builder.target.num_warps * 32) diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index c9b6fef79cfa..835fef198bd3 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -9,6 +9,13 @@ T = TypeVar('T') +# TODO: redundant code -- remove after 3P backend refactor + + +def _is_cuda(target): + from ..compiler.compiler import CudaTargetDescriptor + return isinstance(target, CudaTargetDescriptor) + # Create custom exception that prints message "hello" @@ -681,11 +688,6 @@ def bitcast(input: tl.tensor, dst_ty) -# TODO: architecture descriptor class -def _is_cuda(arch): - return isinstance(arch, int) - - def cast(input: tl.tensor, dst_ty: tl.dtype, builder: ir.builder) -> tl.tensor: @@ -700,7 +702,7 @@ def cast(input: tl.tensor, src_sca_ty = src_ty.scalar dst_sca_ty = dst_ty.scalar - if _is_cuda(builder.arch) and builder.arch < 89 and \ + if _is_cuda(builder.target) and builder.target.capability < 89 and \ (src_sca_ty.is_fp8e4nv() or dst_sca_ty.is_fp8e4nv()): assert False, "fp8e4nv data type is not supported on CUDA arch < 89" @@ -1290,13 +1292,13 @@ def dot(lhs: tl.tensor, max_num_imprecise_acc: int, out_dtype: tl.dtype, builder: ir.builder) -> tl.tensor: - def assert_dtypes_valid(lhs_dtype, rhs_dtype, arch): + def assert_dtypes_valid(lhs_dtype, rhs_dtype, target): # Checks for non-cuda archs - if not _is_cuda(builder.arch): + if not _is_cuda(target): assert lhs_dtype == rhs_dtype, f"First input ({lhs_dtype}) and second input ({rhs_dtype}) must have the same dtype!" return # Checks for cuda arch - if arch < 90: + if target.capability < 90: assert not lhs_dtype.is_fp8e4nv() and not rhs_dtype.is_fp8e4nv(), "Dot op does not support fp8e4nv on CUDA arch < 90" if lhs_dtype.is_fp8() and rhs_dtype.is_fp8(): return @@ -1317,7 +1319,7 @@ def assert_dtypes_valid(lhs_dtype, rhs_dtype, arch): assert lhs.type.is_block() and rhs.type.is_block() - assert_dtypes_valid(lhs.dtype, rhs.dtype, builder.arch) + assert_dtypes_valid(lhs.dtype, rhs.dtype, builder.target) assert len(lhs.shape) == 2, f"First input shape ({lhs.shape}) is not two dimensional!" assert len(rhs.shape) == 2, f"Second input shape ({rhs.shape}) is not two dimensional!" @@ -1375,7 +1377,7 @@ def assert_dtypes_valid(lhs_dtype, rhs_dtype, arch): assert acc.type == ret_ty # max_num_imprecise_acc only applies to fp8 -> fp32 dot on sm_90 - if not (_is_cuda(builder.arch) and builder.arch == 90 and lhs.dtype.is_fp8() and rhs.dtype.is_fp8() and ret_scalar_ty.is_fp32()): + if not (_is_cuda(builder.target) and builder.target.capability == 90 and lhs.dtype.is_fp8() and rhs.dtype.is_fp8() and ret_scalar_ty.is_fp32()): max_num_imprecise_acc = 0 if max_num_imprecise_acc is None: max_num_imprecise_acc = 2**30 From 2d28b09319cb89284dbc89379d24342b8ad8d2a3 Mon Sep 17 00:00:00 2001 From: Tori Baker Date: Tue, 26 Sep 2023 17:08:13 +0200 Subject: [PATCH 095/122] Forward fixes to build on newer version of llvm (#2388) --- .../Transforms/RemoveLayoutConversions.cpp | 13 ++++++------- lib/Dialect/TritonGPU/Transforms/Utility.cpp | 2 +- .../Transforms/RewriteTensorPointer.cpp | 14 +++++++------- lib/Dialect/TritonNvidiaGPU/Transforms/WSMutex.cpp | 4 +++- .../TritonNvidiaGPU/Transforms/WSPipeline.cpp | 4 ++-- 5 files changed, 19 insertions(+), 18 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp index b7f88948b982..0e49c0ce5dd7 100644 --- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp @@ -761,16 +761,15 @@ static scf::ForOp replaceForOpWithNewSignature(OpBuilder &rewriter, // Create a new loop before the existing one, with the extra operands. rewriter.setInsertionPoint(loop); - auto operands = llvm::to_vector<4>(loop.getIterOperands()); + auto operands = llvm::to_vector<4>(loop.getInitArgs()); operands.append(newIterOperands.begin(), newIterOperands.end()); scf::ForOp newLoop = rewriter.create( loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(), operands); newLoop.getBody()->erase(); - newLoop.getLoopBody().getBlocks().splice( - newLoop.getLoopBody().getBlocks().begin(), - loop.getLoopBody().getBlocks()); + newLoop.getRegion().getBlocks().splice( + newLoop.getRegion().getBlocks().begin(), loop.getRegion().getBlocks()); for (Value operand : newIterOperands) newLoop.getBody()->addArgument(operand.getType(), operand.getLoc()); @@ -805,9 +804,9 @@ static void rewriteSlice(SetVector &slice, for (auto arg : forOp.getRegionIterArgs()) { if (slice.count(arg)) { OpOperand &initVal = forOp.getOpOperandForRegionIterArg(arg); - argMapping.push_back( - std::make_pair(*forOp.getIterArgNumberForOpOperand(initVal), - forOp.getNumIterOperands() + newOperands.size())); + argMapping.push_back(std::make_pair( + forOp.getResultForOpOperand(initVal).getResultNumber(), + forOp.getInitArgs().size() + newOperands.size())); newOperands.push_back(mapping.lookup(initVal.get())); } } diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index 2bef4904c05b..cc17e4e87815 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -613,7 +613,7 @@ struct ForOpDeadArgElimination : public OpRewritePattern { Value yieldOperand = forOwner.getBody()->getTerminator()->getOperand(iterIdx); markLive(yieldOperand); - markLive(forOwner.getIterOperands()[iterIdx]); + markLive(forOwner.getInitArgs()[iterIdx]); } } SmallVector deadArg; diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/RewriteTensorPointer.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/RewriteTensorPointer.cpp index e13cf8bd9179..5a2d3beaa4b4 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/RewriteTensorPointer.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/RewriteTensorPointer.cpp @@ -523,9 +523,9 @@ class TritonGPURewriteTensorPointerPass std::stack &eraser, DenseSet &valueToRemove) { // Generate new iteration operands and set rewrited information - SmallVector oldIterOperands = op.getIterOperands(); - SmallVector newIterOperands = op.getIterOperands(); - for (unsigned i = 0, oldI = 0, size = op.getNumIterOperands(); i < size; + SmallVector oldIterOperands = llvm::to_vector(op.getInitArgs()); + SmallVector newIterOperands = llvm::to_vector(op.getInitArgs()); + for (unsigned i = 0, oldI = 0, size = op.getInitArgs().size(); i < size; ++i, ++oldI) { if (!tt::isTensorPointerType(newIterOperands[i].getType())) continue; @@ -550,7 +550,7 @@ class TritonGPURewriteTensorPointerPass // mapping. It may refer to a value in the old loop, but we will rewrite it // later IRMapping mapping; - for (unsigned i = 0, oldI = 0; oldI < op.getNumIterOperands(); + for (unsigned i = 0, oldI = 0; oldI < op.getInitArgs().size(); ++i, ++oldI) { auto oldRegionIterArg = op.getRegionIterArg(oldI); if (tt::isTensorPointerType(oldRegionIterArg.getType()) && @@ -586,7 +586,7 @@ class TritonGPURewriteTensorPointerPass valueToRemove.insert(v); // Replace later usages - assert(op.getNumResults() == op.getNumIterOperands()); + assert(op.getNumResults() == op.getInitArgs().size()); for (unsigned i = 0, oldI = 0; oldI < op.getNumResults(); ++i, ++oldI) { auto oldResult = op.getResult(oldI); if (tt::isTensorPointerType(oldResult.getType()) && @@ -787,8 +787,8 @@ class TritonGPURewriteTensorPointerPass } } if (auto forOp = dyn_cast(op)) { - SmallVector iterOperands = forOp.getIterOperands(); - for (unsigned i = 0, size = forOp.getNumIterOperands(); i < size; ++i) { + SmallVector iterOperands = llvm::to_vector(forOp.getInitArgs()); + for (unsigned i = 0, size = forOp.getInitArgs().size(); i < size; ++i) { if (tt::isTensorPointerType(iterOperands[i].getType())) { auto makeTensorPtrOp = getMakeTensorPtrOp(iterOperands[i]); if (shouldRemove(makeTensorPtrOp, computeCapability)) diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/WSMutex.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/WSMutex.cpp index 4ed9f0c64996..dd39e94b8569 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/WSMutex.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/WSMutex.cpp @@ -153,7 +153,9 @@ void mutexSync(ModuleOp &mod, scf::IfOp &ifOp, scf::ForOp &persistentForOp, Value newIdx = builder.createWithAgentIds(loc, pipelineIdx, curRoleId); - persistentForOp.setIterArg(persistentForOp.getNumIterOperands() - 1, newIdx); + persistentForOp.getInitArgsMutable() + .slice(persistentForOp.getInitArgs().size() - 1, 1) + .assign(newIdx); auto yield = llvm::cast(persistentForOp.getBody()->getTerminator()); auto idxPlusOneOp = diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/WSPipeline.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/WSPipeline.cpp index b7488a8ba4ea..5d6417fab24a 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/WSPipeline.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/WSPipeline.cpp @@ -162,7 +162,7 @@ scf::ForOp appendPipelineIdxToLoopArgs(scf::ForOp forOp, int numStages, // Copy iter operands of forOp SmallVector newLoopArgs; - for (auto operand : forOp.getIterOperands()) + for (auto operand : llvm::to_vector(forOp.getInitArgs())) newLoopArgs.push_back(operand); // Append initial value of pipelineIdx to newLoopArgs @@ -302,7 +302,7 @@ DenseMap createForOpsForEachAgentId(scf::ForOp forOp) { // Prepare newLoopArgs SmallVector newLoopArgs; for (unsigned argNumber : usedArgs) - newLoopArgs.push_back(forOp.getIterOperands()[argNumber]); + newLoopArgs.push_back(forOp.getInitArgs()[argNumber]); // Create newForOp builder.setAgentIdsFromArray({agentId}); From bf3171f5c735ea216fb624107c807e4e026c5638 Mon Sep 17 00:00:00 2001 From: Tori Baker Date: Tue, 26 Sep 2023 19:12:32 +0200 Subject: [PATCH 096/122] Lit test to check for illegal st.shared.b1 llvmir (#2387) --- test/Conversion/tritongpu_to_llvm.mlir | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 978205ec55ce..c7da669da35d 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -1435,3 +1435,23 @@ module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-c tt.return } } + +// ----- + +// CHECK-LABEL: copyitem +// CHECK: st.shared.b8 +// CHECK: ld.shared.b8 +// CHECK-NOT: st.shared.b1 +// CHECK-NOT: ld.shared.b1 +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @copyitem() attributes {noinline = false} { + %cst = arith.constant dense : tensor<4x1xi1, #blocked> + %0 = "tt.reduce"(%cst) <{axis = 1 : i32}> ({ + ^bb0(%arg0: i1, %arg1: i1): + %1 = arith.ori %arg0, %arg1 : i1 + tt.reduce.return %1 : i1 + }) : (tensor<4x1xi1, #blocked>) -> tensor<4xi1, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + tt.return + } +} From 78c28bf5f6049d3f224b22eba1be0db59706a58c Mon Sep 17 00:00:00 2001 From: Ying Zhang Date: Wed, 27 Sep 2023 08:29:53 -0700 Subject: [PATCH 097/122] Support scalar fp8 conversions by packing (#2379) Support fp8 scalar conversions by packing fp8 with undef values. Also add simple unittests to cover this change. --- .../TritonGPUToLLVM/ElementwiseOpToLLVM.cpp | 10 +++-- python/test/unit/language/test_core.py | 44 ++++++++++++------- 2 files changed, 35 insertions(+), 19 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp index b1c73cd3230f..c34f64539a62 100644 --- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -621,26 +621,28 @@ struct FpToFpOpConversion Location loc) const { auto srcElementType = getElementType(op.getFrom()); auto dstElementType = getElementType(op.getResult()); - int numElements = 4; + + size_t numElements = 4; if (srcElementType.isFloat8E4M3FNUZ() || dstElementType.isFloat8E4M3FNUZ()) { numElements = 2; } - assert(operands.size() % numElements == 0 && - "FP8 casting only support tensors with aligned sizes"); bool isSrcFP32 = srcElementType.isF32(); bool isDstFP32 = dstElementType.isF32(); auto cvtFunc = getConversionFunc(isSrcFP32 ? f16_ty : srcElementType, isDstFP32 ? f16_ty : dstElementType); SmallVector inVals; - for (unsigned i = 0; i < numElements; i++) { + for (unsigned i = 0; i < std::min(numElements, operands.size()); i++) { inVals.push_back(operands[i][0]); } if (isSrcFP32) for (Value &v : inVals) v = convertFp32ToFp16(loc, rewriter, v); + inVals.resize(numElements, + undef(typeConverter->convertType(srcElementType))); SmallVector outVals = cvtFunc(loc, rewriter, inVals); assert(outVals.size() == inVals.size()); + outVals.resize(std::min(numElements, operands.size())); if (isDstFP32) for (Value &v : outVals) v = convertFp16ToFp32(loc, rewriter, v); diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 63d5b6c3acf7..05afbf31c943 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -20,6 +20,7 @@ float_dtypes = ['float16', 'float32', 'float64'] dtypes = int_dtypes + uint_dtypes + float_dtypes dtypes_with_bfloat16 = dtypes + ['bfloat16'] +torch_float8_dtypes = ['float8_e4m3fn', 'float8_e5m2'] torch_dtypes = ['bool'] + int_dtypes + ['uint8'] + float_dtypes + ['bfloat16'] # TODO: enable multiple cta cluster testing. @@ -131,7 +132,7 @@ def check_type_supported(dtype, device): cc = torch.cuda.get_device_capability() if cc[0] < 8 and (dtype is tl.bfloat16 or dtype == "bfloat16" or dtype is torch.bfloat16): pytest.skip("bfloat16 is only supported on NVGPU with cc >= 80") - if cc[0] < 9 and (dtype is tl.float8e4nv or dtype == "float8e4nv"): + if cc[0] < 9 and dtype in {tl.float8e4nv, "float8e4nv", "float8_e4m3fn"}: pytest.skip("float8e4nv is only supported on NVGPU with cc >= 90") @@ -1281,23 +1282,33 @@ def serialized_add(data, Lock, SEM: tl.constexpr): # --------------- -@pytest.mark.parametrize("dtype_x, dtype_z, bitcast", [ - (dtype_x, dtype_z, False) +@pytest.mark.parametrize("dtype_x, dtype_z, bitcast, size", [ + (dtype_x, dtype_z, False, 1024) for dtype_x in dtypes for dtype_z in dtypes ] + [ - ('float32', 'bfloat16', False), - ('bfloat16', 'float32', False), - ('float32', 'int32', True), - ('float32', 'int1', False), - ('int8', 'bfloat16', False), + ('float32', 'bfloat16', False, 1024), + ('bfloat16', 'float32', False, 1024), + ('float32', 'int32', True, 1024), + ('float32', 'int1', False, 1024), + ('int8', 'bfloat16', False, 1024), ] + [ - (f'uint{x}', f'int{x}', True) for x in [8, 16, 32, 64] + (f'uint{x}', f'int{x}', True, 1024) for x in [8, 16, 32, 64] ] + [ - (f'int{x}', f'uint{x}', True) for x in [8, 16, 32, 64] -]) + (f'int{x}', f'uint{x}', True, 1024) for x in [8, 16, 32, 64] +] + (([ + (dtype_x, dtype_z, False, size) + for dtype_x in torch_float8_dtypes + for dtype_z in ["float16", "float32"] + for size in [1024, 32] +] + [ + (dtype_x, dtype_z, False, size) + for dtype_z in torch_float8_dtypes + for dtype_x in ["float16", "float32"] + for size in [1024, 32] +]) if torch.__version__ >= "2.1" else [])) @pytest.mark.parametrize("num_ctas", num_ctas_list) -def test_cast(dtype_x, dtype_z, bitcast, num_ctas, device): +def test_cast(dtype_x, dtype_z, bitcast, size, num_ctas, device): # bfloat16 on cc < 80 will not be tested check_type_supported(dtype_x, device) check_type_supported(dtype_z, device) @@ -1305,10 +1316,11 @@ def test_cast(dtype_x, dtype_z, bitcast, num_ctas, device): if is_hip() and (dtype_z == "bfloat16"): pytest.skip(f'test_cast{(dtype_x, dtype_z)} cast to bfloat16 not supported on HIP.') - size = 1024 # This is tricky because numpy doesn't have bfloat, and torch doesn't have uints. if dtype_x.startswith('bfloat'): x_tri = torch.randn(size, dtype=getattr(torch, dtype_x), device=device) + elif dtype_x.startswith('float8'): + x_tri = torch.randn(size, dtype=torch.half, device=device).to(dtype=getattr(torch, dtype_x)) else: x = numpy_random(size, dtype_str=dtype_x, low=-10, high=10) * 10 # Triton clamps negative values to zero, while numpy wraps around @@ -1331,11 +1343,13 @@ def kernel(X, Z, BITCAST: tl.constexpr, SIZE: tl.constexpr): # triton result if dtype_z.startswith('bfloat'): z_tri = torch.empty((size,), dtype=getattr(torch, dtype_z), device=device) + elif dtype_z.startswith('float8'): + z_tri = torch.empty((size,), dtype=torch.float, device=device) else: z_tri = to_triton(np.empty((size, ), dtype=getattr(np, dtype_z_np)), device=device) kernel[(1, )](x_tri, z_tri, BITCAST=bitcast, SIZE=size, num_warps=1, num_ctas=num_ctas) # torch result - if dtype_z.startswith('bfloat') or dtype_x.startswith('bfloat'): + if dtype_z.startswith('bfloat') or dtype_x.startswith('bfloat') or dtype_z.startswith('float8') or dtype_x.startswith('float8'): assert bitcast is False z_ref = x_tri.to(z_tri.dtype) torch.testing.assert_close(z_ref, z_tri, rtol=0, atol=0) @@ -3080,7 +3094,7 @@ def kernel(ptr, n_elements, num1, num2, type: tl.constexpr): err_msg = str(e) if type == "noinline": - assert err_msg is not "" + assert err_msg != "" else: ans = rand_val * 1 * 2 * 1 * 2 * 3 * 4 np.testing.assert_equal(to_numpy(rand_val_tri), ans) From 9bf9c20f300c40e7b43dd4a9d1f0d895fb458f20 Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Wed, 27 Sep 2023 22:13:03 -0700 Subject: [PATCH 098/122] [DOCS] update build instructions, and add testing instrs. (#2400) - Note `wheel` as a build-time dependency. - Add tips for getting a faster build. - Add instructions for running tests. - Add flag to build with ccache. (Thanks to @ThomasRaoux for most of these instructions!) --- README.md | 39 +++++++++++++++++++++++++++++++++++++-- python/setup.py | 5 +++++ 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index c522144723f4..1b51e493d8d4 100644 --- a/README.md +++ b/README.md @@ -40,7 +40,7 @@ pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/ git clone https://github.com/openai/triton.git; cd triton; -pip install ninja cmake; # build-time dependencies +pip install ninja cmake wheel; # build-time dependencies pip install -e python ``` @@ -53,7 +53,7 @@ cd triton; python -m venv .venv --prompt triton; source .venv/bin/activate; -pip install ninja cmake; # build-time dependencies +pip install ninja cmake wheel; # build-time dependencies pip install -e python ``` @@ -97,6 +97,41 @@ arbitrary LLVM version. LLVM_SYSPATH=$LLVM_BUILD_DIR \ pip install -e python +# Tips for building + +- Set `TRITON_BUILD_WITH_CLANG_LLD=true` as an environment variable to use clang + and lld. lld in particular results in faster builds. + +- Set `TRITON_BUILD_WITH_CCACHE=true` to build with ccache. + +- Pass `--no-build-isolation` to `pip install` to make nop builds faster. + Without this, every invocation of `pip install` uses a different symlink to + cmake, and this forces ninja to rebuild most of the `.a` files. + +# Running tests + +There currently isn't a turnkey way to run all the Triton tests, but you can +follow the following recipe. + +```shell +# One-time setup. Note we have to reinstall local Triton because torch +# overwrites it with the public version. +$ pip install scipy numpy torch pytest lit && pip install -e python + +# Run Python tests using your local GPU. +$ python3 -m pytest python/test/unit + +# Move to builddir. Fill in <...> with the full path, e.g. +# `cmake.linux-x86_64-cpython-3.11`. +$ cd python/build/cmake<...> + +# Run C++ unit tests. +$ ninja test + +# Run lit tests. +$ lit test +``` + # Changelog Version 2.0 is out! New features include: diff --git a/python/setup.py b/python/setup.py index 0d7bf594aff3..5450b163fb7d 100644 --- a/python/setup.py +++ b/python/setup.py @@ -273,6 +273,11 @@ def build_extension(self, ext): "-DCMAKE_MODULE_LINKER_FLAGS=-fuse-ld=lld", "-DCMAKE_SHARED_LINKER_FLAGS=-fuse-ld=lld"] + if check_env_flag("TRITON_BUILD_WITH_CCACHE"): + cmake_args += [ + "-DCMAKE_CXX_COMPILER_LAUNCHER=ccache", + ] + env = os.environ.copy() cmake_dir = get_cmake_dir() subprocess.check_call(["cmake", self.base_dir] + cmake_args, cwd=cmake_dir, env=env) From b25edc139e0f10c906c3d279d503d058a8810117 Mon Sep 17 00:00:00 2001 From: Simon Boehm Date: Wed, 27 Sep 2023 22:15:17 -0700 Subject: [PATCH 099/122] [FRONTEND] fix out_path parsing in AOT compiler (#2409) `out_path.with_suffix` (penultimate line) fails if out_path is string. --- python/triton/tools/compile.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/triton/tools/compile.py b/python/triton/tools/compile.py index 32138e8740a0..d80f15e8a1aa 100644 --- a/python/triton/tools/compile.py +++ b/python/triton/tools/compile.py @@ -20,7 +20,7 @@ signature is provided as a list of (optionally divisibility-hinted) types or constexpr values, e.g. -`compile.py --kernel-name kernel --signature "*f32:16, i32:16, 1024, i32" --out-name kernel /path/to/kernel.py` +`compile.py --kernel-name kernel --signature "*fp32:16, i32:16, 1024, i32" --out-name kernel /path/to/kernel.py` will compile triton.JITFunction of name `kernel` inside the file `/path/to/kernel.py`. Said kernel will be specialized such that argument 0, 1 are assumed to be multiple of 16, @@ -51,7 +51,7 @@ args = parser.parse_args() out_name = args.out_name if args.out_name else args.kernel_name - out_path = args.out_path if args.out_path else out_name + out_path = args.out_path if args.out_path else Path(out_name) # execute python sources and extract functions wrapped in JITFunction arg_path = Path(args.path) From 9073a393e0aca430934efe27a4df7493b08a8493 Mon Sep 17 00:00:00 2001 From: Hongtao Yu Date: Wed, 27 Sep 2023 22:16:20 -0700 Subject: [PATCH 100/122] [GIT] .gitignore clangd and vscode index files. (#2406) Vscode or clangd can create indexing cache for symbol resolutions. Those files should be ignored by git. I'm basically cloning what is done in the LLVM repo. --- .gitignore | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/.gitignore b/.gitignore index e85433df82f7..0180cd911245 100644 --- a/.gitignore +++ b/.gitignore @@ -37,3 +37,10 @@ docs/dialects/ docs/getting-started/tutorials !python/tutorials/*.py !python/tutorials/*.rst + +# clangd index. (".clangd" is a config file now, thus trailing slash) +.clangd/ +.cache +/compile_commands.json +.vscode +.vs From 1e093fbfff2fb3bd4406d9379f7aa62deaf74965 Mon Sep 17 00:00:00 2001 From: Yuheng XIE <146084381+yuhengxnv@users.noreply.github.com> Date: Thu, 28 Sep 2023 14:10:01 +0800 Subject: [PATCH 101/122] [OPTIMIZER] Calculate a proper divisibility for ExpandDims (#2397) Previously ExpandDims always inserts 1 as the new divisibility, which makes writing (x * stride)[:, None] far more slower than (x[:, None] * stride). A better divisibility can be afforded by computing the GCD of the old dims. Now the two code above are equally fast. E.g. the conv inductor in pytorch may be faster. --------- Co-authored-by: Yuheng XIE --- lib/Analysis/AxisInfo.cpp | 58 ++++++++++++++++++++++++------- test/Analysis/test-alignment.mlir | 29 ++++++++++++---- 2 files changed, 67 insertions(+), 20 deletions(-) diff --git a/lib/Analysis/AxisInfo.cpp b/lib/Analysis/AxisInfo.cpp index dd782b1876eb..14e766a1ad19 100644 --- a/lib/Analysis/AxisInfo.cpp +++ b/lib/Analysis/AxisInfo.cpp @@ -301,10 +301,19 @@ class MulIOpAxisInfoVisitor final : public BinaryOpVisitorImpl { int64_t getDivisibility(arith::MulIOp op, const AxisInfo &lhs, const AxisInfo &rhs, int dim) override { - // lhs = k * d_lhs - // rhs = p * d_rhs - // lhs * rhs = k * d_lhs * p * d_rhs = k * p * d_lhs * d_rhs - return lhs.getDivisibility(dim) * rhs.getDivisibility(dim); + auto lhsDivisibility = lhs.getDivisibility(dim); + if (lhs.getContiguity(dim) > 1 && + !(rhs.getConstantValue().has_value() && rhs.getConstantValue() == 1)) { + // Treat [2^n,2^n+1,...]'s divisibility as 1 instead of 2^n + lhsDivisibility = 1; + } + auto rhsDivisibility = rhs.getDivisibility(dim); + if (rhs.getContiguity(dim) > 1 && + !(lhs.getConstantValue().has_value() && lhs.getConstantValue() == 1)) { + // Treat [2^n,2^n+1,...]'s divisibility as 1 instead of 2^n + rhsDivisibility = 1; + } + return lhsDivisibility * rhsDivisibility; } std::optional getConstantValue(arith::MulIOp op, const AxisInfo &lhs, @@ -511,8 +520,23 @@ class ExpandDimsOpAxisInfoVisitor final AxisInfo::DimVectorT contiguity = opInfo.getContiguity(); AxisInfo::DimVectorT divisibility = opInfo.getDivisibility(); AxisInfo::DimVectorT constancy = opInfo.getConstancy(); + int64_t newDivisibility = 1; + if (opInfo.getConstantValue().has_value()) { + // The tensor is constant, same as ConstantOpAxisInfoVisitor + newDivisibility = highestPowOf2Divisor(opInfo.getConstantValue().value()); + } else if (opInfo.getRank()) { + // Otherwise, calculate the GCD as the new divisibility + // Treat [2^n,2^n+1,...]'s divisibility as 1 instead of 2^n + newDivisibility = + opInfo.getContiguity(0) > 1 ? 1 : opInfo.getDivisibility(0); + for (int d = 1; d < opInfo.getRank(); ++d) { + newDivisibility = + gcd(newDivisibility, + opInfo.getContiguity(d) > 1 ? 1 : opInfo.getDivisibility(d)); + } + } contiguity.insert(contiguity.begin() + op.getAxis(), 1); - divisibility.insert(divisibility.begin() + op.getAxis(), 1); + divisibility.insert(divisibility.begin() + op.getAxis(), newDivisibility); constancy.insert(constancy.begin() + op.getAxis(), 1); return AxisInfo(contiguity, divisibility, constancy, operands[0]->getValue().getConstantValue()); @@ -756,12 +780,17 @@ class ShLIOpAxisInfoVisitor final : public BinaryOpVisitorImpl { auto shift = rhs.getConstantValue().has_value() ? rhs.getConstantValue().value() : rhs.getDivisibility(dim); - auto numBits = log2Int(lhs.getDivisibility(dim)); + auto lhsDivisibility = lhs.getDivisibility(dim); + if (lhs.getContiguity(dim) > 1 && shift) { + // Treat [2^n,2^n+1,...]'s divisibility as 1 instead of 2^n + lhsDivisibility = 1; + } + auto numBits = log2Int(lhsDivisibility); auto maxBits = log2Int(highestPowOf2Divisor(0)); // Make sure the return value doesn't exceed highestPowOf2Divisor(0) if (shift + numBits > maxBits) return highestPowOf2Divisor(0); - return lhs.getDivisibility(dim) << shift; + return lhsDivisibility << shift; } int64_t getConstancy(arith::ShLIOp op, const AxisInfo &lhs, @@ -795,12 +824,15 @@ class ShROpAxisInfoVisitor final : public BinaryOpVisitorImpl { int64_t getDivisibility(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, int dim) override { - if (rhs.getConstantValue().has_value()) - return std::max(1, lhs.getDivisibility(dim) / - (1 << rhs.getConstantValue().value())); - else - return std::max(1, lhs.getDivisibility(dim) / - (1 << rhs.getDivisibility(dim))); + auto shift = rhs.getConstantValue().has_value() + ? rhs.getConstantValue().value() + : rhs.getDivisibility(dim); + auto lhsDivisibility = lhs.getDivisibility(dim); + if (lhs.getContiguity(dim) > 1 && shift) { + // Treat [2^n,2^n+1,...]'s divisibility as 1 instead of 2^n + lhsDivisibility = 1; + } + return std::max(1, lhsDivisibility / (1 << shift)); } int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, diff --git a/test/Analysis/test-alignment.mlir b/test/Analysis/test-alignment.mlir index 11ee22fe460c..77eae5f22570 100644 --- a/test/Analysis/test-alignment.mlir +++ b/test/Analysis/test-alignment.mlir @@ -184,13 +184,28 @@ tt.func @rem() { // ----- +// CHECK-LABEL: @expanddims +tt.func @expanddims() { + // CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [2], constancy = [128], constant_value = 2 + %1 = arith.constant dense<2> : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [2], constancy = [1], constant_value = + %2 = arith.muli %0, %1 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1, 1], divisibility = [2, 2], constancy = [1, 1], constant_value = + %3 = tt.expand_dims %2 {axis = 1 : i32} : (tensor<128xi32>) -> tensor<128x1xi32> + tt.return +} + +// ----- + // CHECK-LABEL: @broadcast tt.func @broadcast() { // CHECK: contiguity = [1], divisibility = [64], constancy = [128], constant_value = 64 %0 = arith.constant dense<64> : tensor<128xi32> - // CHECK-NEXT: contiguity = [1, 1], divisibility = [64, 1], constancy = [128, 1], constant_value = 64 + // CHECK-NEXT: contiguity = [1, 1], divisibility = [64, 64], constancy = [128, 1], constant_value = 64 %1 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<128xi32>) -> tensor<128x1xi32> - // CHECK-NEXT: contiguity = [1, 1], divisibility = [64, 1], constancy = [128, 128], constant_value = 64 + // CHECK-NEXT: contiguity = [1, 1], divisibility = [64, 64], constancy = [128, 128], constant_value = 64 %2 = tt.broadcast %1 : (tensor<128x1xi32>) -> tensor<128x128xi32> tt.return } @@ -290,9 +305,9 @@ tt.func @shift() { %1 = arith.constant dense<8> : tensor<128xi32> // CHECK-NEXT: contiguity = [1], divisibility = [4], constancy = [128], constant_value = 4 %2 = arith.constant dense<4> : tensor<128xi32> - // CHECK-NEXT: contiguity = [1], divisibility = [274877906944], constancy = [1], constant_value = + // CHECK-NEXT: contiguity = [1], divisibility = [256], constancy = [1], constant_value = %3 = arith.shli %0, %1 : tensor<128xi32> - // CHECK-NEXT: contiguity = [1], divisibility = [67108864], constancy = [1], constant_value = + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = %4 = arith.shrsi %0, %2 : tensor<128xi32> // CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [128], constant_value = 128 %5 = arith.shli %1, %2 : tensor<128xi32> @@ -362,7 +377,7 @@ tt.func @permute_2d(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 %2 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<128xi32>) -> tensor<128x1xi32> // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 1], constant_value = %3 = tt.splat %arg1 : (i32) -> tensor<128x1xi32> - // CHECK-NEXT: contiguity = [1, 1], divisibility = [17179869184, 16], constancy = [1, 1], constant_value = + // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [1, 1], constant_value = %4 = arith.muli %2, %3 : tensor<128x1xi32> // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 1], constant_value = %5 = tt.splat %arg0 : (!tt.ptr) -> tensor<128x1x!tt.ptr> @@ -386,11 +401,11 @@ tt.func @permute_2d(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 %14 = tt.expand_dims %1 {axis = 0 : i32} : (tensor<128xi32>) -> tensor<1x128xi32> // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [1, 128], constant_value = %15 = tt.splat %arg3 : (i32) -> tensor<1x128xi32> - // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 17179869184], constancy = [1, 1], constant_value = + // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [1, 1], constant_value = %16 = arith.muli %14, %15 : tensor<1x128xi32> // CHECK-NEXT: contiguity = [128, 1], divisibility = [16, 4], constancy = [1, 128], constant_value = %17 = tt.broadcast %13 : (tensor<128x1x!tt.ptr>) -> tensor<128x128x!tt.ptr> - // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 17179869184], constancy = [128, 1], constant_value = + // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 1], constant_value = %18 = tt.broadcast %16 : (tensor<1x128xi32>) -> tensor<128x128xi32> // CHECK-NEXT: contiguity = [128, 1], divisibility = [16, 4], constancy = [1, 1], constant_value = %19 = tt.addptr %17, %18 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> From 721bdebee1c77111e8f88e62502ab7a89ae71928 Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Thu, 28 Sep 2023 10:29:08 -0700 Subject: [PATCH 102/122] [OPTIMIZATION] Fix performance for attention backward path with mma v3 (#2411) Support having chain of mma with mixed size. Serialize the different block calculation in backward attention to workaround problem with ptxas and wgmma. --- .../Dialect/TritonGPU/Transforms/Utility.h | 10 - lib/Analysis/Utility.cpp | 30 +-- .../TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp | 40 +++- .../TritonGPUToLLVM/DotOpToLLVM/WGMMA.cpp | 20 +- lib/Dialect/TritonGPU/IR/Dialect.cpp | 2 +- .../TritonGPU/Transforms/AccelerateMatmul.cpp | 32 --- .../Transforms/OptimizeDotOperands.cpp | 6 +- .../Transforms/RemoveLayoutConversions.cpp | 3 +- lib/Dialect/TritonGPU/Transforms/Utility.cpp | 40 ---- python/tutorials/06-fused-attention.py | 190 +++++++++--------- test/TritonGPU/accelerate-matmul.mlir | 8 +- 11 files changed, 175 insertions(+), 206 deletions(-) diff --git a/include/triton/Dialect/TritonGPU/Transforms/Utility.h b/include/triton/Dialect/TritonGPU/Transforms/Utility.h index 375fdfac2356..fe9f9f8c5953 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Utility.h +++ b/include/triton/Dialect/TritonGPU/Transforms/Utility.h @@ -141,16 +141,6 @@ Value linearize(OpBuilder &b, Location loc, ArrayRef multiDim, Value linearize(OpBuilder &b, Location loc, ArrayRef multiDim, ArrayRef shape); -// Implement backward and forward slice that will go through scf blocks when -// yield or scf results are in the slice. -// Note that like exisiting forward and backard slice this may add operations to -// the slice that are not actually dependent on the root because when a region -// is added to the slice in the forward slice all the operations of the region -// are added. We could implement a more accurate slice method by tracking value -// usage across scf regions. -void getBackwardSliceSCFAware(Operation *, SetVector *slices); -void getForwardSliceSCFAware(Value root, SetVector *slices); - } // namespace mlir #endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_ diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index dcf8a7704052..3f1b594ad23a 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -424,6 +424,23 @@ bool supportMMA(Value value, int version) { (elemTy.isInteger(8) && version >= 2); } +static bool isMmaToMmaShortcut(Attribute srcEncoding, Attribute dstEncoding) { + auto src = srcEncoding.dyn_cast(); + auto dst = dstEncoding.dyn_cast(); + if (!src || !dst) + return false; + auto srcInstrShape = src.getInstrShape(); + auto dstInstrShape = dst.getInstrShape(); + // when #mma = MmaEncoding + return src && dst && src.getVersionMajor() == 3 && + src.getWarpsPerCTA()[1] == 1 && dst.getVersionMajor() == 3 && + dst.getWarpsPerCTA()[1] == 1 && srcInstrShape[2] == dstInstrShape[2]; +} + +bool isMmaToMmaShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) { + return isMmaToMmaShortcut(srcTy.getEncoding(), dstTy.getEncoding()); +} + // For MMAV3 dotOperand layout matches mma operand for f16 case. bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy, RankedTensorType dstTy) { @@ -432,7 +449,7 @@ bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy, auto mmaLayout = srcLayout.cast(); auto dotOperandLayout = dstLayout.cast(); return mmaLayout.getVersionMajor() == 3 && dotOperandLayout.getOpIdx() == 0 && - dotOperandLayout.getParent() == mmaLayout && + isMmaToMmaShortcut(dotOperandLayout.getParent(), srcLayout) && srcTy.getElementType().isF16(); } @@ -452,17 +469,6 @@ bool isMmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) { !srcTy.getElementType().isF32(); } -bool isMmaToMmaShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) { - auto src = srcTy.getEncoding().cast(); - auto dst = dstTy.getEncoding().cast(); - auto srcElemsPerThread = triton::gpu::getTotalElemsPerThread(srcTy); - auto dstElemsPerThread = triton::gpu::getTotalElemsPerThread(dstTy); - // when #mma = MmaEncoding - return src.getVersionMajor() == 3 && src.getWarpsPerCTA()[1] == 1 && - dst.getVersionMajor() == 3 && dst.getWarpsPerCTA()[1] == 1 && - srcElemsPerThread == dstElemsPerThread; -} - bool isSingleValue(Value value) { // Don't consider load as expensive if it is loading a scalar. if (auto tensorTy = value.getType().dyn_cast()) diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 581252109ee2..d33217bd2a60 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -81,8 +81,7 @@ struct ConvertLayoutOpConversion // forwarding on mma->mma shortcut, lower distributed->distributed otherwise if (srcLayout.isa() && dstLayout.isa()) { if (isMmaToMmaShortcut(srcTy, dstTy)) { - rewriter.replaceOp(op, op.getSrc()); - return success(); + return lowerMmaToMma(op, adaptor, rewriter); } } if (isaDistributedLayout(srcLayout) && isaDistributedLayout(dstLayout)) { @@ -963,6 +962,43 @@ struct ConvertLayoutOpConversion return failure(); } + // mma -> mma + LogicalResult lowerMmaToMma(triton::gpu::ConvertLayoutOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + auto srcTy = op.getSrc().getType().cast(); + auto dstTy = op.getResult().getType().cast(); + if (triton::gpu::getTotalElemsPerThread(srcTy) == + triton::gpu::getTotalElemsPerThread(dstTy)) { + rewriter.replaceOp(op, op.getSrc()); + return success(); + } + // get source values + auto vals = getTypeConverter()->unpackLLElements(loc, adaptor.getSrc(), + rewriter, srcTy); + SmallVector retVals; + SmallVector dstElementPerThread = + triton::gpu::getElemsPerThread(dstTy); + SmallVector srcElementPerThread = + triton::gpu::getElemsPerThread(srcTy); + for (unsigned j = 0; j < dstElementPerThread[0]; j++) { + for (unsigned i = 0; i < dstElementPerThread[1]; i++) { + if (i >= srcElementPerThread[1] || j >= srcElementPerThread[0]) { + retVals.push_back(undef(vals[0].getType())); + continue; + } + unsigned index = i + j * srcElementPerThread[1]; + retVals.push_back(vals[index]); + } + } + assert(retVals.size() == triton::gpu::getTotalElemsPerThread(dstTy)); + Value view = + getTypeConverter()->packLLElements(loc, retVals, rewriter, dstTy); + rewriter.replaceOp(op, view); + return success(); + } + // shared -> dot_operand if the result layout is mma Value lowerSharedToDotOperandMMA( triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, diff --git a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/WGMMA.cpp b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/WGMMA.cpp index 179b5a269ec2..b16aee5d8be9 100644 --- a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/WGMMA.cpp +++ b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/WGMMA.cpp @@ -220,7 +220,10 @@ DotOpMmaV3SmemLoader loadB(TritonGPUToLLVMTypeConverter *typeConverter, llvm::SmallVector loadReg(ConversionPatternRewriter &rewriter, Location loc, const SmallVector &elements, - int startIndex, int numElements) { + int startIndex, int numElements, + Operation *insertBefore) { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(insertBefore); if (!elements[0].getType().isF16()) { llvm::SmallVector mmaOut(numElements); for (int i = 0; i < numElements; ++i) @@ -351,9 +354,12 @@ LogicalResult convertDot(TritonGPUToLLVMTypeConverter *typeConverter, auto func = op->getParentOfType(); int numTMADescs = func->getAttr(kAttrNumTMALoadDescsName).cast().getInt(); + Operation *startSequence = nullptr; if (numTMADescs == 0) - rewriter.create(loc, 0); - rewriter.create(loc); + startSequence = rewriter.create(loc, 0); + Operation *fenceOp = rewriter.create(loc); + if (startSequence == nullptr) + startSequence = fenceOp; // WGMMA fp8 -> fp32 accumulates in lower precision than fp32. bool needsPartialAccumulator = isFP8(eltTypeA) && eltTypeC == triton::nvgpu::WGMMAEltType::f32 && @@ -362,7 +368,8 @@ LogicalResult convertDot(TritonGPUToLLVMTypeConverter *typeConverter, for (int m = 0; m < numRepM; ++m) { for (int n = 0; n < numRepN; ++n) { llvm::SmallVector mmaOut = - loadReg(rewriter, loc, fc, (m * numRepN + n) * accSize, accSize); + loadReg(rewriter, loc, fc, (m * numRepN + n) * accSize, accSize, + startSequence); llvm::SmallVector elemTypes; for (Value accEl : mmaOut) elemTypes.push_back(accEl.getType()); @@ -379,8 +386,9 @@ LogicalResult convertDot(TritonGPUToLLVMTypeConverter *typeConverter, a = aLoader.smemLoad(m, k, rewriter, loc); } else { unsigned regASize = (instrShape[0] * instrShape[2]) / 32; - llvm::SmallVector regA = loadReg( - rewriter, loc, structA, (m * numRepK + k) * regASize, regASize); + llvm::SmallVector regA = + loadReg(rewriter, loc, structA, (m * numRepK + k) * regASize, + regASize, startSequence); auto regATy = LLVM::LLVMStructType::getLiteral( rewriter.getContext(), SmallVector(regA.size(), regA[0].getType())); diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index a36c30533841..82fc6522ccde 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -1480,7 +1480,7 @@ struct TritonGPUInferLayoutInterface auto dotOpEnc = operandEncoding.dyn_cast(); if (!operandEncoding.isa() && !(opIdx == 0 && dotOpEnc && dotOpEnc.getOpIdx() == 0 && - dotOpEnc.getParent() == mmaRetEncoding)) { + dotOpEnc.getParent().isa())) { return emitOptionalError( location, "unexpected operand layout for MmaEncodingAttr v3"); } diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp index 3473c2123519..8894438f3552 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -145,35 +145,6 @@ class BlockedToMMA : public mlir::RewritePattern { } } - unsigned getMmaV3InstrN(tt::DotOp dotOp, unsigned currN) const { - auto type = dotOp.getResult().getType().cast(); - if (type.getEncoding().isa()) - return currN; - auto it = dotOpInstNs.find(dotOp.getOperation()); - if (it != dotOpInstNs.end()) - return it->second; - - SetVector slices; - mlir::getForwardSliceSCFAware(dotOp.getResult(), &slices); - mlir::getBackwardSliceSCFAware(dotOp.getOperation(), &slices); - unsigned N = currN; - SmallVector dotOps; - for (Operation *iter : slices) { - if (auto nextDotOp = dyn_cast(iter)) { - auto type = nextDotOp.getResult().getType().cast(); - auto AType = nextDotOp.getOperand(0).getType().cast(); - auto shapePerCTA = ttg::getShapePerCTA(type); - auto instrShape = mmaVersionToInstrShape(3, shapePerCTA, AType); - dotOps.push_back(iter); - if (instrShape[1] < N) - N = instrShape[1]; - } - } - for (Operation *dotOp : dotOps) - dotOpInstNs[dotOp] = N; - return N; - } - static Value getMMAv3Operand(Value v, mlir::PatternRewriter &rewriter, int opIdx) { Value arg = v; @@ -232,9 +203,6 @@ class BlockedToMMA : public mlir::RewritePattern { auto instrShape = mmaVersionToInstrShape(versionMajor, retShapePerCTA, AType); - if (versionMajor == 3) - instrShape[1] = getMmaV3InstrN(dotOp, instrShape[1]); - // operands Value a = dotOp.getA(); Value b = dotOp.getB(); diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index 15e71ea201aa..8e13719b7e53 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -247,8 +247,10 @@ struct MMAV3UseRegOperand : public OpRewritePattern { return failure(); auto srcEncoding = getEncoding(convertLhs.getSrc()).dyn_cast(); - if (!srcEncoding || srcEncoding.getVersionMajor() != 3 || - srcEncoding != getEncoding(dotOp.getResult())) + auto dstEncoding = + getEncoding(dotOp.getResult()).dyn_cast(); + if (!srcEncoding || srcEncoding.getVersionMajor() != 3 || !dstEncoding || + dstEncoding.getVersionMajor() != 3) return failure(); // We currently only support convert from f16 mma to f16 dot operand as the // other types require shuffling data across threads. diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp index 0e49c0ce5dd7..d7d1ac98f15f 100644 --- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp @@ -217,7 +217,8 @@ static bool hasConvertToMMATransisitiveUse(Operation *op, Attribute encoding) { if (convertOp.getResult() .getType() .cast() - .getEncoding() == encoding) + .getEncoding() + .isa()) return true; } auto yield = dyn_cast(op); diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index cc17e4e87815..f315fe5ad9a7 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -492,46 +492,6 @@ Value linearize(OpBuilder &b, Location loc, ArrayRef multiDim, return linear; } -void getBackwardSliceSCFAware(Operation *op, SetVector *slices) { - SmallVector queue = {op}; - while (!queue.empty()) { - Operation *currentOp = queue.back(); - queue.pop_back(); - SetVector temp; - auto filter = [slices](Operation *sliceOp) { - return slices->count(sliceOp) == 0; - }; - mlir::getBackwardSlice(currentOp, &temp, filter); - for (Operation *sliceOp : temp) { - if (auto forOp = dyn_cast(sliceOp)) { - queue.push_back(forOp.getBody()->getTerminator()); - } - } - slices->insert(temp.begin(), temp.end()); - } -} - -void getForwardSliceSCFAware(Value root, SetVector *slices) { - SmallVector queue = {root}; - while (!queue.empty()) { - Value currentValue = queue.back(); - queue.pop_back(); - SetVector temp; - auto filter = [slices](Operation *sliceOp) { - return slices->count(sliceOp) == 0; - }; - mlir::getForwardSlice(currentValue, &temp, filter); - for (Operation *sliceOp : temp) { - if (auto yieldOp = dyn_cast(sliceOp)) { - auto forOp = yieldOp->getParentOfType(); - if (forOp) - queue.append(forOp->getResults().begin(), forOp->getResults().end()); - } - } - slices->insert(temp.begin(), temp.end()); - } -} - namespace { /// Detect dead arguments in scf.for op by assuming all the values are dead and diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index eac1330c4077..fe420bc63ce6 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -323,106 +323,102 @@ def _attn_bwd( # load scales offs_k = tl.arange(0, BLOCK_DMODEL) - if (tl.program_id(1) == 0): - - # THIS BLOCK DOES DK/DV/DR: - - start_n = pid * BLOCK_N1 - start_m = start_n - - MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR - offs_n = start_n + tl.arange(0, BLOCK_N1) - - dv = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32) - dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32) - - # load K and V: they stay in SRAM throughout the inner loop. - k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d) - v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d) - - num_steps = BLOCK_N1 // MASK_BLOCK_M1 - - dk, dv = _attn_bwd_dkdv(dk, dv, - Q, k, v, sm_scale, - DO, - M, D, - stride_tok, stride_d, - H, N_CTX, - MASK_BLOCK_M1, BLOCK_N1, BLOCK_DMODEL, - start_n, start_m, num_steps, - MASK=True, - ) - - start_m += num_steps * MASK_BLOCK_M1 - num_steps = (N_CTX - start_m) // BLOCK_M1 - - # Compute dK and dV for non-masked blocks. - dk, dv = _attn_bwd_dkdv(dk, dv, - Q, k, v, sm_scale, - DO, - M, D, - stride_tok, stride_d, - H, N_CTX, - BLOCK_M1, BLOCK_N1, BLOCK_DMODEL, - start_n, start_m, num_steps, - MASK=False, - ) - - dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d - tl.store(dv_ptrs, dv) - - # Write back dK. - dk *= sm_scale - dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d - tl.store(dk_ptrs, dk) + # THIS BLOCK DOES DK/DV/DR: - else: - - # THIS BLOCK DOES DQ: - start_m = pid * BLOCK_M2 - end_n = start_m + BLOCK_M2 + start_n = pid * BLOCK_N1 + start_m = start_n - MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR - offs_m = start_m + tl.arange(0, BLOCK_M2) + MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR + offs_n = start_n + tl.arange(0, BLOCK_N1) - q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d) - dq = tl.zeros([BLOCK_M2, BLOCK_DMODEL], dtype=tl.float32) - do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d) + dv = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32) + + # load K and V: they stay in SRAM throughout the inner loop. + k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d) + v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d) + + num_steps = BLOCK_N1 // MASK_BLOCK_M1 + + dk, dv = _attn_bwd_dkdv(dk, dv, + Q, k, v, sm_scale, + DO, + M, D, + stride_tok, stride_d, + H, N_CTX, + MASK_BLOCK_M1, BLOCK_N1, BLOCK_DMODEL, + start_n, start_m, num_steps, + MASK=True, + ) + + start_m += num_steps * MASK_BLOCK_M1 + num_steps = (N_CTX - start_m) // BLOCK_M1 + + # Compute dK and dV for non-masked blocks. + dk, dv = _attn_bwd_dkdv(dk, dv, + Q, k, v, sm_scale, + DO, + M, D, + stride_tok, stride_d, + H, N_CTX, + BLOCK_M1, BLOCK_N1, BLOCK_DMODEL, + start_n, start_m, num_steps, + MASK=False, + ) + + dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d + tl.store(dv_ptrs, dv) + + # Write back dK. + dk *= sm_scale + dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d + tl.store(dk_ptrs, dk) + + # THIS BLOCK DOES DQ: + start_m = pid * BLOCK_M2 + end_n = start_m + BLOCK_M2 + + MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR + offs_m = start_m + tl.arange(0, BLOCK_M2) - m = tl.load(M + offs_m) - m = m[:, None] - - # Compute dQ for masked (diagonal) blocks. - # NOTE: This code scans each row of QK^T backward (from right to left, - # but inside each call to _attn_bwd_dq, from left to right), but that's - # not due to anything important. I just wanted to reuse the loop - # structure for dK & dV above as much as possible. - num_steps = BLOCK_M2 // MASK_BLOCK_N2 - dq = _attn_bwd_dq( - dq, q, K, V, - do, m, D, - stride_tok, stride_d, - H, N_CTX, - BLOCK_M2, MASK_BLOCK_N2, BLOCK_DMODEL, - start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps, - MASK=True, - ) - end_n -= num_steps * MASK_BLOCK_N2 - # stage 2 - num_steps = end_n // BLOCK_N2 - dq = _attn_bwd_dq( - dq, q, K, V, - do, m, D, - stride_tok, stride_d, - H, N_CTX, - BLOCK_M2, BLOCK_N2, BLOCK_DMODEL, - start_m, end_n - num_steps * BLOCK_N2, num_steps, - MASK=False, - ) - # Write back dQ. - dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d - dq *= LN2 - tl.store(dq_ptrs, dq) + q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d) + dq = tl.zeros([BLOCK_M2, BLOCK_DMODEL], dtype=tl.float32) + do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d) + + m = tl.load(M + offs_m) + m = m[:, None] + + # Compute dQ for masked (diagonal) blocks. + # NOTE: This code scans each row of QK^T backward (from right to left, + # but inside each call to _attn_bwd_dq, from left to right), but that's + # not due to anything important. I just wanted to reuse the loop + # structure for dK & dV above as much as possible. + num_steps = BLOCK_M2 // MASK_BLOCK_N2 + dq = _attn_bwd_dq( + dq, q, K, V, + do, m, D, + stride_tok, stride_d, + H, N_CTX, + BLOCK_M2, MASK_BLOCK_N2, BLOCK_DMODEL, + start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps, + MASK=True, + ) + end_n -= num_steps * MASK_BLOCK_N2 + # stage 2 + num_steps = end_n // BLOCK_N2 + dq = _attn_bwd_dq( + dq, q, K, V, + do, m, D, + stride_tok, stride_d, + H, N_CTX, + BLOCK_M2, BLOCK_N2, BLOCK_DMODEL, + start_m, end_n - num_steps * BLOCK_N2, num_steps, + MASK=False, + ) + # Write back dQ. + dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d + dq *= LN2 + tl.store(dq_ptrs, dq) empty = torch.empty(128, device="cuda") @@ -491,7 +487,7 @@ def backward(ctx, do): BATCH, N_HEAD, N_CTX, BLOCK_M=PRE_BLOCK, D_HEAD=ctx.BLOCK_DMODEL, ) - grid = (N_CTX // BLOCK_N1, 2, BATCH * N_HEAD) + grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD) _attn_bwd[grid]( q, arg_k, v, ctx.sm_scale, do, dq, dk, dv, M, delta, diff --git a/test/TritonGPU/accelerate-matmul.mlir b/test/TritonGPU/accelerate-matmul.mlir index fbb83a1aec39..a56c67f15f95 100644 --- a/test/TritonGPU/accelerate-matmul.mlir +++ b/test/TritonGPU/accelerate-matmul.mlir @@ -1,6 +1,8 @@ // RUN: ENABLE_MMA_V3=1 triton-opt %s -split-input-file --tritongpu-accelerate-matmul=compute-capability=90 | FileCheck %s // CHECK: #[[MMA:.+]] = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 16, 16]}> +// CHECK: #[[MMA1:.+]] = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}> +// CHECK: #[[MMA2:.+]] = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 32, 16]}> #blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> #blocked1 = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> #blocked2 = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> @@ -22,7 +24,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #blocked2> // CHECK: scf.for // CHECK: tt.dot {{.*}} -> tensor<128x16xf16, #[[MMA]]> - // CHECK: tt.dot {{.*}} -> tensor<128x64xf16, #[[MMA]]> + // CHECK: tt.dot {{.*}} -> tensor<128x64xf16, #[[MMA1]]> %115 = scf.for %arg15 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg16 = %cst_0) -> (tensor<128x64xf16, #blocked1>) : i32 { %172 = tt.dot %170, %171, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x16xf16, #blocked> %178 = triton_gpu.convert_layout %172 : (tensor<128x16xf16, #blocked>) -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>> @@ -30,8 +32,8 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c scf.yield %180 : tensor<128x64xf16, #blocked1> } // CHECK: scf.for - // CHECK: tt.dot {{.*}} -> tensor<128x32xf16, #[[MMA]]> - // CHECK: tt.dot {{.*}} -> tensor<128x64xf16, #[[MMA]]> + // CHECK: tt.dot {{.*}} -> tensor<128x32xf16, #[[MMA2]]> + // CHECK: tt.dot {{.*}} -> tensor<128x64xf16, #[[MMA1]]> %149 = scf.for %arg15 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg16 = %115) -> (tensor<128x64xf16, #blocked1>) : i32 { %166 = tt.dot %164, %165, %cst_2 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked2}>> * tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked2}>> -> tensor<128x32xf16, #blocked2> %172 = triton_gpu.convert_layout %166 : (tensor<128x32xf16, #blocked2>) -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>> From 99af23f6f4368c62fa23e7acad925f65d85a734e Mon Sep 17 00:00:00 2001 From: evelynmitchell Date: Thu, 28 Sep 2023 11:48:38 -0600 Subject: [PATCH 103/122] [TUTORIALS] Remove unneeded quantiles parameter (#2408) The fix is to remove the quantiles parameter in both the triton and torch calls for the benchmark. --- python/tutorials/01-vector-add.py | 5 ++--- python/tutorials/02-fused-softmax.py | 7 +++---- python/tutorials/03-matrix-multiplication.py | 5 ++--- 3 files changed, 7 insertions(+), 10 deletions(-) diff --git a/python/tutorials/01-vector-add.py b/python/tutorials/01-vector-add.py index 3463ddf1ced1..7e663459e8be 100644 --- a/python/tutorials/01-vector-add.py +++ b/python/tutorials/01-vector-add.py @@ -124,11 +124,10 @@ def add(x: torch.Tensor, y: torch.Tensor): def benchmark(size, provider): x = torch.rand(size, device='cuda', dtype=torch.float32) y = torch.rand(size, device='cuda', dtype=torch.float32) - quantiles = [0.5, 0.2, 0.8] if provider == 'torch': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y) if provider == 'triton': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y), quantiles=quantiles) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y)) gbps = lambda ms: 12 * size / ms * 1e-6 return gbps(ms), gbps(max_ms), gbps(min_ms) diff --git a/python/tutorials/02-fused-softmax.py b/python/tutorials/02-fused-softmax.py index 13383cc1c783..362e317df25d 100644 --- a/python/tutorials/02-fused-softmax.py +++ b/python/tutorials/02-fused-softmax.py @@ -180,13 +180,12 @@ def softmax(x): ) def benchmark(M, N, provider): x = torch.randn(M, N, device='cuda', dtype=torch.float32) - quantiles = [0.5, 0.2, 0.8] if provider == 'torch-native': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1), quantiles=quantiles) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1)) if provider == 'triton': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x), quantiles=quantiles) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x)) if provider == 'torch-jit': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: naive_softmax(x), quantiles=quantiles) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: naive_softmax(x)) gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3) return gbps(ms), gbps(max_ms), gbps(min_ms) diff --git a/python/tutorials/03-matrix-multiplication.py b/python/tutorials/03-matrix-multiplication.py index 8bcae2007abd..a087ba39e77f 100644 --- a/python/tutorials/03-matrix-multiplication.py +++ b/python/tutorials/03-matrix-multiplication.py @@ -338,11 +338,10 @@ def matmul(a, b, activation=""): def benchmark(M, N, K, provider): a = torch.randn((M, K), device='cuda', dtype=torch.float16) b = torch.randn((K, N), device='cuda', dtype=torch.float16) - quantiles = [0.5, 0.2, 0.8] if provider == 'cublas': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b)) if provider == 'triton': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), quantiles=quantiles) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b)) perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3) return perf(ms), perf(max_ms), perf(min_ms) From d4fae901694d6977bc5d4fb1d627f1098d1f3ef7 Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Thu, 28 Sep 2023 13:59:15 -0700 Subject: [PATCH 104/122] [BACKEND][NFC] Simplify conversion to TritonGPU (#2416) Remove ad hoc patterns. This will help LLVM transition. --- .../TritonToTritonGPUPass.cpp | 252 ++---------------- 1 file changed, 27 insertions(+), 225 deletions(-) diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp index 5b980caa23be..b4f091d7bb02 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp @@ -29,17 +29,19 @@ static void addNamedAttrs(Operation *op, DictionaryAttr dictAttrs) { op->setAttr(attr.getName(), attr.getValue()); } -template class GenericOpPattern : public OpConversionPattern { -public: +template struct GenericOpPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(Op op, typename Op::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - Type retType = this->getTypeConverter()->convertType(op.getType()); - addNamedAttrs( - rewriter.replaceOpWithNewOp(op, retType, adaptor.getOperands()), - adaptor.getAttributes()); + SmallVector retTypes; + if (failed(this->getTypeConverter()->convertTypes(op->getResultTypes(), + retTypes))) + return failure(); + rewriter.replaceOpWithNewOp(op, retTypes, adaptor.getOperands(), + op->getAttrs()); + return success(); } }; @@ -88,22 +90,6 @@ class ArithConstantPattern : public OpConversionPattern { } }; -class ConvertArithOp : public ConversionPattern { -public: - ConvertArithOp(TritonGPUTypeConverter &typeConverter, MLIRContext *context) - : ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1, - context) {} - - LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - Dialect *dialect = op->getDialect(); - if (dialect->getTypeID() != mlir::TypeID::get()) - return failure(); - return success(); - } -}; - void populateArithPatternsAndLegality(TritonGPUTypeConverter &typeConverter, RewritePatternSet &patterns, TritonGPUConversionTarget &target) { @@ -199,22 +185,6 @@ void populateMathPatternsAndLegality(TritonGPUTypeConverter &typeConverter, // // Triton patterns // -// TODO: Do we need to put them in anonymous namespace? -struct TritonMakeRangePattern - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type retType = getTypeConverter()->convertType(op.getType()); - addNamedAttrs(rewriter.replaceOpWithNewOp( - op, retType, adaptor.getStart(), adaptor.getEnd()), - adaptor.getAttributes()); - return success(); - } -}; - struct TritonExpandDimsPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -435,85 +405,6 @@ struct TritonTransPattern : public OpConversionPattern { } }; -struct TritonLoadPattern : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - addNamedAttrs(rewriter.replaceOpWithNewOp( - op, typeConverter->convertType(op.getType()), - adaptor.getPtr(), adaptor.getMask(), adaptor.getOther(), - adaptor.getBoundaryCheckAttr(), adaptor.getPaddingAttr(), - adaptor.getCache(), adaptor.getEvict(), - adaptor.getIsVolatile()), - adaptor.getAttributes()); - return success(); - } -}; - -struct TritonStorePattern : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - addNamedAttrs(rewriter.replaceOpWithNewOp( - op, adaptor.getPtr(), adaptor.getValue(), - adaptor.getMask(), adaptor.getCache(), - adaptor.getEvict()), - adaptor.getAttributes()); - return success(); - } -}; - -struct TritonAtomicCASPattern - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(triton::AtomicCASOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - addNamedAttrs(rewriter.replaceOpWithNewOp( - op, typeConverter->convertType(op.getType()), - adaptor.getPtr(), adaptor.getCmp(), adaptor.getVal(), - op.getSem()), - adaptor.getAttributes()); - return success(); - } -}; - -struct TritonAtomicRMWPattern - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(triton::AtomicRMWOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - addNamedAttrs(rewriter.replaceOpWithNewOp( - op, typeConverter->convertType(op.getType()), - adaptor.getAtomicRmwOp(), adaptor.getPtr(), - adaptor.getVal(), adaptor.getMask(), op.getSem()), - adaptor.getAttributes()); - return success(); - } -}; - -template -struct TritonGenericPattern : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(Op op, typename Op::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type retType = this->getTypeConverter()->convertType(op.getType()); - addNamedAttrs( - rewriter.replaceOpWithNewOp(op, retType, adaptor.getOperands()), - adaptor.getAttributes()); - return success(); - } -}; - struct TritonBroadcastPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -555,20 +446,6 @@ struct TritonReducePattern : public OpConversionPattern { } }; -struct TritonReduceReturnPattern - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(triton::ReduceReturnOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - addNamedAttrs(rewriter.replaceOpWithNewOp( - op, adaptor.getResult()), - adaptor.getAttributes()); - return success(); - } -}; - struct TritonScanPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -587,66 +464,6 @@ struct TritonScanPattern : public OpConversionPattern { } }; -struct TritonScanReturnPattern - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(triton::ScanReturnOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - addNamedAttrs(rewriter.replaceOpWithNewOp( - op, adaptor.getResult()), - adaptor.getAttributes()); - return success(); - } -}; - -struct TritonExternElementwisePattern - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(triton::ExternElementwiseOp op, - typename triton::ExternElementwiseOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type retType = this->getTypeConverter()->convertType(op.getType()); - addNamedAttrs(rewriter.replaceOpWithNewOp( - op, retType, adaptor.getOperands(), op.getLibnameAttr(), - op.getLibpathAttr(), op.getSymbolAttr(), - op.getPureAttr()), - adaptor.getAttributes()); - return success(); - } -}; - -struct TritonPrintPattern : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(triton::PrintOp op, typename triton::PrintOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - addNamedAttrs(rewriter.replaceOpWithNewOp( - op, op.getPrefixAttr(), adaptor.getOperands()), - adaptor.getAttributes()); - return success(); - } -}; - -struct TritonAssertPattern : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(triton::AssertOp op, - typename triton::AssertOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - addNamedAttrs(rewriter.replaceOpWithNewOp( - op, adaptor.getCondition(), op.getMessageAttr(), - op.getFileAttr(), op.getFuncAttr(), op.getLineAttr()), - adaptor.getAttributes()); - return success(); - } -}; - class TritonFuncOpPattern : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -698,22 +515,23 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter, MLIRContext *context = patterns.getContext(); patterns.insert< // TODO: view should have custom pattern that views the // layout - TritonGenericPattern, - TritonGenericPattern, - TritonGenericPattern, - TritonGenericPattern, - TritonGenericPattern, - TritonGenericPattern, - TritonGenericPattern, - TritonGenericPattern, TritonBroadcastPattern, - TritonGenericPattern, TritonCatPattern, - TritonGenericPattern, TritonReducePattern, - TritonReduceReturnPattern, TritonScanPattern, TritonScanReturnPattern, - TritonTransPattern, TritonExpandDimsPattern, TritonMakeRangePattern, - TritonDotPattern, TritonLoadPattern, TritonStorePattern, - TritonExternElementwisePattern, TritonPrintPattern, TritonAssertPattern, - TritonAtomicRMWPattern, TritonFuncOpPattern, TritonReturnOpPattern, - TritonCallOpPattern>(typeConverter, context); + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + TritonBroadcastPattern, GenericOpPattern, + TritonCatPattern, GenericOpPattern, + TritonReducePattern, GenericOpPattern, + TritonScanPattern, GenericOpPattern, + GenericOpPattern, TritonExpandDimsPattern, + TritonTransPattern, TritonDotPattern, GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, TritonFuncOpPattern>(typeConverter, + context); } // @@ -763,22 +581,6 @@ struct SCFForPattern : public OpConversionPattern { } }; -struct SCFYieldPattern : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(scf::YieldOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock()); - // rewriter.create(op.getLoc(), adaptor.getOperands()); - // op.erase(); - addNamedAttrs( - rewriter.replaceOpWithNewOp(op, adaptor.getOperands()), - adaptor.getAttributes()); - return success(); - } -}; - // This is borrowed from ConvertFIfOpTypes in // SCF/Transforms/StructuralTypeConversions.cpp class SCFIfPattern : public OpConversionPattern { @@ -866,8 +668,8 @@ class SCFConditionPattern : public OpConversionPattern { void populateSCFPatterns(TritonGPUTypeConverter &typeConverter, RewritePatternSet &patterns) { MLIRContext *context = patterns.getContext(); - patterns.add(typeConverter, context); + patterns.add, SCFForPattern, SCFIfPattern, + SCFWhilePattern, SCFConditionPattern>(typeConverter, context); } // CF From 90bef57acfda717c52c688e755e727e30c881c97 Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Thu, 28 Sep 2023 22:45:28 -0700 Subject: [PATCH 105/122] [BACKEND] turn on MMA V3 by default on Hopper (#2414) --- .github/workflows/integration-tests.yml | 20 +++++++++---------- include/triton/Tools/Sys/GetEnv.hpp | 2 +- lib/Analysis/Utility.cpp | 2 +- .../Transforms/FenceInsertion.cpp | 3 +-- python/test/unit/hopper/test_gemm.py | 5 ++--- .../test_persistent_warp_specialized_gemm.py | 1 - python/test/unit/language/test_core.py | 10 ++++------ python/triton/ops/flash_attention.py | 7 ++----- test/Conversion/tritongpu_to_llvm_hopper.mlir | 2 +- test/TritonGPU/accelerate-matmul.mlir | 2 +- test/TritonGPU/fence-inserstion.mlir | 2 +- 11 files changed, 23 insertions(+), 33 deletions(-) diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 3234a1094cd0..b4f4495cf4bf 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -53,7 +53,6 @@ jobs: run: | echo "BACKEND=CUDA" >> "${GITHUB_ENV}" echo "ENABLE_TMA=0" >> "${GITHUB_ENV}" - echo "ENABLE_MMA_V3=0" >> "${GITHUB_ENV}" echo "TRITON_DISABLE_LINE_INFO=1" >> "${GITHUB_ENV}" - name: Clear cache @@ -90,14 +89,13 @@ jobs: fi lit -v "${LIT_TEST_DIR}" - - name: Enable MMAV3 and TMA + - name: Enable TMA if: ${{(matrix.runner[0] == 'self-hosted') && (matrix.runner[1] == 'H100')}} run: | echo "ENABLE_TMA=1" >> "${GITHUB_ENV}" - echo "ENABLE_MMA_V3=1" >> "${GITHUB_ENV}" - - name: Run python tests on CUDA with ENABLE_TMA=1 and ENABLE_MMA_V3=1 - if: ${{ env.BACKEND == 'CUDA' && env.ENABLE_TMA == '1' && env.ENABLE_MMA_V3 == '1'}} + - name: Run python tests on CUDA with ENABLE_TMA=1 + if: ${{ env.BACKEND == 'CUDA' && env.ENABLE_TMA == '1'}} run: | cd python/test/unit python3 -m pytest -n 8 --ignore=runtime --ignore=operators --ignore=language/test_line_info.py --ignore=language/test_subprocess.py @@ -109,8 +107,8 @@ jobs: #run hopper/test_flashattention.py to avoid out of gpu memory python3 -m pytest hopper/test_flashattention.py - - name: Run python tests on CUDA with ENABLE_TMA=0 and ENABLE_MMA_V3=0 - if: ${{ env.BACKEND == 'CUDA' && env.ENABLE_TMA == '0' && env.ENABLE_MMA_V3 == '0'}} + - name: Run python tests on CUDA with ENABLE_TMA=0 + if: ${{ env.BACKEND == 'CUDA' && env.ENABLE_TMA == '0'}} run: | cd python/test/unit python3 -m pytest -n 8 --ignore=runtime --ignore=hopper --ignore=operators --ignore=language/test_line_info.py @@ -131,14 +129,14 @@ jobs: cd python/test/unit python3 -m pytest -vs operators/test_flash_attention.py - - name: Run partial tests on CUDA with ENABLE_TMA=1 and ENABLE_MMA_V3=1 - if: ${{ env.BACKEND == 'CUDA' && env.ENABLE_TMA == '1' && env.ENABLE_MMA_V3 == '1'}} + - name: Run partial tests on CUDA with ENABLE_TMA=1 + if: ${{ env.BACKEND == 'CUDA' && env.ENABLE_TMA == '1'}} run: | cd python/test/unit python3 -m pytest -n 8 operators - - name: Run partial tests on CUDA with ENABLE_TMA=0 and ENABLE_MMA_V3=0 - if: ${{ env.BACKEND == 'CUDA' && env.ENABLE_TMA == '0' && env.ENABLE_MMA_V3 == '0'}} + - name: Run partial tests on CUDA with ENABLE_TMA=0 + if: ${{ env.BACKEND == 'CUDA' && env.ENABLE_TMA == '0'}} run: | cd python/test/unit python3 -m pytest -n 8 operators diff --git a/include/triton/Tools/Sys/GetEnv.hpp b/include/triton/Tools/Sys/GetEnv.hpp index bf682af946b1..d1ff8ab83654 100644 --- a/include/triton/Tools/Sys/GetEnv.hpp +++ b/include/triton/Tools/Sys/GetEnv.hpp @@ -30,7 +30,7 @@ namespace triton { const std::set ENV_VARS = { - "ENABLE_MMA_V3", "TRITON_DISABLE_LINE_INFO", "DISABLE_FAST_REDUCTION", + "DISABLE_MMA_V3", "TRITON_DISABLE_LINE_INFO", "DISABLE_FAST_REDUCTION", "ENABLE_TMA", "MLIR_ENABLE_DUMP", "LLVM_IR_ENABLE_DUMP", "AMDGCN_ENABLE_DUMP"}; diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 3f1b594ad23a..512fbc6b92c2 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -381,7 +381,7 @@ bool supportMMA(triton::DotOp op, int version) { auto aElemTy = op.getA().getType().cast().getElementType(); auto bElemTy = op.getB().getType().cast().getElementType(); if (version == 3) { - if (!::triton::tools::getBoolEnv("ENABLE_MMA_V3")) + if (::triton::tools::getBoolEnv("DISABLE_MMA_V3")) return false; auto retType = op.getResult().getType().cast(); auto retShapePerCTA = triton::gpu::getShapePerCTA(retType); diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp index 3d839874607d..25a43529d68c 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp @@ -40,8 +40,7 @@ struct FenceInsertionPass // Only insert fences for compute capability 9.0 if (computeCapability < 90) return; - // ENABLE_MMA_V3 - if (!::triton::tools::getBoolEnv("ENABLE_MMA_V3")) + if (::triton::tools::getBoolEnv("DISABLE_MMA_V3")) return; ModuleOp mod = getOperation(); mod.walk([&](Operation *op) { diff --git a/python/test/unit/hopper/test_gemm.py b/python/test/unit/hopper/test_gemm.py index af236d0de3d8..a3e3f80b917a 100644 --- a/python/test/unit/hopper/test_gemm.py +++ b/python/test/unit/hopper/test_gemm.py @@ -331,7 +331,6 @@ def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, ]: pytest.skip('shapePerCTA[1] < 16 not supported') - # with ENABLE_TMA=0 and ENABLE_MMA_V3=0 if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ '16-32-64-4-1-256-256-256-False', '16-32-64-4-2-256-256-256-False', @@ -444,7 +443,7 @@ def grid(META): atol=1e-3, check_dtype=False) - enable_mmav3 = os.environ.get('ENABLE_MMA_V3', 'not found').lower() - if enable_mmav3 in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256: + disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower() + if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256: ptx = pgm.asm['ptx'] assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(BLOCK_N), ptx) diff --git a/python/test/unit/hopper/test_persistent_warp_specialized_gemm.py b/python/test/unit/hopper/test_persistent_warp_specialized_gemm.py index bd1e70ec58a9..32c04c33bc31 100644 --- a/python/test/unit/hopper/test_persistent_warp_specialized_gemm.py +++ b/python/test/unit/hopper/test_persistent_warp_specialized_gemm.py @@ -818,7 +818,6 @@ def test_full_static_persistent_matmul_kernel(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WAR ]: pytest.skip('shapePerCTA[1] < 16 not supported') - # with ENABLE_TMA=0 and ENABLE_MMA_V3=0 if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ '16-32-64-4-1-256-256-256-False', '16-32-64-4-2-256-256-256-False', diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 05afbf31c943..8317b1670ee3 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -2433,11 +2433,10 @@ def kernel(X, stride_xm, stride_xk, red_code = ptx[start:end] assert len(red_code) > 0 import os - enable_mmav3 = os.environ.get('ENABLE_MMA_V3', 'not found').lower() - enable_tma = os.environ.get('ENABLE_TMA', 'not found').lower() + # skip this check on hopper because there are some functions whose name contain "shared" in ptx. # TODO: we should eliminate these unused functions in ptx code. - if not (enable_mmav3 in ["on", "true", "1"] and enable_tma in ["on", "true", "1"]): + if not (capability[0] >= 9): assert "shared" not in red_code assert "bar.sync" not in red_code # torch result @@ -2540,13 +2539,12 @@ def kernel(Z, X, Y, if is_hip(): return assert "tt.dot" in h.asm['ttir'] - # with option ENABLE_MMA_V3 on, we will not pipeline the load op for Y + # when using MMAv3, we will not pipeline the load op for Y # as the loaded value is in rowmajor. But MMAv3 requires it's second # operand is in colmajor because transpose is not supported for MMAv3 # with float32 input. import os - enable_mmav3 = os.environ.get('ENABLE_MMA_V3', 'not found').lower() - if enable_mmav3 in ["on", "true", "1"]: + if capability[0] >= 9: assert "triton_gpu.async_wait {num = 1 : i32}" in h.asm['ttgir'] else: assert "triton_gpu.async_wait {num = 2 : i32}" in h.asm['ttgir'] diff --git a/python/triton/ops/flash_attention.py b/python/triton/ops/flash_attention.py index 752c7d2f822d..651c162e602b 100644 --- a/python/triton/ops/flash_attention.py +++ b/python/triton/ops/flash_attention.py @@ -335,11 +335,8 @@ def forward(ctx, q, k, v, causal, sm_scale, sequence_parallel=False): @staticmethod def backward(ctx, do): - import os - enable_mmav3 = os.environ.get('ENABLE_MMA_V3', 'not found').lower() - MMA_V3 = False - if enable_mmav3 in ["on", "true", "1"]: - MMA_V3 = True + capability = torch.cuda.get_device_capability() + MMA_V3 = capability[0] >= 9 BLOCK = 128 q, k, v, o, L = ctx.saved_tensors sequence_parallel = ctx.sequence_parallel diff --git a/test/Conversion/tritongpu_to_llvm_hopper.mlir b/test/Conversion/tritongpu_to_llvm_hopper.mlir index 33db273c8ed5..21a7e29b0722 100644 --- a/test/Conversion/tritongpu_to_llvm_hopper.mlir +++ b/test/Conversion/tritongpu_to_llvm_hopper.mlir @@ -1,4 +1,4 @@ -// RUN: ENABLE_MMA_V3=1 triton-opt %s -split-input-file --convert-triton-gpu-to-llvm 2>&1 | FileCheck %s +// RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm 2>&1 | FileCheck %s #blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 4], CTASplitNum = [1, 4], CTAOrder = [0, 1]}> #shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 4], CTASplitNum = [1, 4], CTAOrder = [0, 1], hasLeadingOffset = true}> diff --git a/test/TritonGPU/accelerate-matmul.mlir b/test/TritonGPU/accelerate-matmul.mlir index a56c67f15f95..09fa73cd989c 100644 --- a/test/TritonGPU/accelerate-matmul.mlir +++ b/test/TritonGPU/accelerate-matmul.mlir @@ -1,4 +1,4 @@ -// RUN: ENABLE_MMA_V3=1 triton-opt %s -split-input-file --tritongpu-accelerate-matmul=compute-capability=90 | FileCheck %s +// RUN: triton-opt %s -split-input-file --tritongpu-accelerate-matmul=compute-capability=90 | FileCheck %s // CHECK: #[[MMA:.+]] = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 16, 16]}> // CHECK: #[[MMA1:.+]] = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}> diff --git a/test/TritonGPU/fence-inserstion.mlir b/test/TritonGPU/fence-inserstion.mlir index c5ef88dbc68b..ff4a1fe1e8f2 100644 --- a/test/TritonGPU/fence-inserstion.mlir +++ b/test/TritonGPU/fence-inserstion.mlir @@ -1,4 +1,4 @@ -// RUN: ENABLE_MMA_V3=1 triton-opt %s -split-input-file --triton-nvidia-gpu-fence-insertion | FileCheck %s +// RUN: triton-opt %s -split-input-file --triton-nvidia-gpu-fence-insertion | FileCheck %s #blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> #mma = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 128, 16]}> From f2f5f1d45710f694fb33fb0d165b0e55a07dd3bf Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Fri, 29 Sep 2023 17:24:30 -0400 Subject: [PATCH 106/122] [TUTORIALS] Add missing docstrings (#2420) Depend on https://github.com/openai/triton/pull/2419 to fix the documentation workflow --- python/tutorials/11-grouped-gemm.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/python/tutorials/11-grouped-gemm.py b/python/tutorials/11-grouped-gemm.py index 034e4e217abe..e27acebd1627 100644 --- a/python/tutorials/11-grouped-gemm.py +++ b/python/tutorials/11-grouped-gemm.py @@ -1,3 +1,11 @@ + +""" +Group GEMM +============================ +This group gemm kernel launches a fixed number of CTA to compute a group +of gemms. The scheduling is static and we do it on device. +""" + # Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining @@ -24,9 +32,6 @@ import triton import triton.language as tl -# This group gemm kernel launches a fixed number of CTA to compute a group -# of gemms. The scheduling is static and we do it on device - @triton.autotune( configs=[ From e284112818146a2f99cdc2e52af98ba4b329e586 Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Fri, 29 Sep 2023 17:24:50 -0400 Subject: [PATCH 107/122] Revert "[TUTORIALS] Remove unneeded quantiles parameter (#2408)" (#2419) This reverts commit 99af23f6f4368c62fa23e7acad925f65d85a734e. `quantiles` shouldn't be the problem. The documentation workflow failed because of other issues. --- python/tutorials/01-vector-add.py | 5 +++-- python/tutorials/02-fused-softmax.py | 7 ++++--- python/tutorials/03-matrix-multiplication.py | 5 +++-- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/python/tutorials/01-vector-add.py b/python/tutorials/01-vector-add.py index 7e663459e8be..3463ddf1ced1 100644 --- a/python/tutorials/01-vector-add.py +++ b/python/tutorials/01-vector-add.py @@ -124,10 +124,11 @@ def add(x: torch.Tensor, y: torch.Tensor): def benchmark(size, provider): x = torch.rand(size, device='cuda', dtype=torch.float32) y = torch.rand(size, device='cuda', dtype=torch.float32) + quantiles = [0.5, 0.2, 0.8] if provider == 'torch': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles) if provider == 'triton': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y)) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y), quantiles=quantiles) gbps = lambda ms: 12 * size / ms * 1e-6 return gbps(ms), gbps(max_ms), gbps(min_ms) diff --git a/python/tutorials/02-fused-softmax.py b/python/tutorials/02-fused-softmax.py index 362e317df25d..13383cc1c783 100644 --- a/python/tutorials/02-fused-softmax.py +++ b/python/tutorials/02-fused-softmax.py @@ -180,12 +180,13 @@ def softmax(x): ) def benchmark(M, N, provider): x = torch.randn(M, N, device='cuda', dtype=torch.float32) + quantiles = [0.5, 0.2, 0.8] if provider == 'torch-native': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1)) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1), quantiles=quantiles) if provider == 'triton': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x)) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x), quantiles=quantiles) if provider == 'torch-jit': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: naive_softmax(x)) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: naive_softmax(x), quantiles=quantiles) gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3) return gbps(ms), gbps(max_ms), gbps(min_ms) diff --git a/python/tutorials/03-matrix-multiplication.py b/python/tutorials/03-matrix-multiplication.py index a087ba39e77f..8bcae2007abd 100644 --- a/python/tutorials/03-matrix-multiplication.py +++ b/python/tutorials/03-matrix-multiplication.py @@ -338,10 +338,11 @@ def matmul(a, b, activation=""): def benchmark(M, N, K, provider): a = torch.randn((M, K), device='cuda', dtype=torch.float16) b = torch.randn((K, N), device='cuda', dtype=torch.float16) + quantiles = [0.5, 0.2, 0.8] if provider == 'cublas': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b)) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles) if provider == 'triton': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b)) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), quantiles=quantiles) perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3) return perf(ms), perf(max_ms), perf(min_ms) From e0edb70f78a3702f727bbf1e9c2977d8f90bf530 Mon Sep 17 00:00:00 2001 From: Hongtao Yu Date: Fri, 29 Sep 2023 17:29:41 -0700 Subject: [PATCH 108/122] [BACKEND] support of Fp8E4M3Nv to Bf16 conversion (#2415) --- .../TritonGPUToLLVM/ElementwiseOpToLLVM.cpp | 24 +++++++++++++++++++ python/test/unit/language/test_core.py | 4 ++-- 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp index c34f64539a62..f3f39db0ca2a 100644 --- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -182,6 +182,28 @@ const std::string Fp16_to_Fp8E4M3Nv = "{ \n" "cvt.rn.satfinite.e4m3x2.f16x2 $0, $1; \n" "}"; +// Fp8E4M3 (x2) -> Fp16 (x2) (packed) +const std::string Fp8E4M3Nv_to_Bf16 = + "{ \n" + ".reg .f16 a<2>; \n" + ".reg .bf16 b<2>; \n" + "cvt.rn.f16x2.e4m3x2 {a0, a1}, $1; \n" + "cvt.bf16.f16 b0, a0; \n" + "cvt.bf16.f16 b1, a1; \n" + "mov.b32 $0, {b0, b1}; \n" + "}"; + +// Bf16 (x2) -> Fp8E4M3 (x2) (packed) +const std::string Bf16_to_Fp8E4M3Nv = + "{ \n" + ".reg .bf16 a<2>; \n" + ".reg .f32 b<2>; \n" + "mov.b32 {a0, a1}, $1; \n" + "cvt.f32.bf16 b0, a0; \n" + "cvt.f32.bf16 b1, a1; \n" + "cvt.rn.satfinite.e4m3x2.f32 $0, b0, b1; \n" + "}"; + /* ----- Packed integer to BF16 ------ */ const std::string S8_to_Bf16 = "{ \n" @@ -582,8 +604,10 @@ struct FpToFpOpConversion {{F16TyID, F8E5M2TyID}, Fp16_to_Fp8E5M2}, // F8 -> BF16 {{F8E5M2TyID, BF16TyID}, Fp8E5M2_to_Bf16}, + {{F8E4M3TyID, BF16TyID}, Fp8E4M3Nv_to_Bf16}, // BF16 -> F8 {{BF16TyID, F8E5M2TyID}, Bf16_to_Fp8E5M2}, + {{BF16TyID, F8E4M3TyID}, Bf16_to_Fp8E4M3Nv}, }; int inVecWidthBits = 32; int outVecWidthBits = 32; diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 8317b1670ee3..bb857fbcbf5e 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1299,12 +1299,12 @@ def serialized_add(data, Lock, SEM: tl.constexpr): ] + (([ (dtype_x, dtype_z, False, size) for dtype_x in torch_float8_dtypes - for dtype_z in ["float16", "float32"] + for dtype_z in ["float16", "float32", "bfloat16"] for size in [1024, 32] ] + [ (dtype_x, dtype_z, False, size) for dtype_z in torch_float8_dtypes - for dtype_x in ["float16", "float32"] + for dtype_x in ["float16", "float32", "bfloat16"] for size in [1024, 32] ]) if torch.__version__ >= "2.1" else [])) @pytest.mark.parametrize("num_ctas", num_ctas_list) From ee013d8978ab03a89ebab0f72d165b2852ba9a82 Mon Sep 17 00:00:00 2001 From: Ying Zhang Date: Fri, 29 Sep 2023 19:36:00 -0700 Subject: [PATCH 109/122] Fix PTX issues in bf16 / fp8_e4m3 conversion (#2421) Fix bugs in https://github.com/openai/triton/pull/2415. cc @htyu Previously corresponding tests failed on H100 with latest torch version. It passed CI because CI doesn't use latest torch, so the tests were skipped. --- lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp index f3f39db0ca2a..44e2a2b71700 100644 --- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -185,9 +185,11 @@ const std::string Fp16_to_Fp8E4M3Nv = "{ \n" // Fp8E4M3 (x2) -> Fp16 (x2) (packed) const std::string Fp8E4M3Nv_to_Bf16 = "{ \n" + ".reg .b32 a; \n" ".reg .f16 a<2>; \n" - ".reg .bf16 b<2>; \n" - "cvt.rn.f16x2.e4m3x2 {a0, a1}, $1; \n" + ".reg .b16 b<2>; \n" + "cvt.rn.f16x2.e4m3x2 a, $1; \n" + "mov.b32 {a0, a1}, a; \n" "cvt.bf16.f16 b0, a0; \n" "cvt.bf16.f16 b1, a1; \n" "mov.b32 $0, {b0, b1}; \n" @@ -196,7 +198,7 @@ const std::string Fp8E4M3Nv_to_Bf16 = // Bf16 (x2) -> Fp8E4M3 (x2) (packed) const std::string Bf16_to_Fp8E4M3Nv = "{ \n" - ".reg .bf16 a<2>; \n" + ".reg .b16 a<2>; \n" ".reg .f32 b<2>; \n" "mov.b32 {a0, a1}, $1; \n" "cvt.f32.bf16 b0, a0; \n" From 533efd0cacb6fb6fd530533c21c7a26371eee165 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Fri, 29 Sep 2023 23:33:28 -0700 Subject: [PATCH 110/122] [FRONTEND][BACKEND] changed float8e4b15 clipping semantics from +-1.875 to +-1.75 (#2422) clipping float8e4b15 to +-1.875 is a bad idea because these are represented as 0x7f and 0xff, which are +- nan on H100 for float8e4nv. We lose two values but this will make compatibility with float8e4nv way less painful. (it will just be a matter of adjusting the bias) --- .../TritonGPUToLLVM/ElementwiseOpToLLVM.cpp | 4 ++-- python/test/unit/language/test_core.py | 16 ++++++++++------ 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp index 44e2a2b71700..f33bee2b3ed4 100644 --- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -109,8 +109,8 @@ const std::string Fp16_to_Fp8E4M3B15(bool has_minx2) { ".reg .b16 c<4>; \n" ".reg .b16 max_val_f16; \n" ".reg .b32 max_val_f16x2; \n" - "mov.b16 max_val_f16, 0x3F80; \n" - "mov.b32 max_val_f16x2, 0x3F803F80; \n" + "mov.b16 max_val_f16, 0x3F00; \n" + "mov.b32 max_val_f16x2, 0x3F003F00; \n" "and.b32 a0, $1, 0x7fff7fff; \n" "and.b32 a1, $2, 0x7fff7fff; \n"; if (has_minx2) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index bb857fbcbf5e..c7c76c6a904b 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1526,22 +1526,26 @@ def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): # initialize array containing all possible f8 values except NaN ref_fp8 = np.array(range(-128, 128), dtype=np.int8) - is_nan = (ref_fp8 & 0b01111100) == 128 - 2**in_dtype.fp_mantissa_width exp_mask = 0b01111111 ^ ((1 << in_dtype.fp_mantissa_width) - 1) + is_nan = (ref_fp8 & 0b01111100) == 128 - 2**in_dtype.fp_mantissa_width is_subnormal = np.logical_or((ref_fp8 & exp_mask) == 0, (ref_fp8 & exp_mask) == exp_mask) - ref_fp8[is_nan] = 0 - ref_fp8[is_subnormal] = 0 tri_fp8 = torch.from_numpy(serialize_fp8(ref_fp8, in_dtype)).cuda() + # check that non-subnormal fp8 are correctly converted to fp16 tri_fp16 = torch.empty(256, dtype=out_dtype, device="cuda") copy_kernel[(1,)](triton.reinterpret(tri_fp8, in_dtype), tri_fp16, tri_fp16.shape[0], BLOCK_SIZE=1024) - ref_fp8 = torch.from_numpy(ref_fp8).cuda() ref_fp16 = convert_float_to_float32(ref_fp8, in_dtype) assert torch.all(tri_fp16[~is_subnormal] == ref_fp16[~is_subnormal]) - + # check that values are properly converted back to float8 ref_fp8 = torch.empty_like(tri_fp16, dtype=torch.int8) copy_kernel[(1,)](tri_fp16, triton.reinterpret(ref_fp8, in_dtype), tri_fp16.shape[0], BLOCK_SIZE=1024) - assert torch.all(tri_fp8 == ref_fp8) + if in_dtype == tl.float8e4b15: + assert torch.all(tri_fp8[:127] == ref_fp8[:127]) + assert torch.all(tri_fp8[128:255] == ref_fp8[128:255]) + assert ref_fp8[126] == ref_fp8[127] # -1.875 saturates to -1.75 + assert ref_fp8[254] == ref_fp8[255] # 1.875 saturates to 1.75 + else: + assert torch.all(tri_fp8[~is_subnormal] == ref_fp8[~is_subnormal]) # --------------- # test reduce From c4f3afc020ddd27c6a06e2583807c2bd5cff6429 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Fri, 29 Sep 2023 23:48:08 -0700 Subject: [PATCH 111/122] [CI] disable pypy wheel (#2423) emitting warnings from C++ code requires "#include pybind11/exec.h" which is not compatible with pypy. I think using the python interpreter form C++ is a bad idea in general... but we probably don't care much about pypy wheels anyway --- .github/workflows/wheels.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index b841e8942811..7408d3348a62 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -46,8 +46,8 @@ jobs: export CIBW_MANYLINUX_X86_64_IMAGE="quay.io/pypa/manylinux2014_x86_64:latest" #export CIBW_MANYLINUX_PYPY_X86_64_IMAGE="quay.io/pypa/manylinux2014_x86_64:latest" export CIBW_BEFORE_BUILD="pip install cmake;" - export CIBW_SKIP="{cp,pp}{35,36}-*" - export CIBW_BUILD="{cp,pp}3*-manylinux_x86_64" + export CIBW_SKIP="{cp}{35,36}-*" + export CIBW_BUILD="{cp}3*-manylinux_x86_64" python3 -m cibuildwheel python --output-dir wheelhouse - name: Install Azure CLI From 98039658d474d95e714a302783aa2cea19eb017c Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sat, 30 Sep 2023 00:38:06 -0700 Subject: [PATCH 112/122] [CI] disable pypy wheel (continued) (#2424) there's a typo in the previous commit --- .github/workflows/wheels.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index 7408d3348a62..98f05e7202bb 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -46,8 +46,8 @@ jobs: export CIBW_MANYLINUX_X86_64_IMAGE="quay.io/pypa/manylinux2014_x86_64:latest" #export CIBW_MANYLINUX_PYPY_X86_64_IMAGE="quay.io/pypa/manylinux2014_x86_64:latest" export CIBW_BEFORE_BUILD="pip install cmake;" - export CIBW_SKIP="{cp}{35,36}-*" - export CIBW_BUILD="{cp}3*-manylinux_x86_64" + export CIBW_SKIP="cp{35,36}-*" + export CIBW_BUILD="cp3*-manylinux_x86_64" python3 -m cibuildwheel python --output-dir wheelhouse - name: Install Azure CLI From 97e35b677bf991765a612f7e0995accdabe91a43 Mon Sep 17 00:00:00 2001 From: Tori Baker Date: Sat, 30 Sep 2023 19:53:43 +0200 Subject: [PATCH 113/122] [BACKEND] fix division by 0 pathway (#2412) It was possible for multiDimWarpId[1] to be 0 which then gets translated into a `urem 0, 0` and results in an unreachable when going through llvm, an empty kernel, and nans. This PR uses ceiling to clamp the result to be >=1. chsigg is working on a fix to lower the unreachable in llvm to a trap (https://github.com/llvm/llvm-project/pull/67478). --- .../TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp | 6 ++++-- test/Conversion/divide-by-0.mlir | 14 ++++++++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) create mode 100644 test/Conversion/divide-by-0.mlir diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index d33217bd2a60..404ba2be61d4 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -177,9 +177,11 @@ struct ConvertLayoutOpConversion Value _16 = i32_val(16); if (mmaLayout.isAmpere() || mmaLayout.isHopper()) { multiDimWarpId[0] = - urem(multiDimWarpId[0], i32_val(shapePerCTA[0] / instrShape[0])); + urem(multiDimWarpId[0], + i32_val(ceil(shapePerCTA[0], instrShape[0]))); multiDimWarpId[1] = - urem(multiDimWarpId[1], i32_val(shapePerCTA[1] / instrShape[1])); + urem(multiDimWarpId[1], + i32_val(ceil(shapePerCTA[1], instrShape[1]))); Value mmaGrpId = udiv(laneId, _4); Value mmaGrpIdP8 = add(mmaGrpId, _8); diff --git a/test/Conversion/divide-by-0.mlir b/test/Conversion/divide-by-0.mlir new file mode 100644 index 000000000000..8eca5a34887a --- /dev/null +++ b/test/Conversion/divide-by-0.mlir @@ -0,0 +1,14 @@ +// RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm="target=nvvm" --cse | FileCheck %s + +// CHECK-LABEL: dont_divide_0 +// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK-NOT: llvm.urem %{{.*}}, %[[C0]] +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#mma = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 8]}> +module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @dont_divide_0() attributes {noinline = false} { + %zero = arith.constant dense<0.000000e+00> : tensor<16x1xf32, #mma> + %cvt = triton_gpu.convert_layout %zero : (tensor<16x1xf32, #mma>) -> tensor<16x1xf32, #blocked> + tt.return + } +} From a0025cfc44c841e0c37fa3f3709ea96c109601e8 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sun, 1 Oct 2023 16:07:50 -0700 Subject: [PATCH 114/122] [FRONTEND] add missing implicit constexpr conversion in `dot` (#2427) --- python/triton/language/core.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 496aa42000c9..b33aa2e10c59 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -1002,6 +1002,7 @@ def dot(input, other, acc=None, allow_tf32=True, max_num_imprecise_acc=None, out """ allow_tf32 = _constexpr_to_value(allow_tf32) out_dtype = _constexpr_to_value(out_dtype) + max_num_imprecise_acc = _constexpr_to_value(max_num_imprecise_acc) return semantic.dot(input, other, acc, allow_tf32, max_num_imprecise_acc, out_dtype, _builder) From ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Mon, 2 Oct 2023 00:43:05 -0400 Subject: [PATCH 115/122] [BACKEND] Fine-tune SharedMemoryObject definition and fix related problems (#2428) --- .../TritonGPUToLLVM/TritonGPUToLLVMBase.h | 7 ++++--- lib/Conversion/TritonGPUToLLVM/Utility.h | 17 ++++++++++------- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h index 667f174f2270..2763a76c65c5 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h @@ -462,11 +462,12 @@ class ConvertTritonGPUOpToLLVMPatternBase { unsigned inVec = srcSharedLayout.getVec(); unsigned minVec = std::min(outVec, inVec); unsigned outElems = triton::gpu::getTotalElemsPerThread(dstTy); + SmallVector offsetVals = {i32_val(0), i32_val(0)}; assert(outElems == dstIndices.size()); - DenseMap sharedPtrs = getSwizzledSharedPtrs( - loc, outVec, dstTy, srcSharedLayout, srcElemTy, smemObj, rewriter, - smemObj.offsets, smemObj.strides); + DenseMap sharedPtrs = + getSwizzledSharedPtrs(loc, outVec, dstTy, srcSharedLayout, srcElemTy, + smemObj, rewriter, offsetVals, smemObj.strides); assert(outElems % minVec == 0 && "Unexpected number of elements"); unsigned numVecs = outElems / minVec; auto wordTy = vec_ty(elemTy, minVec); diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.h b/lib/Conversion/TritonGPUToLLVM/Utility.h index 43faa333e05a..08357af6af20 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.h +++ b/lib/Conversion/TritonGPUToLLVM/Utility.h @@ -232,21 +232,24 @@ SmallVector getStridesFromShapeAndOrder(ArrayRef shape, ArrayRef order, Location loc, ConversionPatternRewriter &rewriter); struct SharedMemoryObject { - Value base; // i32 ptr. The start address of the shared memory object. - // We need to store strides as Values but not integers because the + Value base; // i32 ptr. The start address of the shared memory object after + // the initial allocation or the last slicing operation. + // We need to store strides as Values, not integers, because the // extract_slice instruction can take a slice at arbitrary offsets. - // Take $a[16:32, 16:32] as an example, though we know the stride of $a[0] is - // 32, we need to let the instruction that uses $a to be aware of that. + // Take $a[16:32, 16:32] as an example; though we know the stride of $a[0] is + // 32, we need to let the instruction that uses $a be aware of that. // Otherwise, when we use $a, we only know that the shape of $a is 16x16. If // we store strides into an attribute array of integers, the information // cannot pass through block argument assignment because attributes are - // associated with operations but not Values. + // associated with operations, not Values. // TODO(Keren): We may need to figure out a way to store strides as integers // if we want to support more optimizations. SmallVector strides; // i32 int. The strides of the shared memory object. - SmallVector offsets; // i32 int. The offsets of the shared memory - // objects from the originally allocated object. + SmallVector offsets; // i32 int. + // Offsets are applied at the last slicing operation. + // We can use offsets to recover the previous base. + // The offsets are zero at the initial allocation. SharedMemoryObject(Value base, ArrayRef strides, ArrayRef offsets) From 3a6dc5ad8d1298dd93883282bc88c06d9f6a01bc Mon Sep 17 00:00:00 2001 From: Jason Furmanek Date: Thu, 5 Oct 2023 21:05:49 +0000 Subject: [PATCH 116/122] resolve some merge conflicts fix more conflits Resolve merge conflicts Some more build and conflict fixes Resolve conflicts for 06-fused-attension.py resolve merge conflicts for the tutorial group gemm example Fixes for some LIT tests resolve remaining conflicts in tests Fix empty kernel set capability 0 --- .gitignore | 6 - CMakeLists.txt | 1 + README.md | 28 -- bin/CMakeLists.txt | 1 + bin/triton-translate.cpp | 4 +- .../Dialect/TritonGPU/IR/TritonGPUAttrDefs.td | 24 +- lib/Analysis/Allocation.cpp | 15 - lib/Analysis/Utility.cpp | 20 -- lib/Conversion/TritonGPUToLLVM/CMakeLists.txt | 21 +- .../TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp | 3 - .../TritonGPUToLLVM/ElementwiseOpToLLVM.cpp | 7 +- .../TritonGPUToLLVM/ReduceOpToLLVM.cpp | 209 +------------ .../TritonGPUToLLVM/ScanOpToLLVM.cpp | 43 +-- lib/Conversion/TritonGPUToLLVM/Utility.cpp | 73 ++--- lib/Conversion/TritonGPUToLLVM/Utility.h | 5 - lib/Dialect/TritonGPU/IR/Dialect.cpp | 29 +- .../Transforms/AccelerateAMDMatmul.cpp | 3 +- .../Transforms/RemoveLayoutConversions.cpp | 4 - lib/Target/CMakeLists.txt | 1 + lib/Target/LLVMIR/LLVMIRTranslation.cpp | 16 +- python/src/triton.cc | 6 +- python/test/unit/language/test_core.py | 83 ------ .../unit/operators/test_flash_attention.py | 18 +- python/triton/compiler/compiler.py | 185 ++---------- python/triton/compiler/make_launcher.py | 4 - python/triton/language/semantic.py | 39 +-- python/triton/runtime/jit.py | 92 +----- python/tutorials/06-fused-attention.py | 281 ++++-------------- python/tutorials/11-grouped-gemm.py | 50 ---- test/Conversion/minimize_alloc.mlir | 4 +- test/Conversion/tritongpu_to_llvm.mlir | 74 ++--- test/TritonGPU/stream-pipeline.mlir | 16 +- 32 files changed, 189 insertions(+), 1176 deletions(-) diff --git a/.gitignore b/.gitignore index 47fe5db52678..0180cd911245 100644 --- a/.gitignore +++ b/.gitignore @@ -30,11 +30,6 @@ cuobjdump nvdisasm ptxas -<<<<<<< HEAD -# HIP -log* -python/triton/third_party/cuda/bin/ptxas -======= # Docs docs/_build/ docs/python-api/generated/ @@ -49,4 +44,3 @@ docs/getting-started/tutorials /compile_commands.json .vscode .vs ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 diff --git a/CMakeLists.txt b/CMakeLists.txt index d4c2e4a064ad..3ac51329bb03 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -248,6 +248,7 @@ if(TRITON_BUILD_PYTHON_MODULE) TritonNvidiaGPUTransforms TritonLLVMIR TritonPTX + TritonHSACO ${dialect_libs} ${conversion_libs} diff --git a/README.md b/README.md index 7a0df5bc8d9e..52bb5b41e443 100644 --- a/README.md +++ b/README.md @@ -60,11 +60,6 @@ lit -v test ``` git clone https://github.com/openai/triton.git; -<<<<<<< HEAD -cd triton/python; -pip install ninja cmake; # build-time dependencies -pip install -e . -======= cd triton; pip install ninja cmake wheel; # build-time dependencies @@ -82,7 +77,6 @@ source .venv/bin/activate; pip install ninja cmake wheel; # build-time dependencies pip install -e python ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 ``` # Building with a custom LLVM @@ -105,27 +99,6 @@ arbitrary LLVM version. modifications to LLVM. 3. [Build LLVM](https://llvm.org/docs/CMake.html). For example, you might run -<<<<<<< HEAD - - $ cd $HOME/llvm-project # your clone of LLVM. - $ mkdir build - $ cd build - $ cmake -G Ninja -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=ON ../llvm -DLLVM_ENABLE_PROJECTS="mlir;llvm" - $ ninja - -4. Grab a snack, this will take a while. - -5. Build Triton as above, but set the following environment variables. - - # Modify as appropriate to point to your LLVM build. - $ export LLVM_BUILD_DIR=$HOME/llvm-project/build - - $ cd /python - $ LLVM_INCLUDE_DIRS=$LLVM_BUILD_DIR/include \ - LLVM_LIBRARY_DIR=$LLVM_BUILD_DIR/lib \ - LLVM_SYSPATH=$LLVM_BUILD_DIR \ - pip install -e . -======= $ cd $HOME/llvm-project # your clone of LLVM. $ mkdir build @@ -180,7 +153,6 @@ $ ninja test # Run lit tests. $ lit test ``` ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 # Changelog diff --git a/bin/CMakeLists.txt b/bin/CMakeLists.txt index 9da8e5628667..075d71282715 100644 --- a/bin/CMakeLists.txt +++ b/bin/CMakeLists.txt @@ -53,6 +53,7 @@ llvm_update_compile_flags(triton-translate) TritonNvidiaGPUTransforms TritonLLVMIR TritonPTX + TritonHSACO ${dialect_libs} ${conversion_libs} # tests diff --git a/bin/triton-translate.cpp b/bin/triton-translate.cpp index 1b81425ce209..ee1da6a20fbc 100644 --- a/bin/triton-translate.cpp +++ b/bin/triton-translate.cpp @@ -15,6 +15,7 @@ #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Target/HSACO/HSACOTranslation.h" #include "triton/Target/LLVMIR/LLVMIRTranslation.h" #include "triton/Target/PTX/PTXTranslation.h" #include "llvm/IR/LLVMContext.h" @@ -142,14 +143,11 @@ LogicalResult tritonTranslateMain(int argc, char **argv, } else if (targetKind == "ptx") { llvm::outs() << ::triton::translateLLVMIRToPTX(*llvmir, SMArch.getValue(), ptxVersion.getValue()); -<<<<<<< HEAD } else if (targetKind == "hsaco") { auto [module, hsaco] = mlir::triton::translateLLVMIRToHSACO( *llvmir, GCNArch.getValue(), GCNTriple.getValue(), GCNFeatures.getValue()); llvm::outs() << hsaco; -======= ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 } else { llvm::errs() << "Error: Unknown target specified: " << targetKind << "\n"; return failure(); diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index 8e24c915a9ab..1e4f9d9f0760 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -113,7 +113,17 @@ compared to 1*64 when the hasLeadingOffset is false. "ArrayRef":$order, "CTALayoutAttr":$CTALayout, "unsigned":$typeWidthInBit), [{ -<<<<<<< HEAD + bool needTrans = false; // default value + return get(context, dotOpEnc, shape, order, CTALayout, typeWidthInBit, needTrans); + }]>, + + + AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc, + "ArrayRef":$shape, + "ArrayRef":$order, + "CTALayoutAttr":$CTALayout, + "unsigned":$typeWidthInBit, + "bool":$needTrans), [{ #ifdef USE_ROCM // ---- begin GFX908/GFX90A ---- @@ -156,18 +166,6 @@ compared to 1*64 when the hasLeadingOffset is false. } } #endif -======= - bool needTrans = false; // default value - return get(context, dotOpEnc, shape, order, CTALayout, typeWidthInBit, needTrans); - }]>, - - AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc, - "ArrayRef":$shape, - "ArrayRef":$order, - "CTALayoutAttr":$CTALayout, - "unsigned":$typeWidthInBit, - "bool":$needTrans), [{ ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 auto mmaEnc = dotOpEnc.getParent().dyn_cast(); if(!mmaEnc) diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index bb4efa69dbdd..1ed0c7658de7 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -18,11 +18,8 @@ using ::mlir::triton::gpu::getOrder; using ::mlir::triton::gpu::getShapePerCTA; using ::mlir::triton::gpu::getShapePerCTATile; using ::mlir::triton::gpu::getSizePerThread; -<<<<<<< HEAD using ::mlir::triton::gpu::MfmaEncodingAttr; -======= using ::mlir::triton::gpu::getUniqueContigPerThread; ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 using ::mlir::triton::gpu::MmaEncodingAttr; using ::mlir::triton::gpu::SharedEncodingAttr; using ::mlir::triton::gpu::SliceEncodingAttr; @@ -79,7 +76,6 @@ SmallVector getRepShapeForCvtLayout(triton::gpu::ConvertLayoutOp op) { } } -<<<<<<< HEAD #ifdef USE_ROCM if (srcLayout.isa() && srcLayout.dyn_cast().getIsTransposed() && @@ -88,18 +84,7 @@ SmallVector getRepShapeForCvtLayout(triton::gpu::ConvertLayoutOp op) { return {}; #endif - assert(srcLayout && dstLayout && - "Unexpected layout in getScratchConfigForCvtLayout()"); - auto [inOrd, outOrd] = getCvtOrder(srcLayout, dstLayout); - unsigned srcContigPerThread = getContigPerThread(srcLayout)[inOrd[0]]; - unsigned dstContigPerThread = getContigPerThread(dstLayout)[outOrd[0]]; - // TODO: Fix the legacy issue that ourOrd[0] == 0 always means - // that we cannot do vectorization. - inVec = outOrd[0] == 0 ? 1 : inOrd[0] == 0 ? 1 : srcContigPerThread; - outVec = outOrd[0] == 0 ? 1 : dstContigPerThread; -======= assert(srcLayout && dstLayout && "Unexpected layout in getRepShape()"); ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 auto srcShapePerCTA = getShapePerCTA(srcTy); auto dstShapePerCTA = getShapePerCTA(dstTy); diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index da2323660eda..8b09e7293614 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -347,15 +347,9 @@ unsigned ScanLoweringHelper::getAxisBlockStride() { for (unsigned dim : order) { if (dim == getAxis()) return stride; -<<<<<<< HEAD - stride *= ceil(type.getShape()[dim], sizePerThreads[dim] * - threadsPerWarp[dim] * - warpsPerCTA[dim]); -======= stride *= ceil(getShape()[dim], sizePerThreads[dim] * threadsPerWarp[dim] * warpsPerCTA[dim]); ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 } llvm_unreachable("Axis not found in order"); } @@ -543,7 +537,6 @@ bool isMmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) { !srcTy.getElementType().isF32(); } -<<<<<<< HEAD #ifdef USE_ROCM bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) { auto srcLayout = srcTy.getEncoding(); @@ -562,19 +555,6 @@ bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) { } #endif -bool isMmaToMmaShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) { - auto src = srcTy.getEncoding().cast(); - auto dst = dstTy.getEncoding().cast(); - auto srcElemsPerThread = triton::gpu::getTotalElemsPerThread(srcTy); - auto dstElemsPerThread = triton::gpu::getTotalElemsPerThread(dstTy); - // when #mma = MmaEncoding - return src.getVersionMajor() == 3 && src.getWarpsPerCTA()[1] == 1 && - dst.getVersionMajor() == 3 && dst.getWarpsPerCTA()[1] == 1 && - srcElemsPerThread == dstElemsPerThread; -} - -======= ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 bool isSingleValue(Value value) { // Don't consider load as expensive if it is loading a scalar. if (auto tensorTy = value.getType().dyn_cast()) diff --git a/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt b/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt index 68c8dfa71a37..794055e21164 100644 --- a/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt +++ b/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt @@ -1,47 +1,30 @@ -<<<<<<< HEAD add_library(rocm_libraries SHARED IMPORTED ) set_target_properties(rocm_libraries PROPERTIES IMPORTED_LOCATION ${ROCM_LIBRARIES}) -======= # Separate out PTX/GCN builders to avoid cyclic dependencies as TritonAnalysis # depends on it. set(LLVM_OPTIONAL_SOURCES GCNAsmFormat.cpp PTXAsmFormat.cpp ) ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 add_mlir_conversion_library(TritonGPUToLLVM ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp ConvertLayoutOpToLLVM/SharedToDotOperandMMAv1.cpp ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp + ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp ConvertLayoutOpToLLVM.cpp DotOpToLLVM/FMA.cpp DotOpToLLVM/MMAv1.cpp DotOpToLLVM/MMAv2.cpp DotOpToLLVM/WGMMA.cpp + DotOpToLLVM/MFMA.cpp DotOpToLLVM.cpp ElementwiseOpToLLVM.cpp LoadStoreOpToLLVM.cpp BarrierOpToLLVM.cpp TritonGPUToLLVM.cpp -<<<<<<< HEAD GCNAsmFormat.cpp PTXAsmFormat.cpp - ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp - ConvertLayoutOpToLLVM/SharedToDotOperandMMAv1.cpp - ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp - ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp - ConvertLayoutOpToLLVM.cpp - DotOpToLLVM/FMA.cpp - DotOpToLLVM/MMAv1.cpp - DotOpToLLVM/MMAv2.cpp - DotOpToLLVM/MFMA.cpp - DotOpToLLVM.cpp - ElementwiseOpToLLVM.cpp - LoadStoreOpToLLVM.cpp - TritonGPUToLLVM.cpp -======= ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 TritonGPUToLLVMPass.cpp ReduceOpToLLVM.cpp ScanOpToLLVM.cpp diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index de08b7753f79..14212fce79b1 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -1059,7 +1059,6 @@ struct ConvertLayoutOpConversion return failure(); } -<<<<<<< HEAD #ifdef USE_ROCM // shared -> dot_operand if the result layout is mma Value lowerSharedToDotOperandMFMA( @@ -1084,7 +1083,6 @@ struct ConvertLayoutOpConversion return res; } #endif -======= // mma -> mma LogicalResult lowerMmaToMma(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, @@ -1121,7 +1119,6 @@ struct ConvertLayoutOpConversion rewriter.replaceOp(op, view); return success(); } ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 // shared -> dot_operand if the result layout is mma Value lowerSharedToDotOperandMMA( diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp index b84e0a5cafe7..6bb16fe46e1b 100644 --- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -912,6 +912,7 @@ const std::string Fp16_to_Fp8E4M3Nv = "{ \n" "cvt.rn.satfinite.e4m3x2.f16x2 $0, $1; \n" "}"; +#ifndef USE_ROCM // Fp8E4M3 (x2) -> Fp16 (x2) (packed) const std::string Fp8E4M3Nv_to_Bf16 = "{ \n" @@ -937,7 +938,6 @@ const std::string Bf16_to_Fp8E4M3Nv = "}"; /* ----- Packed integer to BF16 ------ */ -#ifndef USE_ROCM const std::string S8_to_Bf16 = "{ \n" ".reg .s8 s<4>; \n" @@ -1398,12 +1398,15 @@ struct FpToFpOpConversion {{F16TyID, F8E5M2FNUZTyID}, Fp16_to_Fp8E5M2FNUZ}, // F8 -> BF16 {{F8E5M2TyID, BF16TyID}, Fp8E5M2_to_Bf16}, +#ifndef USE_ROCM {{F8E4M3TyID, BF16TyID}, Fp8E4M3Nv_to_Bf16}, +#endif // BF16 -> F8 {{BF16TyID, F8E5M2TyID}, Bf16_to_Fp8E5M2}, +#ifndef USE_ROCM {{BF16TyID, F8E4M3TyID}, Bf16_to_Fp8E4M3Nv}, +#endif }; - int inVecWidthBits = 32; int outVecWidthBits = 32; if (srcTy.isFloat8E4M3FNUZ()) { diff --git a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp index 1fdfed3883b8..fb148c4237f8 100644 --- a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp @@ -133,109 +133,10 @@ struct ReduceOpConversion return srcValues; } -<<<<<<< HEAD - // Calculates the write index in the shared memory where we would be writing - // the within-thread accumulations before we start doing across-threads - // accumulations. `index` is the index of the within-thread accumulations in - // the full tensor, whereas `writeIdx` is the mapped-to index in the shared - // memory - void getWriteIndexBasic(ConversionPatternRewriter &rewriter, Location loc, - Attribute layout, SmallVector &index, - SmallVector &writeIdx, - std::map &ints, unsigned originalAxis, - unsigned axis) const { - if (auto sliceLayout = layout.dyn_cast()) { - // Recover the axis in the parent layout - auto parentAxis = axis < sliceLayout.getDim() ? axis : axis + 1; - auto parentLayout = sliceLayout.getParent(); - getWriteIndexBasic(rewriter, loc, parentLayout, index, writeIdx, ints, - originalAxis, parentAxis); - return; - } - - writeIdx = index; - auto sizePerThread = triton::gpu::getSizePerThread(layout); - Value axisSizePerThread = ints[sizePerThread[axis]]; - Value _8 = ints[8]; - Value _16 = ints[16]; -#ifdef USE_ROCM - Value _2 = ints[2]; - Value _4 = ints[4]; - Value _32 = ints[32]; -#endif - - if (layout.isa()) { - // A single thread owns axisSizePerThread contiguous values - // on the reduction axis. After within thread reduction, - // we would have a single accumulation every `axisSizePerThread` - // contiguous values in the original tensor, so we would need - // to map every `axisSizePerThread` to 1 value in smem as: - // writeIdx[originalAxis] = index[originalAxis] / axisSizePerThread - writeIdx[originalAxis] = udiv(index[originalAxis], axisSizePerThread); - } else if (auto mmaLayout = layout.dyn_cast()) { - if (!mmaLayout.isAmpere() && !mmaLayout.isHopper()) { - llvm::report_fatal_error("Unsupported layout"); - } - if (originalAxis == 0) { - // Because warpTileSize = [16, 8] and threadsPerWarp = [8, 4], each 8 - // rows in smem would correspond to a warp. The mapping - // is: (warp_index) x 8 + (row index within warp) - writeIdx[originalAxis] = add(mul(udiv(index[originalAxis], _16), _8), - urem(index[originalAxis], _8)); - } else { - // Same as BlockedEncodingAttr case - writeIdx[originalAxis] = udiv(index[originalAxis], axisSizePerThread); - } - } else if (auto mfmaLayout = layout.dyn_cast()) { - // TODO: Support MFMA transposed layout. - if (axis == 0) { - // Because warpTileSize = [32, 32] and threadsPerWarp = [2, 32], each 2 - // rows in smem would correspond to a warp. The mapping - // is: (warp_index) x 2 + (row index within warp) - writeIdx[axis] = add(mul(udiv(index[axis], _32), _2), - udiv(urem(index[axis], _32), _4)); - } else { - // Same as BlockedEncodingAttr case - writeIdx[axis] = udiv(index[axis], axisSizePerThread); - } - } else { - llvm::report_fatal_error("Unsupported layout"); - } - } - - // Use shared memory for reduction within warps and across warps - LogicalResult - matchAndRewriteBasic(triton::ReduceOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - ReduceOpHelper helper(op); - Location loc = op.getLoc(); - unsigned axis = op.getAxis(); - - auto srcTys = op.getInputTypes(); - auto srcLayout = helper.getSrcLayout(); - if (!helper.isSupportedLayout()) { - assert(false && "Unexpected srcLayout in ReduceOpConversion"); - } - // The order of the axes for the the threads within the warp - auto srcOrd = triton::gpu::getOrder(srcLayout); - auto sizePerThread = triton::gpu::getSizePerThread(srcLayout); - auto srcShape = helper.getSrcShape(); - - SmallVector elemPtrTys(srcTys.size()); - for (unsigned i = 0; i < op.getNumOperands(); ++i) { - auto ty = srcTys[i].getElementType(); - auto llvmElemTy = getTypeConverter()->convertType(ty); - elemPtrTys[i] = LLVM::LLVMPointerType::get(llvmElemTy, 3); - } - auto llvmIndexTy = getTypeConverter()->getIndexType(); - - auto smemShape = helper.getScratchConfigBasic(); -======= SmallVector getSmemBases(ReduceOpHelper &helper, triton::ReduceOp op, SmallVector smemShape, ConversionPatternRewriter &rewriter) const { auto loc = op.getLoc(); ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 unsigned elems = product(smemShape); // indices will store the index of the op operands in descending order // of their bitwidths @@ -251,93 +152,10 @@ struct ReduceOpConversion bitcast(getSharedMemoryBase(loc, rewriter, op.getOperation()), getElementPtrType(op, indices[0])); for (unsigned i = 1; i < op.getNumOperands(); ++i) { -<<<<<<< HEAD - smemBases[i] = - bitcast(gep(elemPtrTys[i - 1], smemBases[i - 1], i32_val(elems)), - elemPtrTys[i]); - } - - auto srcValues = unpackInputs(loc, op, adaptor, rewriter); - std::map, SmallVector> accs; - std::map, SmallVector> indices; - reduceWithinThreads(helper, srcValues, accs, indices, rewriter); - - // cached int32 constants - std::map ints; - ints[0] = i32_val(0); - for (int N = smemShape[axis] / 2; N > 0; N >>= 1) - ints[N] = i32_val(N); - ints[sizePerThread[axis]] = i32_val(sizePerThread[axis]); - ints[8] = i32_val(8); - ints[16] = i32_val(16); -#ifdef USE_ROCM - ints[2] = i32_val(2); - ints[4] = i32_val(4); - ints[32] = i32_val(32); -#endif - // reduce across threads - for (auto it : accs) { - const SmallVector &key = it.first; - auto &acc = it.second; - // get the writeIdx at which to write in smem - SmallVector writeIdx; - getWriteIndexBasic(rewriter, loc, srcLayout, indices[key], writeIdx, ints, - axis, axis); - - // calculate the offset in smem for that writeIdx - Value writeOffset = linearize(rewriter, loc, writeIdx, smemShape, srcOrd); - SmallVector writePtrs(op.getNumOperands()); - for (unsigned i = 0; i < op.getNumOperands(); ++i) { - // Store the within-thread accumulated value into shared memory - writePtrs[i] = gep(elemPtrTys[i], smemBases[i], writeOffset); - store(acc[i], writePtrs[i]); - } - - SmallVector readIdx(writeIdx.size(), ints[0]); - // Perform parallel reduction with sequential addressing - // E.g. We reduce `smemShape[axis]` elements into `smemShape[axis]/2` - // elements using `smemShape[axis]/2` threads where each thread - // would accumalte values that are `smemShape[axis]/2` apart - // to avoid bank conflicts. Then we repeat with `smemShape[axis]/4` - // threads, .. etc. - for (int N = smemShape[axis] / 2; N > 0; N >>= 1) { - // The readIdx will be N elements away on the reduction axis - readIdx[axis] = ints[N]; - // If the writeIdx is greater or equal to N, do nothing - Value readMask = icmp_slt(writeIdx[axis], ints[N]); - // Calculate the readOffset, if readMask is False, readOffset=0 - // meaning we reduce the value at writeIdx with itself - Value readOffset = select( - readMask, linearize(rewriter, loc, readIdx, smemShape, srcOrd), - ints[0]); - SmallVector readPtrs(op.getNumOperands()); - for (unsigned i = 0; i < op.getNumOperands(); ++i) { - // The readPtr is readOffset away from writePtr - readPtrs[i] = gep(elemPtrTys[i], writePtrs[i], readOffset); - } - - sync(rewriter, loc, op); - - // Combine accumulator value from another thread - SmallVector cur(op.getNumOperands()); - for (unsigned i = 0; i < op.getNumOperands(); ++i) { - cur[i] = load(readPtrs[i]); - } - accumulate(rewriter, op.getCombineOp(), acc, cur, false); - - sync(rewriter, loc, op); - - // Publish our new accumulator value to shared memory - for (unsigned i = 0; i < op.getNumOperands(); ++i) { - store(acc[i], writePtrs[i]); - } - } -======= indexToBase[indices[i]] = bitcast(gep(getElementPtrType(op, indices[i - 1]), indexToBase[indices[i - 1]], i32_val(elems)), getElementPtrType(op, indices[i])); ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 } // smemBases[k] is the base pointer for the k-th operand SmallVector smemBases(op.getNumOperands()); @@ -485,11 +303,7 @@ struct ReduceOpConversion } #endif for (unsigned i = 0; i < acc.size(); ++i) { -<<<<<<< HEAD - shfl[i] = shflSync(loc, rewriter, acc[i], shuffleIdx); -======= - shfl[i] = shflSync(loc, rewriter, acc[i], N * interleave); ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 + shfl[i] = shflSync(loc, rewriter, acc[i], shuffleIdx * interleave); } accumulate(rewriter, op.getCombineOp(), acc, shfl, false); } @@ -597,27 +411,6 @@ struct ReduceOpConversion auto order = getOrder(srcLayout); SmallVector multiDimLaneId = delinearize(rewriter, loc, laneId, threadsPerWarp, order); -<<<<<<< HEAD - SmallVector multiDimWarpId = - delinearize(rewriter, loc, warpId, warpsPerCTA, order); - -#ifdef USE_ROCM - auto srcTys = op.getInputTypes(); - auto inputTy = srcTys[0].cast(); - auto inMfma = - inputTy.getEncoding().dyn_cast(); - // Original logic is buggy for warpsPerCTA=[2, 2], but works fine for - // warpsPerCTA=[4, 1] (that is used in flash attention, thus tested). - // TODO: Check whether this is the case for MMA layout as well, if yes, this - // should be fixed in the upstream repo. - if (inMfma) { - multiDimLaneId = delinearize(rewriter, loc, laneId, threadsPerWarp); - multiDimWarpId = delinearize(rewriter, loc, warpId, warpsPerCTA); - } -#endif - -======= ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 Value laneIdAxis = multiDimLaneId[axis]; Value zero = i32_val(0); Value laneZero = icmp_eq(laneIdAxis, zero); diff --git a/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp index c1276a4de15c..b810ab9d1d6a 100644 --- a/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp @@ -62,12 +62,8 @@ static void scanThreadContiguousElements(SmallVector &srcValues, // contiguous group of elements. static void warpScan(SmallVector &srcValues, ConversionPatternRewriter &rewriter, -<<<<<<< HEAD ScanLoweringHelper &helper, Value laneIdAxis, - Value laneId) { -======= - ScanLoweringHelper &helper, Value laneIdAxis) { ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 + Value laneId) { Location loc = helper.getLoc(); unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread(); unsigned elementStride = helper.getAxisElementStride(); @@ -81,14 +77,8 @@ static void warpScan(SmallVector &srcValues, // Reduce within warps. Value acc = srcValues[srcIndex]; for (unsigned i = 1; i <= (scanDim) / 2; i = i << 1) { -<<<<<<< HEAD Value shfl = shflUpSync(loc, rewriter, acc, i * threadStride, laneId); - Value tempAcc = acc; - accumulate(rewriter, helper.getCombineOp(), tempAcc, shfl); -======= - Value shfl = shflUpSync(loc, rewriter, acc, i * threadStride); Value tempAcc = accumulate(rewriter, helper.getCombineOp(), shfl, acc); ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 Value mask = icmp_slt(laneIdAxis, i32_val(i)); acc = select(mask, acc, tempAcc); } @@ -138,11 +128,7 @@ static void AddPartialReduce(SmallVector &srcValues, ConversionPatternRewriter &rewriter, ScanLoweringHelper &helper, Value sharedMemoryPtr, Value warpId, Value laneIdAxis, -<<<<<<< HEAD Value parallelLaneId, Value laneId) { -======= - Value parallelLaneId) { ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 Location loc = helper.getLoc(); unsigned numParallelLane = helper.getNonAxisNumThreadsPerCTA(); unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread(); @@ -232,7 +218,7 @@ static void AddPartialReduce(SmallVector &srcValues, static void AddPartialReduceOneWarp(SmallVector &srcValues, ConversionPatternRewriter &rewriter, ScanLoweringHelper &helper, Value warpId, - Value laneIdAxis, Value laneIdLast) { + Value laneIdAxis, Value laneIdLast, Value laneId) { Location loc = helper.getLoc(); unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread(); unsigned parallelElementsPerThread = helper.getNonAxisNumElementsPerThread(); @@ -275,7 +261,7 @@ static void AddPartialReduceOneWarp(SmallVector &srcValues, Value lastElement = srcValues[srcIndex]; if (scanDim > 1) { lastElement = - shflUpSync(loc, rewriter, srcValues[srcIndex], threadStride); + shflUpSync(loc, rewriter, srcValues[srcIndex], threadStride, laneId); lastElement = select(maskFirstLane, accumulator, lastElement); if (numScanBlocks > 1) // Update accumulator with the value from the last lane. @@ -401,10 +387,6 @@ ScanOpConversion::emitFastScan(triton::ScanOp op, triton::ScanOpAdaptor adaptor, if (!helper.isSupported()) return failure(); -<<<<<<< HEAD - // Obtain global laneId and pass it around -======= ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 Value threadId = getThreadId(rewriter, loc); auto mod = op->getParentOfType(); unsigned iWarpSize = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); @@ -427,20 +409,6 @@ ScanOpConversion::emitFastScan(triton::ScanOp op, triton::ScanOpAdaptor adaptor, // elements. warpScan(srcValues, rewriter, helper, laneIdAxis, laneId); -<<<<<<< HEAD - // Store the partial reducing for each warp into shared memory. - Type elemPtrTys = LLVM::LLVMPointerType::get(srcValues[0].getType(), 3); - Value baseSharedMemPtr = bitcast( - getSharedMemoryBase(loc, rewriter, op.getOperation()), elemPtrTys); - storeWarpAccumulator(srcValues, rewriter, helper, laneIdAxis, warpIdAxis, - baseSharedMemPtr, flatIdParallel); - barrier(); - // Read back the partial reduction of each warp and accumulate them based on - // warpId. Then update each chunk of contiguous elements by adding the - // accumulated value from the previous lane. - AddPartialReduce(srcValues, rewriter, helper, baseSharedMemPtr, warpIdAxis, - laneIdAxis, flatIdParallel, laneId); -======= if (axisNumWarps > 1) { // Slow path for the case where there are multiple warps with unique data on // the axis. @@ -455,7 +423,7 @@ ScanOpConversion::emitFastScan(triton::ScanOp op, triton::ScanOpAdaptor adaptor, // warpId. Then update each chunk of contiguous elements by adding the // accumulated value from the previous lane. AddPartialReduce(srcValues, rewriter, helper, baseSharedMemPtr, warpIdAxis, - laneIdAxis, flatIdParallel); + laneIdAxis, flatIdParallel, laneId); } else if (srcValues.size() > 1) { // Fast path for the case where there is only one warp with unique data on // the axis. @@ -466,9 +434,8 @@ ScanOpConversion::emitFastScan(triton::ScanOp op, triton::ScanOpAdaptor adaptor, auto laneIdLast = linearize(rewriter, loc, multiDimLaneId, threadsPerWarp, triton::gpu::getOrder(helper.getEncoding())); AddPartialReduceOneWarp(srcValues, rewriter, helper, warpIdAxis, laneIdAxis, - laneIdLast); + laneIdLast, laneId); } // else axisNumWarps == 1 and srcValues.size() == 1, nothing to do. ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 Value results = getTypeConverter()->packLLElements(loc, srcValues, rewriter, input.getType()); diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index 5385dd1350c6..f0ced79ce0e9 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -275,21 +275,24 @@ Value loadShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr, } static Value commonShflSync(Location loc, ConversionPatternRewriter &rewriter, -<<<<<<< HEAD - Value val, int i, const std::string &shuffleType, - const std::string &clamp, Value laneId = Value()) { -======= - Value val, Value i, NVVM::ShflKind mode, - Value clamp) { ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 + Value val, Value i, int strideInt, NVVM::ShflKind mode, + Value clamp, Value laneId = Value()) { unsigned bits = val.getType().getIntOrFloatBitWidth(); + //int stride = i.cast(); + //int stride = i.dyn_cast(); + //int stride = i.cast().getValue().getSExtValue(); + //int stride = i.Value(); + //constantOp.getValue().cast().getValue().getSExtValue(); + //unsigned strideint = i.cast().getValue().getSExtValue(); + //auto intAttr = i.dyn_cast_or_null(); + //auto strideint = intAttr.getValue().getSExtValue(); #ifdef USE_ROCM //On AMD, the ds_swizzle_b32 and ds_permute_b32 instructions work on 32bit/dwords //so we need promote to 32 here. if (bits == 8) { Value i32Val = sext(i32_ty, val); - Value result = commonShflSync(loc, rewriter, i32Val, i, shuffleType, clamp, laneId); + Value result = commonShflSync(loc, rewriter, i32Val, i, strideInt, mode, clamp, laneId); return trunc(i8_ty, result); } #endif @@ -299,24 +302,19 @@ static Value commonShflSync(Location loc, ConversionPatternRewriter &rewriter, Value vec = bitcast(val, vecTy); Value val0 = extract_element(f32_ty, vec, i32_val(0)); Value val1 = extract_element(f32_ty, vec, i32_val(1)); -<<<<<<< HEAD - val0 = commonShflSync(loc, rewriter, val0, i, shuffleType, clamp, laneId); - val1 = commonShflSync(loc, rewriter, val1, i, shuffleType, clamp, laneId); -======= - val0 = commonShflSync(loc, rewriter, val0, i, mode, clamp); - val1 = commonShflSync(loc, rewriter, val1, i, mode, clamp); ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 + val0 = commonShflSync(loc, rewriter, val0, i, strideInt, mode, clamp, laneId); + val1 = commonShflSync(loc, rewriter, val1, i, strideInt, mode, clamp, laneId); vec = undef(vecTy); vec = insert_element(vecTy, vec, val0, i32_val(0)); vec = insert_element(vecTy, vec, val1, i32_val(1)); return bitcast(vec, val.getType()); } -<<<<<<< HEAD #ifdef USE_ROCM GCNBuilder builder; - if (shuffleType == "bfly") { - if (i > 16) { + switch (mode) { + case NVVM::ShflKind::bfly: + if (strideInt > 16) { Value threadId = rewriter .create( @@ -342,13 +340,14 @@ static Value commonShflSync(Location loc, ConversionPatternRewriter &rewriter, auto dOpr = builder.newOperand("=v"); auto aOpr = builder.newOperand(val, "v"); auto maskOpr = - builder.newConstantOperand("offset:" + std::to_string(masks[i])); + builder.newConstantOperand("offset:" + std::to_string(masks[strideInt])); (*shfl)(dOpr, aOpr, maskOpr); } - } else { // shuffle_up - assert(shuffleType == "up" && "Only shfl_bfly and shfl_up are supported"); - Value mask = icmp_slt(laneId, i32_val(i)); - Value delta = sub(laneId, i32_val(i)); + break; + case NVVM::ShflKind::up: + //assert(shuffleType == "up" && "Only shfl_bfly and shfl_up are supported"); + Value mask = icmp_slt(laneId, i); + Value delta = sub(laneId, i); Value index = select(mask, laneId, delta); Value byteOffset = i32_val(2); Value permuteAddr = shl(index, byteOffset); @@ -357,22 +356,13 @@ static Value commonShflSync(Location loc, ConversionPatternRewriter &rewriter, auto addrOpr = builder.newOperand(permuteAddr, "v"); auto aOpr = builder.newOperand(val, "v"); (*shfl)(dOpr, addrOpr, aOpr); + break; } + auto swait = builder.create("s_waitcnt lgkmcnt(0)"); (*swait)(); return builder.launch(rewriter, loc, val.getType(), true); #else - PTXBuilder builder; - auto &shfl = builder.create("shfl.sync")->o(shuffleType).o("b32"); - auto *dOpr = builder.newOperand("=r"); - auto *aOpr = builder.newOperand(val, "r"); - auto *bOpr = builder.newConstantOperand(i); - auto *cOpr = builder.newConstantOperand(clamp); - auto *maskOpr = builder.newConstantOperand("0xffffffff"); - shfl(dOpr, aOpr, bOpr, cOpr, maskOpr); - return builder.launch(rewriter, loc, val.getType(), false); -#endif -======= Type type = val.getType(); if (type != i32_ty) { val = bitcast(val, int_ty(bits)); @@ -386,24 +376,19 @@ static Value commonShflSync(Location loc, ConversionPatternRewriter &rewriter, result = bitcast(result, type); } return result; ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 +#endif } Value shflSync(Location loc, ConversionPatternRewriter &rewriter, Value val, int i) { - return commonShflSync(loc, rewriter, val, i32_val(i), NVVM::ShflKind::bfly, + return commonShflSync(loc, rewriter, val, i32_val(i), i, NVVM::ShflKind::bfly, i32_val(0x1f)); } Value shflUpSync(Location loc, ConversionPatternRewriter &rewriter, Value val, -<<<<<<< HEAD int i, Value laneId) { - return commonShflSync(loc, rewriter, val, i, "up", "0x0", laneId); -======= - int i) { - return commonShflSync(loc, rewriter, val, i32_val(i), NVVM::ShflKind::up, - i32_val(0x0)); ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 + return commonShflSync(loc, rewriter, val, i32_val(i), i, NVVM::ShflKind::up, + i32_val(0x0), laneId); } Value shflIdxSync(Location loc, ConversionPatternRewriter &rewriter, Value val, @@ -413,7 +398,7 @@ Value shflIdxSync(Location loc, ConversionPatternRewriter &rewriter, Value val, Value shflIdxSync(Location loc, ConversionPatternRewriter &rewriter, Value val, Value i) { - return commonShflSync(loc, rewriter, val, i, NVVM::ShflKind::idx, + return commonShflSync(loc, rewriter, val, i, 0, NVVM::ShflKind::idx, i32_val(0x1f)); } diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.h b/lib/Conversion/TritonGPUToLLVM/Utility.h index 31d7950c421b..fd69a4c63a6a 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.h +++ b/lib/Conversion/TritonGPUToLLVM/Utility.h @@ -337,16 +337,11 @@ Value loadShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr, Value shflSync(Location loc, ConversionPatternRewriter &rewriter, Value val, int i); Value shflUpSync(Location loc, ConversionPatternRewriter &rewriter, Value val, -<<<<<<< HEAD int i, Value laneId); - -======= - int i); Value shflIdxSync(Location loc, ConversionPatternRewriter &rewriter, Value val, int i); Value shflIdxSync(Location loc, ConversionPatternRewriter &rewriter, Value val, Value i); ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 Value getSRegValue(OpBuilder &b, Location loc, const std::string &sRegStr); Value addStringToModule(Location loc, ConversionPatternRewriter &rewriter, StringRef key, StringRef content); diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 113e709c16dd..8d584d4b47dc 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -1768,34 +1768,8 @@ struct TritonGPUInferLayoutInterface // Canonicalizer //===----------------------------------------------------------------------===// -<<<<<<< HEAD -LogicalResult ConvertLayoutOp::canonicalize(ConvertLayoutOp op, - PatternRewriter &rewriter) { - // we don't handle conversions to DotOperandEncodingAttr - // this is a heuristics to accommodate fused attention - auto srcType = op.getOperand().getType().cast(); - auto dstType = op.getType().cast(); - if (dstType.getEncoding().isa() && - (srcType.getEncoding().isa() || - srcType.getEncoding().isa())) - return mlir::failure(); - // for hopper MMAv3 - if (!op.use_empty()) { - bool hasDotUser = false; - for (Operation *dot : op.getResult().getUsers()) - if (isa(dot)) - hasDotUser = true; - - if (hasDotUser) { - if (dstType.getEncoding().isa() && - srcType.getEncoding().isa()) - return mlir::failure(); - } - } -======= struct CanonicalizeConvertFromView : public mlir::OpRewritePattern { ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 CanonicalizeConvertFromView(MLIRContext *context) : OpRewritePattern(context, 1) {} @@ -1829,7 +1803,8 @@ struct CanonicalizeConvertFromConvert auto srcType = op.getOperand().getType().cast(); auto dstType = op.getType().cast(); if (dstType.getEncoding().isa() && - srcType.getEncoding().isa()) + (srcType.getEncoding().isa() || + srcType.getEncoding().isa())) return mlir::failure(); // for hopper MMAv3 if (!op.use_empty()) { diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateAMDMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateAMDMatmul.cpp index 03a2e03fb184..35b5efc364af 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateAMDMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateAMDMatmul.cpp @@ -226,7 +226,8 @@ class BlockedToMFMA : public mlir::RewritePattern { a = rewriter.create(a.getLoc(), newAType, a); b = rewriter.create(b.getLoc(), newBType, b); auto newDot = rewriter.create(dotOp.getLoc(), newRetType, a, b, - newAcc, dotOp.getAllowTF32()); + newAcc, dotOp.getAllowTF32(), + dotOp.getMaxNumImpreciseAcc()); rewriter.replaceOpWithNewOp(op, oldRetType, newDot.getResult()); diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp index 71981b97ae35..989bc5aa0944 100644 --- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp @@ -318,7 +318,6 @@ void LayoutPropagation::initAnchorLayout() { if (tensorType.getEncoding().isa() && !hasConvertToMMATransisitiveUse(op, tensorType.getEncoding())) continue; -<<<<<<< HEAD #ifdef USE_ROCM // Workaround to not propagate MFMA layout in case there are // no chained dots MFMA layout is expensive to convert, so we want @@ -331,10 +330,7 @@ void LayoutPropagation::initAnchorLayout() { !hasConvertToMFMATransisitiveUse(op, tensorType.getEncoding())) continue; #endif - layouts.insert({result, tensorType.getEncoding()}); -======= layouts.insert({result, LayoutInfo(tensorType.getEncoding())}); ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 } } } diff --git a/lib/Target/CMakeLists.txt b/lib/Target/CMakeLists.txt index 9b24f0ff225b..99cf364fab4d 100644 --- a/lib/Target/CMakeLists.txt +++ b/lib/Target/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(LLVMIR) add_subdirectory(PTX) +add_subdirectory(HSACO) diff --git a/lib/Target/LLVMIR/LLVMIRTranslation.cpp b/lib/Target/LLVMIR/LLVMIRTranslation.cpp index 077e1a3af851..7f2fd303791e 100644 --- a/lib/Target/LLVMIR/LLVMIRTranslation.cpp +++ b/lib/Target/LLVMIR/LLVMIRTranslation.cpp @@ -1,10 +1,5 @@ #include "triton/Target/LLVMIR/LLVMIRTranslation.h" -<<<<<<< HEAD - -#include "mlir/Conversion/Passes.h" -======= #include "LLVMPasses.h" ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" #include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h" #include "mlir/Conversion/Passes.h" @@ -44,14 +39,9 @@ #include "llvm/Support/Error.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/SourceMgr.h" -<<<<<<< HEAD - -#include -======= #include "llvm/Target/TargetMachine.h" #include "llvm/Transforms/InstCombine/InstCombine.h" #include ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 #ifdef _WIN32 #define WIN32_LEAN_AND_MEAN #include @@ -469,12 +459,8 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext, pm.addPass(mlir::createConvertSCFToCFPass()); pm.addPass(mlir::createConvertIndexToLLVMPass()); pm.addPass( -<<<<<<< HEAD - createConvertTritonGPUToLLVMPass({computeCapability, &tmaInfos, target})); -#ifndef USE_ROCM -======= createConvertTritonGPUToLLVMPass(computeCapability, target, &tmaInfos)); ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 +#ifndef USE_ROCM pm.addPass(createConvertNVGPUToLLVMPass()); #endif pm.addPass(mlir::createArithToLLVMConversionPass()); diff --git a/python/src/triton.cc b/python/src/triton.cc index b05202a68472..149ef103424f 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -2096,9 +2096,7 @@ void init_triton_translation(py::module &m) { const std::vector &paths) { ::mlir::triton::addExternalLibs(op, names, paths); }); -} -<<<<<<< HEAD m.def( "translate_llvmir_to_hsaco", [](const std::string llvmIR, std::string gfx_arch, std::string gfx_triple, @@ -2116,7 +2114,8 @@ void init_triton_translation(py::module &m) { return hsacoCode; }, ret::take_ownership); -======= +} + void init_triton_interpreter(py::module &&m) { using ret = py::return_value_policy; @@ -2155,7 +2154,6 @@ void init_triton_interpreter(py::module &&m) { } } }); ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 } void init_triton(py::module &m) { diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 7f3530862e80..a8557b4e6430 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1316,10 +1316,7 @@ def test_cast(dtype_x, dtype_z, bitcast, size, num_ctas, device): if is_hip() and (dtype_z == "bfloat16"): pytest.skip(f'test_cast{(dtype_x, dtype_z)} cast to bfloat16 not supported on HIP.') -<<<<<<< HEAD size = 1024 -======= ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 # This is tricky because numpy doesn't have bfloat, and torch doesn't have uints. if dtype_x.startswith('bfloat'): x_tri = torch.randn(size, dtype=getattr(torch, dtype_x), device=device) @@ -1882,13 +1879,6 @@ def test_scan_layouts(M, N, src_layout, axis, device): @pytest.mark.parametrize("M, N", [[128, 16], [128, 128], [32, 128], [32, 32]]) @pytest.mark.parametrize("src_layout", layouts) @pytest.mark.parametrize("axis", [0, 1]) -<<<<<<< HEAD -def test_reduce_layouts(M, N, src_layout, axis, device): - if is_hip(): - pytest.skip("test_reduce_layouts is not supported in HIP") - - rdims_2d = f"1x{N}" if axis == 0 else f"{M}x1" -======= @pytest.mark.parametrize("reduce2d", [False, True]) @pytest.mark.parametrize("dtype_str", ["int32", "float32", "float16"]) @pytest.mark.parametrize("reduce_op", ["sum", "max"]) @@ -1907,7 +1897,6 @@ def test_reduce_layouts(M, N, src_layout, axis, reduce2d, dtype_str, reduce_op, "max": np.max, "sum": np.sum }[reduce_op] ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 rdims_1d = f"{N}" if axis == 0 else f"{M}" rdims_2d = f"1x{N}" if axis == 0 else f"{M}x1" store_range = "%7" if axis == 0 else "%1" @@ -1937,40 +1926,11 @@ def test_reduce_layouts(M, N, src_layout, axis, reduce2d, dtype_str, reduce_op, #blocked = {blocked} #src = {src_layout} module attributes {{"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32}} {{ -<<<<<<< HEAD - tt.func public @kernel_0d1d2c3d4c(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: i32 {{tt.divisibility = 16 : i32}}, %arg2: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ -======= tt.func public @kernel_0d1d2c3d4c(%arg0: !tt.ptr<{ty}, 1> {{tt.divisibility = 16 : i32}}, %arg1: i32 {{tt.divisibility = 16 : i32}}, %arg2: !tt.ptr<{ty}, 1> {{tt.divisibility = 16 : i32}}) {{ ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #blocked}}>> %1 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #blocked}}>>) -> tensor<{M}x1xi32, #blocked> %2 = tt.splat %arg1 : (i32) -> tensor<{M}x1xi32, #blocked> %3 = arith.muli %1, %2 : tensor<{M}x1xi32, #blocked> -<<<<<<< HEAD - %4 = tt.splat %arg0 : (!tt.ptr) -> tensor<{M}x1x!tt.ptr, #blocked> - %5 = tt.addptr %4, %3 : tensor<{M}x1x!tt.ptr, #blocked>, tensor<{M}x1xi32, #blocked> - %6 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #{GPU_DIALECT}.slice<{{dim = 0, parent = #blocked}}>> - %7 = tt.expand_dims %6 {{axis = 0 : i32}} : (tensor<{N}xi32, #{GPU_DIALECT}.slice<{{dim = 0, parent = #blocked}}>>) -> tensor<1x{N}xi32, #blocked> - %8 = tt.broadcast %5 : (tensor<{M}x1x!tt.ptr, #blocked>) -> tensor<{M}x{N}x!tt.ptr, #blocked> - %9 = tt.broadcast %7 : (tensor<1x{N}xi32, #blocked>) -> tensor<{M}x{N}xi32, #blocked> - %10 = tt.addptr %8, %9 : tensor<{M}x{N}x!tt.ptr, #blocked>, tensor<{M}x{N}xi32, #blocked> - %11 = tt.splat %arg2 : (!tt.ptr) -> tensor<{rdims_2d}x!tt.ptr, #blocked> - %12 = tt.addptr %11, {store_range} : tensor<{rdims_2d}x!tt.ptr, #blocked>, tensor<{rdims_2d}xi32, #blocked> - %13 = tt.load %10 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}xi32, #blocked> - %14 = {GPU_DIALECT}.convert_layout %13 : (tensor<{M}x{N}xi32, #blocked>) -> tensor<{M}x{N}xi32, #src> - %15 = "tt.reduce"(%14) ({{ - ^bb0(%arg3: i32, %arg4: i32): - %17 = arith.addi %arg3, %arg4 : i32 - tt.reduce.return %17 : i32 - }}) {{axis = {axis} : i32}} : (tensor<{M}x{N}xi32, #src>) -> tensor<{rdims_1d}xi32, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #src}}>> - %18 = {GPU_DIALECT}.convert_layout %15 : (tensor<{rdims_1d}xi32, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #src}}>>) -> tensor<{rdims_1d}xi32, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #blocked}}>> - %19 = tt.expand_dims %18 {{axis = {axis} : i32}} : (tensor<{rdims_1d}xi32, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #blocked}}>>) -> tensor<{rdims_2d}xi32, #blocked> - tt.store %12, %19 {{cache = 1 : i32, evict = 1 : i32}} : tensor<{rdims_2d}xi32, #blocked> - tt.return - }} - }} - """ -======= %4 = tt.splat %arg0 : (!tt.ptr<{ty}, 1>) -> tensor<{M}x1x!tt.ptr<{ty}, 1>, #blocked> %5 = tt.addptr %4, %3 : tensor<{M}x1x!tt.ptr<{ty}, 1>, #blocked>, tensor<{M}x1xi32, #blocked> %6 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #{GPU_DIALECT}.slice<{{dim = 0, parent = #blocked}}>> @@ -1986,7 +1946,6 @@ def test_reduce_layouts(M, N, src_layout, axis, reduce2d, dtype_str, reduce_op, tt.reduce.return %17 : {ty} }}) {{axis = {axis} : i32}} : (tensor<{M}x{N}x{ty}, #src>) -> tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #src}}>> """ + epilogue ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 import tempfile with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: @@ -2027,28 +1986,16 @@ def test_store_op(M, src_layout, device): ir = f""" #src = {src_layout} module attributes {{"{GPU_DIALECT}.num-warps" = 4 : i32, "{GPU_DIALECT}.num-ctas" = 1 : i32, "{GPU_DIALECT}.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ -<<<<<<< HEAD - tt.func public @kernel(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ - %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> - %1 = tt.splat %arg0 : (!tt.ptr) -> tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> - %2 = tt.addptr %1, %0 : tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>, tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> -======= tt.func public @kernel(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> %1 = tt.splat %arg0 : (!tt.ptr) -> tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> %2 = tt.addptr %1, %0 : tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>, tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 %3 = tt.load %2 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}xf32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> %4 = tt.expand_dims %3 {{axis = 1 : i32}} : (tensor<{M}xf32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xf32, #src> %5 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> %6 = tt.expand_dims %5 {{axis = 1 : i32}} : (tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xi32, #src> -<<<<<<< HEAD - %7 = tt.splat %arg1 : (!tt.ptr) -> tensor<{M}x1x!tt.ptr, #src> - %8 = tt.addptr %7, %6 : tensor<{M}x1x!tt.ptr, #src>, tensor<{M}x1xi32, #src> -======= %7 = tt.splat %arg1 : (!tt.ptr) -> tensor<{M}x1x!tt.ptr, #src> %8 = tt.addptr %7, %6 : tensor<{M}x1x!tt.ptr, #src>, tensor<{M}x1xi32, #src> ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 tt.store %8, %4 : tensor<{M}x1xf32, #src> tt.return }} @@ -2092,16 +2039,6 @@ def test_convert1d(M, src_layout, dst_layout, src_dim, dst_dim, device): #dst = {dst_layout} #src = {src_layout} module attributes {{"{GPU_DIALECT}.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ -<<<<<<< HEAD - tt.func public @kernel(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ - %0 = tt.splat %arg0 : (!tt.ptr) -> tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>> - %1 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>> - %2 = tt.addptr %0, %1 : tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>>, tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>> - %3 = tt.load %2 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>> - %4 = tt.splat %arg1 : (!tt.ptr) -> tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>> - %5 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>> - %6 = tt.addptr %4, %5 : tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>>, tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>> -======= tt.func public @kernel(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ %0 = tt.splat %arg0 : (!tt.ptr) -> tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>> %1 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>> @@ -2110,7 +2047,6 @@ def test_convert1d(M, src_layout, dst_layout, src_dim, dst_dim, device): %4 = tt.splat %arg1 : (!tt.ptr) -> tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>> %5 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>> %6 = tt.addptr %4, %5 : tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>>, tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>> ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 %7 = {GPU_DIALECT}.convert_layout %3 : (tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>>) -> tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>> tt.store %6, %7 : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>> tt.return @@ -2174,11 +2110,7 @@ def test_chain_reduce(M, N, src_layout, op, device, first_axis): ir = f""" #src = {src_layout} module attributes {{"{GPU_DIALECT}.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ -<<<<<<< HEAD - tt.func public @sum_kernel_0d1d(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ -======= tt.func public @sum_kernel_0d1d(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #src> %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> %1 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xi32, #src> @@ -2506,18 +2438,10 @@ def kernel(X, stride_xm, stride_xk, red_code = ptx[start:end] assert len(red_code) > 0 import os -<<<<<<< HEAD - enable_mmav3 = os.environ.get('ENABLE_MMA_V3', 'not found').lower() - enable_tma = os.environ.get('ENABLE_TMA', 'not found').lower() - # skip this check on hopper because there are some functions whose name contain "shared" in ptx. - # TODO: we should eliminate these unused functions in ptx code. - if not (enable_mmav3 in ["on", "true", "1"] and enable_tma in ["on", "true", "1"]): -======= # skip this check on hopper because there are some functions whose name contain "shared" in ptx. # TODO: we should eliminate these unused functions in ptx code. if not (capability[0] >= 9): ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 assert "shared" not in red_code assert "bar.sync" not in red_code # torch result @@ -3766,18 +3690,11 @@ def kernel(Out): @pytest.mark.parametrize("src_layout", layouts) @pytest.mark.parametrize("interm_layout", intermediate_layouts) @pytest.mark.parametrize("dst_layout", layouts) -<<<<<<< HEAD -def test_convert2d(dtype, shape, src_layout, interm_layout, dst_layout, device): - if is_hip(): - pytest.skip("test_convert2d is not supported in HIP") - -======= def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device): if is_hip(): pytest.skip("test_convert2d is not supported in HIP") if (M == 1 or N == 1) and interm_layout: pytest.skip("Out of bound access when maxPhase > 1") ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 if str(src_layout) == str(dst_layout): pytest.skip() if 'mma' in str(src_layout) and 'mma' in str(dst_layout): diff --git a/python/test/unit/operators/test_flash_attention.py b/python/test/unit/operators/test_flash_attention.py index b4318426513e..09f739bcb687 100644 --- a/python/test/unit/operators/test_flash_attention.py +++ b/python/test/unit/operators/test_flash_attention.py @@ -20,7 +20,6 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par): pytest.skip('bfloat16 tma not support currently') capability = torch.cuda.get_device_capability() -<<<<<<< HEAD if torch.version.hip is not None: if dtype != torch.float16: pytest.skip("Currently flash attention on AMD gpu is only supported in fp16.") @@ -31,11 +30,9 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par): if capability[0] < 8: pytest.skip("Flash attention only supported for compute capability < 80") -======= interpreter = os.environ.get("TRITON_INTERPRET", 'not found') in ["on", "true", "1"] if not interpreter and capability[0] < 8: pytest.skip("Flash attention only supported for compute capability >= 80") ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 torch.manual_seed(20) q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() @@ -68,15 +65,8 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par): tri_dq, q.grad = q.grad.clone(), None # compare atol = 1e-1 if dtype == torch.bfloat16 else 1e-2 -<<<<<<< HEAD - torch.testing.assert_allclose(ref_out, tri_out, atol=atol, rtol=0) - if torch.version.hip is None: - torch.testing.assert_allclose(ref_dv, tri_dv, atol=atol, rtol=0) - torch.testing.assert_allclose(ref_dk, tri_dk, atol=atol, rtol=0) - torch.testing.assert_allclose(ref_dq, tri_dq, atol=atol, rtol=0) -======= torch.testing.assert_close(ref_out, tri_out, atol=atol, rtol=0) - torch.testing.assert_close(ref_dv, tri_dv, atol=atol, rtol=0) - torch.testing.assert_close(ref_dk, tri_dk, atol=atol, rtol=0) - torch.testing.assert_close(ref_dq, tri_dq, atol=atol, rtol=0) ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 + if torch.version.hip is None: + torch.testing.assert_close(ref_dv, tri_dv, atol=atol, rtol=0) + torch.testing.assert_close(ref_dk, tri_dk, atol=atol, rtol=0) + torch.testing.assert_close(ref_dq, tri_dq, atol=atol, rtol=0) diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index 8cce0d7a31e3..89c7e1999925 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -5,18 +5,11 @@ import json import os import re -<<<<<<< HEAD -import tempfile -from collections import namedtuple -from pathlib import Path -from typing import Any -======= from collections import namedtuple from pathlib import Path from typing import Any from dataclasses import dataclass ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 from .._C.libtriton.triton import (ClusterInfo, TMAInfos, add_external_libs, compile_ptx_to_cubin, get_env_vars, get_num_warps, @@ -71,18 +64,13 @@ def ttir_compute_capability_rewrite(mod, target): # with block (tensor) pointers into tensors of pointers pm = ir.pass_manager(mod.context) pm.enable_debug() -<<<<<<< HEAD - if _is_cuda(arch): - pm.add_rewrite_tensor_pointer_pass(arch, False) + if _is_cuda(target): + pm.add_rewrite_tensor_pointer_pass(target.capability, False) elif is_hip(): capability = 90 pm.add_rewrite_tensor_pointer_pass(capability, True) else: assert(False, "unsupported target") -======= - if _is_cuda(target): - pm.add_rewrite_tensor_pointer_pass(target.capability) ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 pm.run(mod) return mod @@ -103,34 +91,22 @@ def optimize_ttir(mod, target): return mod -<<<<<<< HEAD -def ttir_to_ttgir(mod, num_warps, warpsize, num_ctas, arch): +def ttir_to_ttgir(mod, num_warps, warpsize, num_ctas, target): pm = ir.pass_manager(mod.context) pm.enable_debug() if is_hip(): pm.add_convert_triton_to_tritongpu_pass(num_warps, warpsize, num_ctas, 0) else: - pm.add_convert_triton_to_tritongpu_pass(num_warps, warpsize, num_ctas, arch) -======= -def ttir_to_ttgir(mod, num_warps, num_ctas, target): - pm = ir.pass_manager(mod.context) - pm.enable_debug() - pm.add_convert_triton_to_tritongpu_pass(num_warps, 32, num_ctas, target.capability) ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 + pm.add_convert_triton_to_tritongpu_pass(num_warps, warpsize, num_ctas, target.capability) pm.run(mod) return mod -<<<<<<< HEAD -def optimize_ttgir(mod, num_stages, num_warps, num_ctas, arch, - cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue, matrix_inst_type): -======= def optimize_ttgir(mod, num_stages, num_warps, num_ctas, target, - cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue): + cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue, matrix_inst_type): is_cuda = _is_cuda(target) if is_cuda: capability = target.capability ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 pm = ir.pass_manager(mod.context) pm.enable_debug() pm.add_tritongpu_coalesce_pass() @@ -140,18 +116,13 @@ def optimize_ttgir(mod, num_stages, num_warps, num_ctas, target, pm.add_tritongpu_rewrite_tensor_pointer_pass(capability) pm.add_plan_cta_pass(cluster_info) pm.add_tritongpu_remove_layout_conversions_pass() -<<<<<<< HEAD - if _is_cuda(arch): - pm.add_tritongpu_accelerate_matmul_pass(arch) + if is_cuda: + pm.add_tritongpu_accelerate_matmul_pass(capability) # TODO change interface of accelerate_matmul_pass if is_hip(): matrix_core_version = gpu_matrix_core_version() matrix_inst_size = matrix_inst_type pm.add_tritonamdgpu_accelerate_matmul_pass(matrix_core_version, matrix_inst_size) -======= - if is_cuda: - pm.add_tritongpu_accelerate_matmul_pass(capability) ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 pm.add_tritongpu_remove_layout_conversions_pass() if optimize_epilogue: pm.add_tritongpu_optimize_epilogue_pass() @@ -165,13 +136,8 @@ def optimize_ttgir(mod, num_stages, num_warps, num_ctas, target, # it's the responsibility of the compiler to figure out the exact # `num_warps` to use. # TODO: support the case where `num_warps` from user is not 4. -<<<<<<< HEAD - if _is_cuda(arch) and arch // 10 >= 9 and enable_warp_specialization and num_warps == 4: - pm.add_tritongpu_ws_feasibility_checking_pass(arch) -======= - if capability // 10 >= 9 and enable_warp_specialization and num_warps == 4: + if is_cuda and capability // 10 >= 9 and enable_warp_specialization and num_warps == 4: pm.add_tritongpu_ws_feasibility_checking_pass(capability) ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 pm.run(mod) ws_enabled = ir.is_ws_supported(mod) pm = ir.pass_manager(mod.context) @@ -183,23 +149,17 @@ def optimize_ttgir(mod, num_stages, num_warps, num_ctas, target, pm.add_tritongpu_wsmaterialization_pass(capability) pm.add_cse_pass() else: -<<<<<<< HEAD if is_hip(): pm.add_tritongpu_pipeline_pass( num_stages, num_warps, num_ctas, 0) else: pm.add_tritongpu_pipeline_pass( - num_stages, num_warps, num_ctas, arch) + num_stages, num_warps, num_ctas, capability) if is_hip(): pm.add_tritongpu_materialize_load_store_pass(num_warps, 0) else: - pm.add_tritongpu_materialize_load_store_pass(num_warps, arch) - if _is_cuda(arch) and arch // 10 <= 8: -======= - pm.add_tritongpu_pipeline_pass(num_stages, num_warps, num_ctas, capability) - pm.add_tritongpu_materialize_load_store_pass(num_warps, capability) - if capability // 10 <= 8: ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 + pm.add_tritongpu_materialize_load_store_pass(num_warps, capability) + if is_cuda and capability // 10 <= 8: pm.add_tritongpu_prefetch_pass() pm.add_tritongpu_optimize_dot_operands_pass() pm.add_tritongpu_remove_layout_conversions_pass() @@ -209,11 +169,7 @@ def optimize_ttgir(mod, num_stages, num_warps, num_ctas, target, pm.add_tritongpu_reorder_instructions_pass() pm.add_cse_pass() pm.add_symbol_dce_pass() -<<<<<<< HEAD - if _is_cuda(arch) and arch // 10 >= 9: -======= - if capability // 10 >= 9: ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 + if is_cuda and capability // 10 >= 9: pm.add_tritongpu_fence_insertion_pass() pm.add_tritongpu_ws_fixup_missing_attrs_pass() pm.run(mod) @@ -227,21 +183,12 @@ def _add_external_libs(mod, libs): add_external_libs(mod, list(libs.keys()), list(libs.values())) -<<<<<<< HEAD -def ttgir_to_llir(mod, extern_libs, arch, tma_infos, waves_per_eu=0): - if extern_libs: - _add_external_libs(mod, extern_libs) - # TODO: separate tritongpu_to_llvmir for different backends - if _is_cuda(arch): - return translate_triton_gpu_to_llvmir(mod, arch, tma_infos, runtime.TARGET.NVVM, waves_per_eu) -======= -def ttgir_to_llir(mod, extern_libs, target, tma_infos): +def ttgir_to_llir(mod, extern_libs, target, tma_infos, waves_per_eu=0): if extern_libs: _add_external_libs(mod, extern_libs) # TODO: separate tritongpu_to_llvmir for different backends if _is_cuda(target): - return translate_triton_gpu_to_llvmir(mod, target.capability, tma_infos, runtime.TARGET.NVVM) ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 + return translate_triton_gpu_to_llvmir(mod, target.capability, tma_infos, runtime.TARGET.NVVM, waves_per_eu) else: return translate_triton_gpu_to_llvmir(mod, 0, TMAInfos(), runtime.TARGET.ROCDL, waves_per_eu) @@ -284,11 +231,7 @@ def ptx_to_cubin(ptx: str, target: CudaTargetDescriptor): :return: str ''' ptxas, _ = path_to_ptxas() -<<<<<<< HEAD - return compile_ptx_to_cubin(ptx, ptxas, arch) -======= return compile_ptx_to_cubin(ptx, ptxas, target.capability) ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 # ------------------------------------------------------------------------------ @@ -333,11 +276,7 @@ def make_hash(fn, target, env_vars, **kwargs): get_conf_key = lambda conf: (sorted(conf.divisible_by_16), sorted(conf.equal_to_1), sorted(conf.ids_of_folded_args), sorted(conf.divisible_by_8)) configs_key = [get_conf_key(conf) for conf in configs] env_vars_list = [f"{env_vars[k]}" for k in sorted(env_vars.keys())] -<<<<<<< HEAD - key = f"{fn.cache_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{waves_per_eu}-{matrix_instr_nonkdim}-{num_ctas}-{num_stages}-{enable_warp_specialization}-{enable_persistent}-{debug}-{arch}-{env_vars_list}" -======= - key = f"{fn.cache_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{num_ctas}-{num_stages}-{enable_warp_specialization}-{enable_persistent}-{debug}-{target}-{env_vars_list}" ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 + key = f"{fn.cache_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{waves_per_eu}-{matrix_instr_nonkdim}-{num_ctas}-{num_stages}-{enable_warp_specialization}-{enable_persistent}-{debug}-{target}-{env_vars_list}" return hashlib.md5(key.encode("utf-8")).hexdigest() assert isinstance(fn, str) ignore_version = kwargs.get('ignore_version', False) @@ -374,11 +313,7 @@ def make_hash(fn, target, env_vars, **kwargs): "ptx": ptx_arg_type_pattern, } if is_hip(): -<<<<<<< HEAD - ttgir_num_warps_pattern = r'"triton_gpu.num-warps"\s?=\s?(\d+)\s?:' -======= ttgir_num_warps_pattern = r'"triton_gpu_rocm.num-warps"\s?=\s?(\d+)\s?:' ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 else: ttgir_num_warps_pattern = r'"triton_gpu.num-warps"\s?=\s?(\d+)\s?:' @@ -407,11 +342,6 @@ def parse_mlir_module(path, context): instance_descriptor = namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"], defaults=[set(), set(), set(), set()]) -<<<<<<< HEAD -# TODO: architecture descriptor class -def _is_cuda(arch): - return isinstance(arch, int) - def is_hip(): try: import torch @@ -421,20 +351,6 @@ def is_hip(): from ..language.semantic import gpu_matrix_core_version -@functools.lru_cache -def get_architecture_descriptor(capability): - if is_hip(): - _device_backend = get_backend("hip") - assert _device_backend - arch = _device_backend.get_architecture_descriptor() - return arch - else: - if capability is None: - device = get_current_device() - capability = get_device_capability(device) - capability = capability[0] * 10 + capability[1] - return capability -======= def get_cuda_capability(capability): if capability is None: device = get_current_device() @@ -442,8 +358,6 @@ def get_cuda_capability(capability): capability = capability[0] * 10 + capability[1] return capability ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 - @functools.lru_cache def get_arch_default_num_warps(device_type): if device_type in ["cuda"]: @@ -457,15 +371,8 @@ def get_arch_default_num_warps(device_type): @functools.lru_cache def get_arch_default_num_stages(device_type, capability=None): -<<<<<<< HEAD - if device_type in ["cuda"]: - arch = get_architecture_descriptor(capability) - is_cuda = device_type == "cuda" and _is_cuda(arch) - num_stages = 3 if is_cuda and arch >= 75 else 2 -======= if device_type == "cuda": num_stages = 3 if get_cuda_capability(capability) >= 75 else 2 ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 else: _device_backend = get_backend(device_type) assert _device_backend @@ -475,12 +382,7 @@ def get_arch_default_num_stages(device_type, capability=None): return num_stages -<<<<<<< HEAD -def add_cuda_stages(arch, extern_libs, stages): -======= def add_cuda_stages(target, extern_libs, stages): - ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 stages["ptx"] = (lambda path: Path(path).read_text(), lambda src: llir_to_ptx(src, target)) stages["cubin"] = (lambda path: Path(path).read_bytes(), @@ -494,27 +396,11 @@ def compile(fn, **kwargs): if is_hip(): device_type = "hip" -<<<<<<< HEAD capability = None - if device_type == "cuda": - _device_backend = get_backend(device_type) - arch = get_architecture_descriptor(capability) - else: - _device_backend = get_backend(device_type) - assert _device_backend - arch = _device_backend.get_architecture_descriptor(**kwargs) - - is_cuda = device_type == "cuda" and _is_cuda(arch) - if is_hip(): - is_cuda = False - warp_size = CUDA_DEFAULT_WARP_SIZE if _is_cuda(arch) else arch["warp_size"] -======= is_cuda = device_type == "cuda" if is_hip(): is_cuda = False - ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 context = ir.context() constants = kwargs.get("constants", dict()) num_warps = kwargs.get("num_warps", get_arch_default_num_warps(device_type)) @@ -542,9 +428,6 @@ def compile(fn, **kwargs): cluster_info.clusterDimY = kwargs["clusterDims"][1] cluster_info.clusterDimZ = kwargs["clusterDims"][2] tma_infos = TMAInfos() -<<<<<<< HEAD - -======= # build architecture descriptor if device_type == "cuda": _device_backend = get_backend(device_type) @@ -553,24 +436,23 @@ def compile(fn, **kwargs): _device_backend = get_backend(device_type) assert _device_backend target = _device_backend.get_architecture_descriptor(**kwargs) ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 + warp_size = CUDA_DEFAULT_WARP_SIZE if is_cuda else target["warp_size"] # build compilation stages stages = dict() stages["ast"] = (lambda path: fn, None) stages["ttir"] = (lambda path: parse_mlir_module(path, context), -<<<<<<< HEAD - lambda src: optimize_ttir(ast_to_ttir(src, signature, configs[0], constants, debug=debug, arch=arch), arch)) + lambda src: optimize_ttir(ast_to_ttir(src, signature, configs[0], constants, debug=debug, target=target), target)) if is_cuda: stages["ttgir"] = (lambda path: parse_mlir_module(path, context), - lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps, num_ctas, arch), num_stages, num_warps, num_ctas, arch, cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue)) + lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps, num_ctas, target), num_stages, num_warps, num_ctas, target, cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue)) stages["llir"] = (lambda path: Path(path).read_text(), - lambda src: ttgir_to_llir(src, extern_libs, arch, tma_infos)) - add_cuda_stages(arch, extern_libs, stages) + lambda src: ttgir_to_llir(src, extern_libs, target, tma_infos)) + add_cuda_stages(target, extern_libs, stages) elif device_type == "hip": # pass the user's configuration to the backend device. - arch["num_warps"] = num_warps - arch["num_stages"] = num_stages - arch["num_ctas"] = num_ctas + target["num_warps"] = num_warps + target["num_stages"] = num_stages + target["num_ctas"] = num_ctas other = {} other["context"] = context @@ -583,24 +465,13 @@ def compile(fn, **kwargs): other["waves_per_eu"] = waves_per_eu other["matrix_instr_nonkdim"] = matrix_instr_nonkdim - _device_backend.add_stages(arch, extern_libs, stages, other) + _device_backend.add_stages(target, extern_libs, stages, other) elif device_type == "xpu": stages["ttgir"] = (lambda path: parse_mlir_module(path, context), lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps, num_ctas, arch), num_stages, num_warps, num_ctas, arch, cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue)) stages["llir"] = (lambda path: Path(path).read_text(), lambda src: ttgir_to_llir(src, extern_libs, arch, tma_infos)) _device_backend.add_stages(arch, extern_libs, stages) -======= - lambda src: optimize_ttir(ast_to_ttir(src, signature, configs[0], constants, debug=debug, target=target), target)) - if is_cuda: - stages["ttgir"] = (lambda path: parse_mlir_module(path, context), - lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps, num_ctas, target), num_stages, num_warps, num_ctas, target, cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue)) - stages["llir"] = (lambda path: Path(path).read_text(), - lambda src: ttgir_to_llir(src, extern_libs, target, tma_infos)) - add_cuda_stages(target, extern_libs, stages) - elif device_type == "hip": - _device_backend.add_stages(target, extern_libs, stages, num_warps=num_warps, num_stages=num_stages) ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 else: # pass the user's configuration to the backend device. target["num_warps"] = num_warps @@ -735,19 +606,11 @@ def compile(fn, **kwargs): else: metadata["shared"] = get_shared_memory_size(module) if ir_name == "ttgir": -<<<<<<< HEAD metadata["enable_warp_specialization"] = ir.is_ws_supported(next_module) if metadata["enable_warp_specialization"]: if is_hip(): metadata["num_warps"] = _device_backend.get_num_warps(next_module) else: -======= - if is_hip(): - metadata["num_warps"] = _device_backend.get_num_warps(next_module) - else: - metadata["enable_warp_specialization"] = ir.is_ws_supported(next_module) - if metadata["enable_warp_specialization"]: ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 metadata["num_warps"] = get_num_warps(next_module) if ir_name == "ptx": metadata["name"] = get_kernel_name(next_module, pattern='// .globl') diff --git a/python/triton/compiler/make_launcher.py b/python/triton/compiler/make_launcher.py index ea58455f372b..c7dd75ec72a8 100644 --- a/python/triton/compiler/make_launcher.py +++ b/python/triton/compiler/make_launcher.py @@ -100,11 +100,7 @@ def format_of(ty): # generate glue code folded_without_constexprs = [c for c in ids['ids_of_folded_args'] if c not in ids['ids_of_const_exprs']] -<<<<<<< HEAD - params = [i for i in signature.keys() if i >= start_desc or (i not in constants and i not in folded_without_constexprs)] -======= params = [i for i in signature.keys() if i >= desc_start_idx or (i not in constants and i not in folded_without_constexprs)] ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 src = f""" #include \"cuda.h\" #include diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index b6194ad3bfd2..a5d38a6d0318 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1354,8 +1354,7 @@ def dot(lhs: tl.tensor, max_num_imprecise_acc: int, out_dtype: tl.dtype, builder: ir.builder) -> tl.tensor: -<<<<<<< HEAD - def assert_dtypes_valid(lhs_dtype, rhs_dtype, arch): + def assert_dtypes_valid(lhs_dtype, rhs_dtype, target): if is_hip(): assert lhs.dtype == rhs.dtype or (lhs.type.scalar.is_fp8() and rhs.type.scalar.is_fp16()) or \ (lhs.type.scalar.is_fp16() and rhs.type.scalar.is_fp8()) or (lhs.type.scalar.is_fp8() and rhs.type.scalar.is_fp8()), \ @@ -1364,35 +1363,12 @@ def assert_dtypes_valid(lhs_dtype, rhs_dtype, arch): return # Checks for non-cuda archs - if _is_cuda(builder.arch): + if target.capability < 90: # Checks for cuda arch if arch < 90: assert not lhs_dtype.is_fp8e4nv() and not rhs_dtype.is_fp8e4nv(), "Dot op does not support fp8e4nv on CUDA arch < 90" -======= - def assert_dtypes_valid(lhs_dtype, rhs_dtype, target): - # Checks for non-cuda archs - if not _is_cuda(target): - assert lhs_dtype == rhs_dtype, f"First input ({lhs_dtype}) and second input ({rhs_dtype}) must have the same dtype!" - return - # Checks for cuda arch - if target.capability < 90: - assert not lhs_dtype.is_fp8e4nv() and not rhs_dtype.is_fp8e4nv(), "Dot op does not support fp8e4nv on CUDA arch < 90" - if lhs_dtype.is_fp8() and rhs_dtype.is_fp8(): - return - assert lhs_dtype == rhs_dtype, f"First input ({lhs_dtype}) and second input ({rhs_dtype}) must have the same dtype!" - else: - assert not lhs_dtype.is_fp8e4b15() and not rhs_dtype.is_fp8e4b15(), "Dot op does not support fp8e4b15 on CUDA arch >= 90" - assert not lhs_dtype.is_fp8e4b15x4() and not rhs_dtype.is_fp8e4b15x4(), "Dot op does not support fp8e4b15x4 on CUDA arch >= 90" - if lhs_dtype.is_int() or rhs_dtype.is_int(): - assert lhs_dtype == rhs_dtype, f"Both operands must be same type. First operand ({lhs_dtype}) and second operand ({rhs_dtype})" - assert lhs_dtype.is_int8() or lhs_dtype.is_uint8(), f"Both operands must be either int8 or uint8. Operand type ({lhs_dtype})" - elif lhs_dtype.is_fp8() or rhs_dtype.is_fp8(): - assert lhs_dtype.is_fp8e4nv() or lhs_dtype.is_fp8e5(), f"Only supports fp8e4nv or fp8e5. First operand ({lhs_dtype})" - assert rhs_dtype.is_fp8e4nv() or rhs_dtype.is_fp8e5(), f"Only supports fp8e4nv or fp8e5. Second operand ({rhs_dtype})" - else: - assert lhs_dtype.is_fp16() or lhs_dtype.is_bf16() or lhs_dtype.is_fp32() or lhs_dtype.is_int1(), f"Unsupported dtype {lhs_dtype}" - assert rhs_dtype.is_fp16() or rhs_dtype.is_bf16() or rhs_dtype.is_fp32() or rhs_dtype.is_int1(), f"Unsupported dtype {rhs_dtype}" ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 + if lhs_dtype.is_fp8() and rhs_dtype.is_fp8(): + return assert lhs_dtype == rhs_dtype, f"First input ({lhs_dtype}) and second input ({rhs_dtype}) must have the same dtype!" else: assert not lhs_dtype.is_fp8e4b15() and not rhs_dtype.is_fp8e4b15(), "Dot op does not support fp8e4b15 on CUDA arch >= 90" @@ -1412,12 +1388,8 @@ def assert_dtypes_valid(lhs_dtype, rhs_dtype, target): assert lhs_dtype == rhs_dtype, f"First input ({lhs_dtype}) and second input ({rhs_dtype}) must have the same dtype!" return -<<<<<<< HEAD assert lhs.type.is_block() and rhs.type.is_block() - assert_dtypes_valid(lhs.dtype, rhs.dtype, builder.arch) -======= assert_dtypes_valid(lhs.dtype, rhs.dtype, builder.target) ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 assert len(lhs.shape) == 2, f"First input shape ({lhs.shape}) is not two dimensional!" assert len(rhs.shape) == 2, f"Second input shape ({rhs.shape}) is not two dimensional!" @@ -1480,11 +1452,8 @@ def assert_dtypes_valid(lhs_dtype, rhs_dtype, target): ret = tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32), ret_ty) return cast(ret, ret_scalar_ty, builder) -<<<<<<< HEAD _0 = builder.create_splat(_0, [M, N]) -======= ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 ret_ty = tl.block_type(ret_scalar_ty, [M, N]) if acc is None: acc_handle = builder.create_splat(_0, [M, N]) diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index 72605e621935..b558b1380132 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -351,7 +351,7 @@ def _make_launcher(self): def regular_args_v(args_proxy): return [args_proxy[arg_name] for arg_name in regular_args] - def launcher_body(args_proxy, grid, num_warps, num_ctas, num_stages, enable_warp_specialization, extern_libs, stream, warmup, device, device_type): + def launcher_body(args_proxy, grid, num_warps, num_ctas, num_stages, waves_per_eu, matrix_instr_nonkdim, enable_warp_specialization, extern_libs, stream, warmup, device, device_type): from ..compiler import (CompiledKernel, compile, get_arch_default_num_stages, get_arch_default_num_warps) @@ -402,7 +402,7 @@ def launcher_body(args_proxy, grid, num_warps, num_ctas, num_stages, enable_warp if num_stages is None: num_stages = get_arch_default_num_stages(device_type) - key = (version_key(), sig_key, constexpr_key, spec_key, num_warps, num_ctas, num_stages, enable_warp_specialization, self.debug) + key = (version_key, sig_key, constexpr_key, spec_key, num_warps, num_ctas, num_stages, waves_per_eu, matrix_instr_nonkdim, enable_warp_specialization, self.debug) if extern_libs is not None: key = (key, tuple(extern_libs.items())) @@ -430,8 +430,8 @@ def launcher_body(args_proxy, grid, num_warps, num_ctas, num_stages, enable_warp for i, arg in constants.items(): if callable(arg): raise TypeError(f"Callable constexpr at index {i} is not supported") - if not self._call_hook(key, signature, device, constants, num_warps, num_ctas, num_stages, enable_warp_specialization, extern_libs, configs): - bin = compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, enable_warp_specialization=enable_warp_specialization, extern_libs=extern_libs, configs=configs, debug=self.debug, device_type=device_type) + if not self._call_hook(key, signature, device, constants, num_warps, num_ctas, num_stages, waves_per_eu, matrix_instr_nonkdim, enable_warp_specialization, extern_libs, configs): + bin = compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, waves_per_eu=waves_per_eu, matrix_instr_nonkdim=matrix_instr_nonkdim, enable_warp_specialization=enable_warp_specialization, extern_libs=extern_libs, configs=configs, debug=self.debug, device_type=device_type) # Create tensormaps and append to args args = bin.assemble_tensormap_to_arg(args) if not warmup: @@ -446,90 +446,8 @@ def launcher_body(args_proxy, grid, num_warps, num_ctas, num_stages, enable_warp args_signature = args_signature + ', ' if len(args_signature) > 0 else '' src = f""" import triton -<<<<<<< HEAD def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, num_stages=None, waves_per_eu=0, matrix_instr_nonkdim=0, enable_warp_specialization=False, extern_libs=None, stream=None, warmup=False, device=None, device_type=None): - from ..compiler import compile, CompiledKernel, get_arch_default_num_warps, get_arch_default_num_stages - sig_key = {f'{sig_keys},' if len(sig_keys) > 0 else ()} - constexpr_key = {f'{constexpr_keys},' if len(constexpr_keys) > 0 else ()} - spec_key = {f'{spec_keys},' if len(spec_keys) > 0 else ()} - assert num_ctas > 0 - assert grid is not None - if callable(grid): - grid = grid({{{grid_args}}}) - grid_size = len(grid) - grid_0 = grid[0] - grid_1 = grid[1] if grid_size > 1 else 1 - grid_2 = grid[2] if grid_size > 2 else 1 - - if device_type is None: - device_types = [_device_type for _device_type in {device_types} if _device_type != ''] - device_type = self._conclude_device_type(device_types, {pinned_memory_flags}) - - device_backend = None - if device_type not in ['cuda']: - device_backend = get_backend(device_type) - if device_backend is None: - raise ValueError('Cannot find backend for ' + device_type) - - if device is None: - if device_type in ['cuda']: - device = get_current_device() - set_current_device(device) - else: - device = device_backend.get_current_device() - device_backend.set_current_device(device) - if stream is None and not warmup: - if device_type in ['cuda']: - stream = get_cuda_stream(device) - else: - stream = device_backend.get_stream() - - if num_warps is None: - num_warps = get_arch_default_num_warps(device_type) - if num_stages is None: - num_stages = get_arch_default_num_stages(device_type) - - key = (version_key, sig_key, constexpr_key, spec_key, num_warps, num_ctas, num_stages, waves_per_eu, matrix_instr_nonkdim, enable_warp_specialization, self.debug) - if not extern_libs is None: - key = (key, tuple(extern_libs.items())) - - bin = cache[device].get(key, None) - if bin is not None: - # build dict of constant values - args = [{args}] - # Create tensormaps and append to args - args = bin.assemble_tensormap_to_arg(args) - if not warmup: - bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.num_ctas, bin.clusterDims[0], bin.clusterDims[1], bin.clusterDims[2], bin.shared, stream, bin.cu_function, CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, bin, *args) - return bin - # kernel not cached -- compile - else: - # build dict of constant values - args = [{args}] - all_args = {', '.join([f'{arg}' for arg in self.arg_names]) + ', ' if len(self.arg_names) > 0 else ()} - configs = self._get_config(*all_args), - constants = self._make_constants(constexpr_key) - constants.update({{i: None for i, arg in enumerate(all_args) if arg is None}}) - constants.update({{i: 1 for i in configs[0].equal_to_1}}) - # build kernel signature -- doesn't include specialized arguments - signature = {{ i: self._type_of(_key_of(arg)) for i, arg in enumerate(all_args) if i not in self.constexprs }} - # build stub signature -- includes arguments that are specialized - for i, arg in constants.items(): - if callable(arg): - raise TypeError(f"Callable constexpr at index {{i}} is not supported") - if not self._call_hook(key, signature, device, constants, num_warps, num_ctas, num_stages, waves_per_eu, matrix_instr_nonkdim, enable_warp_specialization, extern_libs, configs): - bin = compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, waves_per_eu=waves_per_eu, matrix_instr_nonkdim=matrix_instr_nonkdim, enable_warp_specialization=enable_warp_specialization, extern_libs=extern_libs, configs=configs, debug=self.debug, device_type=device_type) - # Create tensormaps and append to args - args = bin.assemble_tensormap_to_arg(args) - if not warmup: - bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.num_ctas, bin.clusterDims[0], bin.clusterDims[1], bin.clusterDims[2], bin.shared, stream, bin.cu_function, CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, bin, *args) - self.cache[device][key] = bin - return bin - return None -======= -def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, num_stages=None, enable_warp_specialization=False, extern_libs=None, stream=None, warmup=False, device=None, device_type=None): - return launcher_body({{{args_map}}}, grid, num_warps, num_ctas, num_stages, enable_warp_specialization, extern_libs, stream, warmup, device, device_type) ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 + return launcher_body({{{args_map}}}, grid, num_warps, num_ctas, num_stages, waves_per_eu, matrix_instr_nonkdim, enable_warp_specialization, extern_libs, stream, warmup, device, device_type) """ scope = {"launcher_body": launcher_body} exec(src, scope) diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index 719ba3bf6be5..8f8ca18c48f5 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -148,7 +148,6 @@ def _attn_fwd( stride_vz, stride_vh, stride_vk, stride_vn, stride_oz, stride_oh, stride_om, stride_on, Z, H, -<<<<<<< HEAD N_CTX, BLOCK_DMODEL: tl.constexpr, STAGE: tl.constexpr, @@ -161,12 +160,11 @@ def _attn_fwd( qkv_offset = off_hz * stride_qh Q_block_ptr = tl.make_block_ptr( base=Q + qkv_offset, -======= - N_CTX: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, STAGE: tl.constexpr, + pre_load_v: tl.constexpr, ): start_m = tl.program_id(0) off_hz = tl.program_id(1) @@ -177,7 +175,6 @@ def _attn_fwd( # block pointers Q_block_ptr = tl.make_block_ptr( base=Q + qvk_offset, ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 shape=(N_CTX, BLOCK_DMODEL), strides=(stride_qm, stride_qk), offsets=(start_m * BLOCK_M, 0), @@ -193,26 +190,13 @@ def _attn_fwd( order=(1, 0), ) K_block_ptr = tl.make_block_ptr( -<<<<<<< HEAD - base=K + qkv_offset, -======= base=K + qvk_offset, ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 shape=(BLOCK_DMODEL, N_CTX), strides=(stride_kk, stride_kn), offsets=(0, 0), block_shape=(BLOCK_DMODEL, BLOCK_N), order=(0, 1), ) -<<<<<<< HEAD - V_block_ptr = tl.make_block_ptr( - base=V + qkv_offset, - shape=(N_CTX, BLOCK_DMODEL), - strides=(stride_vk, stride_vn), - offsets=(0, 0), - block_shape=(BLOCK_N, BLOCK_DMODEL), - order=(1, 0) -======= O_block_ptr = tl.make_block_ptr( base=Out + qvk_offset, shape=(N_CTX, BLOCK_DMODEL), @@ -220,7 +204,6 @@ def _attn_fwd( offsets=(start_m * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_DMODEL), order=(1, 0), ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 ) # initialize offsets offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) @@ -229,7 +212,6 @@ def _attn_fwd( m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) -<<<<<<< HEAD # scale sm_scale by log_2(e) and use # 2^x instead of exp in the loop because CSE and LICM # don't work as expected with `exp` in the loop @@ -274,51 +256,13 @@ def _attn_fwd( block_shape=(BLOCK_M, BLOCK_DMODEL), order=(1, 0) ) -======= - # load scales - qk_scale = sm_scale - qk_scale *= 1.44269504 # 1/log(2) - # load q: it will stay in SRAM throughout - q = tl.load(Q_block_ptr) - # stage 1: off-band - if STAGE & 1: - acc, l_i, m_i = _attn_fwd_inner( - acc, l_i, m_i, q, K_block_ptr, V_block_ptr, - start_m, qk_scale, - BLOCK_M, BLOCK_DMODEL, BLOCK_N, - 1, offs_m, offs_n, - ) - # barrier makes it easier for compielr to schedule the - # two loops independently - tl.debug_barrier() - # stage 2: on-band - if STAGE & 2: - acc, l_i, m_i = _attn_fwd_inner( - acc, l_i, m_i, q, K_block_ptr, V_block_ptr, - start_m, qk_scale, - BLOCK_M, BLOCK_DMODEL, BLOCK_N, - 2, offs_m, offs_n, - ) - # epilogue - m_i += tl.math.log2(l_i) - acc = acc / l_i[:, None] - m_ptrs = M + off_hz * N_CTX + offs_m - tl.store(m_ptrs, m_i) ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 tl.store(O_block_ptr, acc.to(Out.type.element_ty)) @triton.jit -<<<<<<< HEAD def _bwd_preprocess( Out, DO, NewDO, Delta, -======= -def _attn_bwd_preprocess( - O, DO, - Delta, - Z, H, N_CTX, ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr, ): off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) @@ -329,12 +273,8 @@ def _attn_bwd_preprocess( do = tl.load(DO + off_hz * D_HEAD * N_CTX + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) delta = tl.sum(o * do, axis=1) # write-back -<<<<<<< HEAD tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do) tl.store(Delta + off_m, delta) -======= - tl.store(Delta + off_hz * N_CTX + off_m, delta) ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 # The main inner-loop logic for computing dK and dV. @@ -467,7 +407,6 @@ def _attn_bwd( pid = tl.program_id(0) # offset pointers for batch/head -<<<<<<< HEAD Q += off_z * stride_qz + off_h * stride_qh K += off_z * stride_kz + off_h * stride_kh V += off_z * stride_vz + off_h * stride_vh @@ -539,117 +478,6 @@ def _attn_bwd( dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk) tl.store(dk_ptrs, dk) tl.store(dv_ptrs, dv) -======= - Q += adj - K += adj - V += adj - DO += adj - DQ += adj - DK += adj - DV += adj - M += off_chz - D += off_chz - - # load scales - offs_k = tl.arange(0, BLOCK_DMODEL) - - # THIS BLOCK DOES DK/DV/DR: - - start_n = pid * BLOCK_N1 - start_m = start_n - - MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR - offs_n = start_n + tl.arange(0, BLOCK_N1) - - dv = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32) - dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32) - - # load K and V: they stay in SRAM throughout the inner loop. - k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d) - v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d) - - num_steps = BLOCK_N1 // MASK_BLOCK_M1 - - dk, dv = _attn_bwd_dkdv(dk, dv, - Q, k, v, sm_scale, - DO, - M, D, - stride_tok, stride_d, - H, N_CTX, - MASK_BLOCK_M1, BLOCK_N1, BLOCK_DMODEL, - start_n, start_m, num_steps, - MASK=True, - ) - - start_m += num_steps * MASK_BLOCK_M1 - num_steps = (N_CTX - start_m) // BLOCK_M1 - - # Compute dK and dV for non-masked blocks. - dk, dv = _attn_bwd_dkdv(dk, dv, - Q, k, v, sm_scale, - DO, - M, D, - stride_tok, stride_d, - H, N_CTX, - BLOCK_M1, BLOCK_N1, BLOCK_DMODEL, - start_n, start_m, num_steps, - MASK=False, - ) - - dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d - tl.store(dv_ptrs, dv) - - # Write back dK. - dk *= sm_scale - dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d - tl.store(dk_ptrs, dk) - - # THIS BLOCK DOES DQ: - start_m = pid * BLOCK_M2 - end_n = start_m + BLOCK_M2 - - MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR - offs_m = start_m + tl.arange(0, BLOCK_M2) - - q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d) - dq = tl.zeros([BLOCK_M2, BLOCK_DMODEL], dtype=tl.float32) - do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d) - - m = tl.load(M + offs_m) - m = m[:, None] - - # Compute dQ for masked (diagonal) blocks. - # NOTE: This code scans each row of QK^T backward (from right to left, - # but inside each call to _attn_bwd_dq, from left to right), but that's - # not due to anything important. I just wanted to reuse the loop - # structure for dK & dV above as much as possible. - num_steps = BLOCK_M2 // MASK_BLOCK_N2 - dq = _attn_bwd_dq( - dq, q, K, V, - do, m, D, - stride_tok, stride_d, - H, N_CTX, - BLOCK_M2, MASK_BLOCK_N2, BLOCK_DMODEL, - start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps, - MASK=True, - ) - end_n -= num_steps * MASK_BLOCK_N2 - # stage 2 - num_steps = end_n // BLOCK_N2 - dq = _attn_bwd_dq( - dq, q, K, V, - do, m, D, - stride_tok, stride_d, - H, N_CTX, - BLOCK_M2, BLOCK_N2, BLOCK_DMODEL, - start_m, end_n - num_steps * BLOCK_N2, num_steps, - MASK=False, - ) - # Write back dQ. - dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d - dq *= LN2 - tl.store(dq_ptrs, dq) ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 @triton.jit def _bwd_kernel_dk_dv( @@ -873,8 +701,12 @@ def forward(ctx, q, k, v, causal, sm_scale, split_kernel=False): Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] assert Lq == Lk and Lk == Lv assert Lk in {16, 32, 64, 128} +<<<<<<< HEAD <<<<<<< HEAD o = torch.empty_like(q, dtype=v.dtype) +======= + o = torch.empty_like(q) +>>>>>>> 0eed50883... resolve some merge conflicts if torch.version.hip is None: BLOCK_M = 128 BLOCK_N = 64 if Lk <= 64 else 32 @@ -889,6 +721,7 @@ def forward(ctx, q, k, v, causal, sm_scale, split_kernel=False): ) M = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) +<<<<<<< HEAD ======= o = torch.empty_like(q) BLOCK_M = 128 @@ -898,6 +731,8 @@ def forward(ctx, q, k, v, causal, sm_scale, split_kernel=False): grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1) M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) >>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 +======= +>>>>>>> 0eed50883... resolve some merge conflicts _attn_fwd[grid]( q, k, v, sm_scale, M, o, q.stride(0), q.stride(1), q.stride(2), q.stride(3), @@ -906,7 +741,6 @@ def forward(ctx, q, k, v, causal, sm_scale, split_kernel=False): o.stride(0), o.stride(1), o.stride(2), o.stride(3), q.shape[0], q.shape[1], N_CTX=q.shape[2], -<<<<<<< HEAD BLOCK_DMODEL=Lk, STAGE=stage, ) @@ -915,6 +749,7 @@ def forward(ctx, q, k, v, causal, sm_scale, split_kernel=False): best_config = _attn_fwd.get_best_config(Z = q.shape[0], H = q.shape[1], N_CTX = q.shape[2], STAGE = stage, BLOCK_DMODEL=Lk) block_m = int(best_config.__str__().split(",")[0].split("BLOCK_M:")[1]) grid = (triton.cdiv(q.shape[2], block_m), q.shape[0] * q.shape[1], 1) +<<<<<<< HEAD ======= BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, @@ -925,20 +760,19 @@ def forward(ctx, q, k, v, causal, sm_scale, split_kernel=False): ) >>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 +======= + +>>>>>>> 0eed50883... resolve some merge conflicts ctx.save_for_backward(q, k, v, o, M) ctx.grid = grid ctx.sm_scale = sm_scale ctx.BLOCK_DMODEL = Lk ctx.causal = causal -<<<<<<< HEAD ctx.split_kernel = split_kernel -======= ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 return o @staticmethod def backward(ctx, do): -<<<<<<< HEAD # configuration is not supported assert(not (ctx.split_kernel and not ctx.causal)) if torch.version.hip is not None: @@ -1011,53 +845,10 @@ def backward(ctx, do): ) # print(h.asm["ttgir"]) return dq, dk, dv, None, None, None -======= - q, k, v, o, M = ctx.saved_tensors - assert do.is_contiguous() - assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride() - dq = torch.empty_like(q) - dk = torch.empty_like(k) - dv = torch.empty_like(v) - BATCH, N_HEAD, N_CTX = q.shape[:3] - PRE_BLOCK = 128 - NUM_WARPS, NUM_STAGES = 4, 1 - BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32 - BLK_SLICE_FACTOR = 2 - RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2) - arg_k = k - arg_k = arg_k * (ctx.sm_scale * RCP_LN2) - PRE_BLOCK = 128 - assert N_CTX % PRE_BLOCK == 0 - pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD) - delta = torch.empty_like(M) - _attn_bwd_preprocess[pre_grid]( - o, do, - delta, - BATCH, N_HEAD, N_CTX, - BLOCK_M=PRE_BLOCK, D_HEAD=ctx.BLOCK_DMODEL, - ) - grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD) - _attn_bwd[grid]( - q, arg_k, v, ctx.sm_scale, do, dq, dk, dv, - M, delta, - q.stride(0), q.stride(1), q.stride(2), q.stride(3), - N_HEAD, N_CTX, - BLOCK_M1=BLOCK_M1, BLOCK_N1=BLOCK_N1, - BLOCK_M2=BLOCK_M2, BLOCK_N2=BLOCK_N2, - BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, - BLOCK_DMODEL=ctx.BLOCK_DMODEL, - num_warps=NUM_WARPS, - num_stages=NUM_STAGES, - ) - - return dq, dk, dv, None, None - ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 attention = _attention.apply -<<<<<<< HEAD @pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(4, 48, 1024, 64), (4, 48, 2048, 64), @@ -1071,17 +862,39 @@ def backward(ctx, do): @pytest.mark.parametrize('causal', [False, True]) def test_op_fwd(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16): torch.manual_seed(20) +<<<<<<< HEAD q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() if TORCH_HAS_FP8E5: q = q.to(torch_dtype) k = k.to(torch_dtype) +======= + q = ( + torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda") + .normal_(mean=0., std=0.5) + .requires_grad_() + ) + k = ( + torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda") + .normal_(mean=0.0, std=0.5) + .requires_grad_() + ) + v = ( + torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda") + .normal_(mean=0.0, std=0.5) + .requires_grad_() + ) +>>>>>>> 0eed50883... resolve some merge conflicts sm_scale = 0.5 dout = torch.randn_like(q, dtype=torch.float16) # reference implementation M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) +<<<<<<< HEAD p = torch.matmul(q.half(), k.transpose(2, 3).half()) * sm_scale +======= + p = torch.matmul(q, k.transpose(2, 3)) * sm_scale +>>>>>>> 0eed50883... resolve some merge conflicts if causal: p[:, :, M == 0] = float("-inf") p = torch.softmax(p.float(), dim=-1).half() @@ -1166,12 +979,12 @@ def test_op(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16): HAS_FLASH = False # vary seq length for fixed head and batch=4 -<<<<<<< HEAD configs = [] for mode in ['fwd', 'bwd']: for D_HEAD in [128, 64]: if mode == 'bwd' and D_HEAD == 128: continue +<<<<<<< HEAD for causal in [False, True]: if mode == 'bwd' and causal == False: continue @@ -1224,6 +1037,25 @@ def test_op(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16): for causal in [True] ] >>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 +======= + configs.append(triton.testing.Benchmark( + x_names=['N_CTX'], + x_vals=[2**i for i in range(10, 15)], + line_arg='provider', + line_vals=['triton'] + (['flash'] if HAS_FLASH else []), + line_names=['Triton'] + ([f'Flash-{FLASH_VER}'] if HAS_FLASH else []), + styles=[('red', '-'), ('blue', '-')], + ylabel='ms', + plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}-causal={causal}', + args={ + 'H': N_HEADS, + 'BATCH': BATCH, + 'D_HEAD': D_HEAD, + 'dtype': torch.float16, + 'mode': mode, + 'causal': causal}) + ) +>>>>>>> 0eed50883... resolve some merge conflicts @triton.testing.perf_report(configs) @@ -1246,13 +1078,8 @@ def bench_flash_attention( q = q.to(torch_dtype) k = k.to(torch_dtype) sm_scale = 1.3 -<<<<<<< HEAD fn = lambda: attention(q, k, v, causal, sm_scale, split_kernel) if mode == 'bwd': -======= - fn = lambda: attention(q, k, v, causal, sm_scale) - if mode == "bwd": ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 o = fn() do = torch.randn_like(o) fn = lambda: o.backward(do, retain_graph=True) diff --git a/python/tutorials/11-grouped-gemm.py b/python/tutorials/11-grouped-gemm.py index 5b4305b7931c..ee1328cd8517 100644 --- a/python/tutorials/11-grouped-gemm.py +++ b/python/tutorials/11-grouped-gemm.py @@ -1,14 +1,3 @@ -<<<<<<< HEAD -======= - -""" -Group GEMM -============================ -This group gemm kernel launches a fixed number of CTA to compute a group -of gemms. The scheduling is static and we do it on device. -""" - ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 # Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining @@ -35,16 +24,10 @@ import triton import triton.language as tl -<<<<<<< HEAD # This group gemm kernel launches a fixed number of CTA to compute a group # of gemms. The scheduling is static and we do it on device -@triton.autotune( - configs= [ -======= - @triton.autotune( configs=[ ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 triton.Config( { 'BLOCK_SIZE_M': 128, @@ -77,7 +60,6 @@ 'NUM_SM': 128, } ), -<<<<<<< HEAD ] if torch.version.hip is None else [ triton.Config( { @@ -131,10 +113,6 @@ ), ], key=['SUM_M', 'SUM_N', 'SUM_K'], -======= - ], - key=['group_size'], ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 ) @triton.jit def grouped_matmul_kernel( @@ -150,12 +128,9 @@ def grouped_matmul_kernel( g_lds, # number of gemms group_size, -<<<<<<< HEAD SUM_M: tl.constexpr, SUM_N: tl.constexpr, SUM_K: tl.constexpr, -======= ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 # number of virtual SM NUM_SM: tl.constexpr, # tile sizes @@ -236,12 +211,9 @@ def group_gemm_fn(group_A, group_B): g_sizes = [] g_lds = [] group_C = [] -<<<<<<< HEAD SUM_M = 0 SUM_N = 0 SUM_K = 0 -======= ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 for i in range(group_size): A = group_A[i] B = group_B[i] @@ -254,12 +226,9 @@ def group_gemm_fn(group_A, group_B): B_addrs.append(B.data_ptr()) C_addrs .append(C.data_ptr()) g_sizes += [M, N, K] -<<<<<<< HEAD SUM_M += M SUM_N += N SUM_K += K -======= ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 g_lds += [A.stride(0), B.stride(0), C.stride(0)] # note these are device tensors @@ -281,12 +250,9 @@ def group_gemm_fn(group_A, group_B): d_g_sizes, d_g_lds, group_size, -<<<<<<< HEAD SUM_M=SUM_M, SUM_N=SUM_N, SUM_K=SUM_K, -======= ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 ) return group_C @@ -311,7 +277,6 @@ def group_gemm_fn(group_A, group_B): tri_out = group_gemm_fn(group_A, group_B) ref_out = [torch.matmul(a, b) for a, b in zip(group_A, group_B)] -<<<<<<< HEAD rtol = 0 if torch.version.hip is None else 1e-2 for i in range(group_size): assert torch.allclose(ref_out[i], tri_out[i], atol=1e-2, rtol=rtol) @@ -319,14 +284,6 @@ def group_gemm_fn(group_A, group_B): # only launch the kernel, no tensor preparation here to remove all overhead def triton_perf_fn(a_ptrs, b_ptrs, c_ptrs, sizes, lds, group_size, sum_m, sum_n, sum_k): -======= -for i in range(group_size): - assert torch.allclose(ref_out[i], tri_out[i], atol=1e-2, rtol=0) - - -# only launch the kernel, no tensor preparation here to remove all overhead -def triton_perf_fn(a_ptrs, b_ptrs, c_ptrs, sizes, lds, group_size): ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 grid = lambda META: (META['NUM_SM'],) grouped_matmul_kernel[grid]( a_ptrs, @@ -335,12 +292,9 @@ def triton_perf_fn(a_ptrs, b_ptrs, c_ptrs, sizes, lds, group_size): sizes, lds, group_size, -<<<<<<< HEAD sum_m, sum_n, sum_k, -======= ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 ) @@ -401,11 +355,7 @@ def benchmark(N, provider): if provider == 'cublas': ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch_perf_fn(group_A, group_B), quantiles=quantiles) if provider == 'triton': -<<<<<<< HEAD ms, min_ms, max_ms = triton.testing.do_bench(lambda: triton_perf_fn(d_a_ptrs, d_b_ptrs, d_c_ptrs, d_g_sizes, d_g_lds, group_size, group_size*N, group_size*N, group_size*N), quantiles=quantiles) -======= - ms, min_ms, max_ms = triton.testing.do_bench(lambda: triton_perf_fn(d_a_ptrs, d_b_ptrs, d_c_ptrs, d_g_sizes, d_g_lds, group_size), quantiles=quantiles) ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 return ms, max_ms, min_ms diff --git a/test/Conversion/minimize_alloc.mlir b/test/Conversion/minimize_alloc.mlir index 8a33aa8bfa6c..f534cbdab3f9 100644 --- a/test/Conversion/minimize_alloc.mlir +++ b/test/Conversion/minimize_alloc.mlir @@ -84,7 +84,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %66 = tt.load %63 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x64xf16, #blocked1> %67 = triton_gpu.convert_layout %60 : (tensor<64x32xf16, #shared>) -> tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 8}>> %68 = triton_gpu.convert_layout %61 : (tensor<32x64xf16, #shared1>) -> tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 8}>> - %69 = tt.dot %67, %68, %59 {allowTF32 = true} : tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 8}>> * tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 8}>> -> tensor<64x64xf32, #mfma> + %69 = tt.dot %67, %68, %59 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 8}>> * tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 8}>> -> tensor<64x64xf32, #mfma> %70 = tt.addptr %62, %cst_0 : tensor<64x32x!tt.ptr, #blocked>, tensor<64x32xi32, #blocked> %71 = tt.addptr %63, %50 : tensor<32x64x!tt.ptr, #blocked1>, tensor<32x64xi32, #blocked1> %72 = triton_gpu.convert_layout %65 : (tensor<64x32xf16, #blocked>) -> tensor<64x32xf16, #shared> @@ -94,7 +94,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : ^bb3: // pred: ^bb1 %75 = triton_gpu.convert_layout %60 : (tensor<64x32xf16, #shared>) -> tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 8}>> %76 = triton_gpu.convert_layout %61 : (tensor<32x64xf16, #shared1>) -> tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 8}>> - %77 = tt.dot %75, %76, %59 {allowTF32 = true} : tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 8}>> * tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 8}>> -> tensor<64x64xf32, #mfma> + %77 = tt.dot %75, %76, %59 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 8}>> * tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 8}>> -> tensor<64x64xf32, #mfma> %78 = arith.truncf %77 : tensor<64x64xf32, #mfma> to tensor<64x64xf16, #mfma> %79 = tt.splat %arg8 : (i32) -> tensor<64x1xi32, #blocked1> %80 = arith.muli %79, %27 : tensor<64x1xi32, #blocked1> diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 9b6110fed40e..d1574c24f4f7 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -1283,19 +1283,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : %BB_DOT = triton_gpu.convert_layout %BB : (tensor<16x16xf16, #shared0>) -> tensor<16x16xf16, #dot_operand_b> %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma0> -<<<<<<< HEAD // PTX: llvm.inline_asm // PTX-SAME: mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 // PTX: llvm.inline_asm // PTX-SAME: mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 - %D = tt.dot %AA_DOT, %BB_DOT, %cst0 {allowTF32 = true, transA = false, transB = false} : tensor<16x16xf16, #dot_operand_a> * tensor<16x16xf16, #dot_operand_b> -> tensor<16x16xf32, #mma0> -======= - // CHECK: llvm.inline_asm - // CHECK-SAME: mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 - // CHECK: llvm.inline_asm - // CHECK-SAME: mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 %D = tt.dot %AA_DOT, %BB_DOT, %cst0 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<16x16xf16, #dot_operand_a> * tensor<16x16xf16, #dot_operand_b> -> tensor<16x16xf32, #mma0> ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 tt.return } @@ -1329,7 +1321,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : %cst0 = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mfma0> // GCN-COUNT-4: rocdl.mfma.f32.32x32x8f16 - %D = tt.dot %AA_DOT, %BB_DOT, %cst0 {allowTF32 = true, transA = false, transB = false} : tensor<32x32xf16, #dot_operand_a> * tensor<32x32xf16, #dot_operand_b> -> tensor<32x32xf32, #mfma0> + %D = tt.dot %AA_DOT, %BB_DOT, %cst0 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<32x32xf16, #dot_operand_a> * tensor<32x32xf16, #dot_operand_b> -> tensor<32x32xf32, #mfma0> tt.return } @@ -1501,7 +1493,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %a_mat = triton_gpu.convert_layout %a : (tensor<128x32xf16, #shared>) -> tensor<128x32xf16, #dot_operand_a> %b_mat = triton_gpu.convert_layout %b : (tensor<32x256xf16, #shared>) -> tensor<32x256xf16, #dot_operand_b> - %28 = tt.dot %a_mat, %b_mat, %cst {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #dot_operand_a> * tensor<32x256xf16, #dot_operand_b> -> tensor<128x256xf32, #mfma> + %28 = tt.dot %a_mat, %b_mat, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #dot_operand_a> * tensor<32x256xf16, #dot_operand_b> -> tensor<128x256xf32, #mfma> // GCN-COUNT-32: rocdl.mfma.f32.32x32x8f16 %38 = triton_gpu.convert_layout %28 : (tensor<128x256xf32, #mfma>) -> tensor<128x256xf32, #blocked> @@ -1583,7 +1575,6 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %a_mat = triton_gpu.convert_layout %a : (tensor<32x16xf32, #shared>) -> tensor<32x16xf32, #dot_operand_a> %b_mat = triton_gpu.convert_layout %b : (tensor<16x32xf32, #shared>) -> tensor<16x32xf32, #dot_operand_b> -<<<<<<< HEAD // PTX: llvm.inline_asm // PTX-SAME: mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 // PTX: llvm.inline_asm @@ -1592,18 +1583,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // PTX-SAME: mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 // PTX: llvm.inline_asm // PTX-SAME: mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 - %28 = tt.dot %a_mat, %b_mat, %cst {allowTF32 = true, transA = false, transB = false} : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #mma> -======= - // CHECK: llvm.inline_asm - // CHECK-SAME: mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 - // CHECK: llvm.inline_asm - // CHECK-SAME: mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 - // CHECK: llvm.inline_asm - // CHECK-SAME: mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 - // CHECK: llvm.inline_asm - // CHECK-SAME: mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 %28 = tt.dot %a_mat, %b_mat, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #mma> ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 %38 = triton_gpu.convert_layout %28 : (tensor<32x32xf32, #mma>) -> tensor<32x32xf32, #blocked> %30 = tt.splat %ptr : (!tt.ptr) -> tensor<32x1x!tt.ptr, #blocked> @@ -1925,12 +1905,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // ----- // CHECK-LABEL: sum_reduction -<<<<<<< HEAD // PTX: %[[M:.+]] = llvm.mlir.constant(-1 : i32) : i32 // PTX: nvvm.redux.sync add %{{.*}}, %[[M]] // PTX: nvvm.barrier0 -// PTX: shfl.sync.bfly.b32 -// PTX: shfl.sync.bfly.b32 +// PTX: nvvm.shfl.sync bfly +// PTX: nvvm.shfl.sync bfly // PTX: nvvm.barrier0 // GCN-COUNT-4: ds_swizzle_b32 @@ -1946,14 +1925,6 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // GCN: rocdl.barrier // GCN: llvm.load // GCN: llvm.store -======= -// CHECK: %[[M:.+]] = llvm.mlir.constant(-1 : i32) : i32 -// CHECK: nvvm.redux.sync add %{{.*}}, %[[M]] -// CHECK: nvvm.barrier0 -// CHECK: nvvm.shfl.sync bfly -// CHECK: nvvm.shfl.sync bfly -// CHECK: nvvm.barrier0 ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 #blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> #blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { @@ -2041,7 +2012,26 @@ module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-c // ----- -<<<<<<< HEAD +// CHECK-LABEL: copyitem +// CHECK: st.shared.b8 +// CHECK: ld.shared.b8 +// CHECK-NOT: st.shared.b1 +// CHECK-NOT: ld.shared.b1 +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @copyitem() attributes {noinline = false} { + %cst = arith.constant dense : tensor<4x1xi1, #blocked> + %0 = "tt.reduce"(%cst) <{axis = 1 : i32}> ({ + ^bb0(%arg0: i1, %arg1: i1): + %1 = arith.ori %arg0, %arg1 : i1 + tt.reduce.return %1 : i1 + }) : (tensor<4x1xi1, #blocked>) -> tensor<4xi1, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + tt.return + } +} + +// ----- + #blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 4], warpsPerCTA = [2, 1], order = [1, 0], CTAsPerCGA = [1,1], CTASplitNum = [1,1], CTAOrder = [1, 0]}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32} { // CHECK-LABEL: atomic_add_f16 @@ -2059,22 +2049,6 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : // PTX-SAME: @$3 atom.global.gpu.add.noftz.f16x2 // GCN-COUNT-8: llvm.atomicrmw fadd {{.*}} monotonic : !llvm.ptr, f16 %8 = "tt.atomic_rmw"(%5, %6, %7) {atomic_rmw_op = 5 : i32, sem = 1: i32} : (tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xf16, #blocked>, tensor<32x32xi1, #blocked>) -> tensor<32x32xf16, #blocked> -======= -// CHECK-LABEL: copyitem -// CHECK: st.shared.b8 -// CHECK: ld.shared.b8 -// CHECK-NOT: st.shared.b1 -// CHECK-NOT: ld.shared.b1 -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> -module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { - tt.func public @copyitem() attributes {noinline = false} { - %cst = arith.constant dense : tensor<4x1xi1, #blocked> - %0 = "tt.reduce"(%cst) <{axis = 1 : i32}> ({ - ^bb0(%arg0: i1, %arg1: i1): - %1 = arith.ori %arg0, %arg1 : i1 - tt.reduce.return %1 : i1 - }) : (tensor<4x1xi1, #blocked>) -> tensor<4xi1, #triton_gpu.slice<{dim = 1, parent = #blocked}>> ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 tt.return } } diff --git a/test/TritonGPU/stream-pipeline.mlir b/test/TritonGPU/stream-pipeline.mlir index e6ab4e5df012..ffa574b58797 100644 --- a/test/TritonGPU/stream-pipeline.mlir +++ b/test/TritonGPU/stream-pipeline.mlir @@ -69,7 +69,7 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %b_ = triton_gpu.convert_layout %b__ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B> %b = arith.mulf %b_, %b_scale: tensor<32x128xf16, #B> - %c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + %c = tt.dot %a, %b, %prev_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> @@ -133,7 +133,7 @@ tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index, %b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL> %b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B> - %c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + %c = tt.dot %a, %b, %prev_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> @@ -194,7 +194,7 @@ tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, %loop:2 = scf.for %iv = %lb to %ub step %step iter_args(%b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { %b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL> %b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B> - %c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + %c = tt.dot %a, %b, %prev_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> scf.yield %next_b_ptr, %c : tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> } @@ -243,7 +243,7 @@ tt.func @lut_bmm_scalar(%77: i64 {tt.divisibility=16: i32}, %87 = tt.load %86 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #BL> %88 = triton_gpu.convert_layout %82 : (tensor<16x16xf16, #AL>) -> tensor<16x16xf16, #A> %89 = triton_gpu.convert_layout %87 : (tensor<16x16xf16, #BL>) -> tensor<16x16xf16, #B> - %90 = tt.dot %88, %89, %arg19 {allowTF32 = true} : tensor<16x16xf16, #A> * tensor<16x16xf16, #B> -> tensor<16x16xf32, #C> + %90 = tt.dot %88, %89, %arg19 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xf16, #A> * tensor<16x16xf16, #B> -> tensor<16x16xf32, #C> %91 = tt.addptr %arg20, %78 : tensor<16x16x!tt.ptr, #AL>, tensor<16x16xi32, #AL> %92 = tt.addptr %arg21, %c1_i32 : !tt.ptr, i32 scf.yield %90, %91, %92 : tensor<16x16xf32, #C>, tensor<16x16x!tt.ptr, #AL>, !tt.ptr @@ -294,7 +294,7 @@ tt.func @lut_bmm_vector(%77: tensor<16x16xi64, #BL> {tt.divisibility=16: i32, tt %87 = tt.load %86 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #BL> %88 = triton_gpu.convert_layout %82 : (tensor<16x16xf16, #AL>) -> tensor<16x16xf16, #A> %89 = triton_gpu.convert_layout %87 : (tensor<16x16xf16, #BL>) -> tensor<16x16xf16, #B> - %90 = tt.dot %88, %89, %arg19 {allowTF32 = true} : tensor<16x16xf16, #A> * tensor<16x16xf16, #B> -> tensor<16x16xf32, #C> + %90 = tt.dot %88, %89, %arg19 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xf16, #A> * tensor<16x16xf16, #B> -> tensor<16x16xf32, #C> %91 = tt.addptr %arg20, %78 : tensor<16x16x!tt.ptr, #AL>, tensor<16x16xi32, #AL> %92 = tt.addptr %arg21, %c1_i32_splat : tensor<16x!tt.ptr, #BLs1>, tensor<16xi32, #BLs1> scf.yield %90, %91, %92 : tensor<16x16xf32, #C>, tensor<16x16x!tt.ptr, #AL>, tensor<16x!tt.ptr, #BLs1> @@ -346,7 +346,7 @@ tt.func @post_load_inv(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %116 = tt.load %arg12, %115, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32xf32, #AL> %117 = triton_gpu.convert_layout %112 : (tensor<32x32xf32, #AL>) -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>> %118 = triton_gpu.convert_layout %116 : (tensor<32x32xf32, #AL>) -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>> - %119 = tt.dot %117, %118, %arg10 {allowTF32 = true} : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>> -> tensor<32x32xf32, #C> + %119 = tt.dot %117, %118, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>> -> tensor<32x32xf32, #C> %131 = arith.index_cast %arg9 : index to i32 %120 = arith.addi %131, %c1_i32 : i32 %121 = arith.muli %120, %c32_i32 : i32 @@ -404,7 +404,7 @@ tt.func @cross_iter_dep(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %150 = tt.load %arg12, %149, %cst_1 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32xf32, #AL> %151 = triton_gpu.convert_layout %146 : (tensor<32x32xf32, #AL>) -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>> %152 = triton_gpu.convert_layout %150 : (tensor<32x32xf32, #AL>) -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>> - %153 = tt.dot %151, %152, %arg10 {allowTF32 = true} : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>> -> tensor<32x32xf32, #C> + %153 = tt.dot %151, %152, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>> -> tensor<32x32xf32, #C> %162 = arith.index_cast %arg9 : index to i32 %154 = arith.addi %162, %c2_i32 : i32 %155 = arith.muli %154, %c32_i32 : i32 @@ -529,7 +529,7 @@ tt.func @matmul_mixed_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : %96 = tt.fp_to_fp %91 : tensor<64x32xf8E4M3FNUZ, #blocked1> -> tensor<64x32xf16, #blocked1> %97 = triton_gpu.convert_layout %96 : (tensor<64x32xf16, #blocked1>) -> tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> %98 = triton_gpu.convert_layout %95 : (tensor<32x32xf16, #blocked2>) -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> - %99 = tt.dot %97, %98, %arg10 {allowTF32 = true} : tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<64x32xf16, #blocked> + %99 = tt.dot %97, %98, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<64x32xf16, #blocked> %100 = tt.addptr %arg11, %cst_0 : tensor<64x32x!tt.ptr, #blocked1>, tensor<64x32xi32, #blocked1> %101 = tt.addptr %arg12, %65 : tensor<32x32x!tt.ptr, #blocked2>, tensor<32x32xi32, #blocked2> scf.yield %99, %100, %101 : tensor<64x32xf16, #blocked>, tensor<64x32x!tt.ptr, #blocked1>, tensor<32x32x!tt.ptr, #blocked2> From c3132eeda8fb8c9e25398e459a3fcbadb39a2f08 Mon Sep 17 00:00:00 2001 From: Jason Furmanek Date: Tue, 31 Oct 2023 14:21:13 +0000 Subject: [PATCH 117/122] ROCM IFU: Third-party fixes: preload ROCDL Additional 3rd-party fix Remove redundant mfma_supported defines --- .../TritonGPUToLLVM/TritonGPUToLLVMPass.cpp | 2 +- python/triton/compiler/compiler.py | 2 +- python/triton/language/semantic.py | 13 ------------- python/triton/third_party/hip/hip_backend.py | 2 +- 4 files changed, 3 insertions(+), 16 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp index 94ed93fa1b73..1e4d6443f9b9 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp @@ -401,7 +401,7 @@ struct ConvertTritonGPUToLLVM void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); + NVVM::NVVMDialect, ROCDL::ROCDLDialect>(); } ConvertTritonGPUToLLVM(int32_t computeCapability, Target target, diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index 89c7e1999925..66c719c23fd2 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -360,7 +360,7 @@ def get_cuda_capability(capability): @functools.lru_cache def get_arch_default_num_warps(device_type): - if device_type in ["cuda"]: + if device_type in ["cuda", "hip"]: num_warps = 4 else: _device_backend = get_backend(device_type) diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index a5d38a6d0318..cc69a756d8d7 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1334,19 +1334,6 @@ def mfma_supported(M, N, K, allow_tf32, ret_scalar_ty) -> bool: return False return True -def gpu_has_mfma() -> bool: - if not is_hip(): - return False - return True # mfma supported in ['gfx908', 'gfx90a'] - - -def mfma_supported(M, N, K, allow_tf32, ret_scalar_ty) -> bool: - if not gpu_has_mfma(): - return False - # TODO: Add check for configurations and types. - return True - - def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, diff --git a/python/triton/third_party/hip/hip_backend.py b/python/triton/third_party/hip/hip_backend.py index 00d547e10909..53a0c1d32ef9 100644 --- a/python/triton/third_party/hip/hip_backend.py +++ b/python/triton/third_party/hip/hip_backend.py @@ -64,7 +64,7 @@ def ty_to_cpp(ty): def generate_launcher_hip(constants, signature, ids): start_desc = len(signature) - signature = generate_cu_signature(constants, signature, ids) + #signature = generate_cu_signature(constants, signature, ids) arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) def _extracted_type(ty): From 39e8901d7a4dbb3cdef43384d1e435a282541b30 Mon Sep 17 00:00:00 2001 From: Jason Furmanek Date: Tue, 7 Nov 2023 03:10:04 +0000 Subject: [PATCH 118/122] ROCM IFU: Resolve merge conflicts in RemoveLayoutConversions.cpp fix merge error fix dot fix make_range additional fix --- .../Transforms/RemoveLayoutConversions.cpp | 23 +------ python/triton/compiler/compiler.py | 2 +- python/triton/language/semantic.py | 62 ++++++++++--------- 3 files changed, 35 insertions(+), 52 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp index 989bc5aa0944..af6ae1c301af 100644 --- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp @@ -1003,22 +1003,11 @@ static void hoistConvertOnTopOfExtOrBroadcast(ConvertLayoutOp convertOp) { if (targetType.getEncoding().isa()) return; -<<<<<<< HEAD -#ifndef USE_ROCM - auto isExtOp = [](Operation *op) { - return isa(op); -======= auto isExtOrBroadcastOp = [](Operation *op) { return isa(op); ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 }; -#else - auto isExtOp = [](Operation *op) { - return isa(op); - }; -#endif + // 1. Take a backward slice of all the tensor dependencies. SetVector slice; DenseMap layout; @@ -1064,22 +1053,12 @@ static void hoistConvertOnTopOfExtOrBroadcast(ConvertLayoutOp convertOp) { inferSrcEncoding(extOrBroadcatOp, layout[extOrBroadcatOp->getResult(0)]); if (!srcEncoding) return; - std::optional srcEncoding = - inferSrcEncoding(extOp, layout[extOp->getResult(0)]); // Move the convert before the ext op and rewrite the slice. -<<<<<<< HEAD - OpBuilder builder(extOp); - auto tensorType = extOp->getOperand(0).getType().cast(); - auto newType = - RankedTensorType::get(tensorType.getShape(), tensorType.getElementType(), - *srcEncoding); -======= OpBuilder builder(extOrBroadcatOp); auto tensorType = extOrBroadcatOp->getOperand(0).getType().cast(); auto newType = RankedTensorType::get( tensorType.getShape(), tensorType.getElementType(), *srcEncoding); ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 auto newConvertOp = builder.create( convertOp.getLoc(), newType, extOrBroadcatOp->getOperand(0)); IRMapping mapping; diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index 66c719c23fd2..c8d57d0efd11 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -313,7 +313,7 @@ def make_hash(fn, target, env_vars, **kwargs): "ptx": ptx_arg_type_pattern, } if is_hip(): - ttgir_num_warps_pattern = r'"triton_gpu_rocm.num-warps"\s?=\s?(\d+)\s?:' + ttgir_num_warps_pattern = r'"triton_gpu.num-warps"\s?=\s?(\d+)\s?:' else: ttgir_num_warps_pattern = r'"triton_gpu.num-warps"\s?=\s?(\d+)\s?:' diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index cc69a756d8d7..dea0787139bc 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1342,38 +1342,32 @@ def dot(lhs: tl.tensor, out_dtype: tl.dtype, builder: ir.builder) -> tl.tensor: def assert_dtypes_valid(lhs_dtype, rhs_dtype, target): - if is_hip(): + # Checks for non-cuda archs + if not _is_cuda(target): assert lhs.dtype == rhs.dtype or (lhs.type.scalar.is_fp8() and rhs.type.scalar.is_fp16()) or \ (lhs.type.scalar.is_fp16() and rhs.type.scalar.is_fp8()) or (lhs.type.scalar.is_fp8() and rhs.type.scalar.is_fp8()), \ f"First input ({lhs.dtype}) and second input ({rhs.dtype}) must have the same dtype!" - return - # Checks for non-cuda archs + # Checks for cuda archs if target.capability < 90: - # Checks for cuda arch - if arch < 90: - assert not lhs_dtype.is_fp8e4nv() and not rhs_dtype.is_fp8e4nv(), "Dot op does not support fp8e4nv on CUDA arch < 90" - if lhs_dtype.is_fp8() and rhs_dtype.is_fp8(): - return - assert lhs_dtype == rhs_dtype, f"First input ({lhs_dtype}) and second input ({rhs_dtype}) must have the same dtype!" + assert not lhs_dtype.is_fp8e4nv() and not rhs_dtype.is_fp8e4nv(), "Dot op does not support fp8e4nv on CUDA arch < 90" + if lhs_dtype.is_fp8() and rhs_dtype.is_fp8(): + return + assert lhs_dtype == rhs_dtype, f"First input ({lhs_dtype}) and second input ({rhs_dtype}) must have the same dtype!" + else: + assert not lhs_dtype.is_fp8e4b15() and not rhs_dtype.is_fp8e4b15(), "Dot op does not support fp8e4b15 on CUDA arch >= 90" + assert not lhs_dtype.is_fp8e4b15x4() and not rhs_dtype.is_fp8e4b15x4(), "Dot op does not support fp8e4b15x4 on CUDA arch >= 90" + if lhs_dtype.is_int() or rhs_dtype.is_int(): + assert lhs_dtype == rhs_dtype, f"Both operands must be same type. First operand ({lhs_dtype}) and second operand ({rhs_dtype})" + assert lhs_dtype.is_int8() or lhs_dtype.is_uint8(), f"Both operands must be either int8 or uint8. Operand type ({lhs_dtype})" + elif lhs_dtype.is_fp8() or rhs_dtype.is_fp8(): + assert lhs_dtype.is_fp8e4nv() or lhs_dtype.is_fp8e5(), f"Only supports fp8e4nv or fp8e5. First operand ({lhs_dtype})" + assert rhs_dtype.is_fp8e4nv() or rhs_dtype.is_fp8e5(), f"Only supports fp8e4nv or fp8e5. Second operand ({rhs_dtype})" else: - assert not lhs_dtype.is_fp8e4b15() and not rhs_dtype.is_fp8e4b15(), "Dot op does not support fp8e4b15 on CUDA arch >= 90" - assert not lhs_dtype.is_fp8e4b15x4() and not rhs_dtype.is_fp8e4b15x4(), "Dot op does not support fp8e4b15x4 on CUDA arch >= 90" - if lhs_dtype.is_int() or rhs_dtype.is_int(): - assert lhs_dtype == rhs_dtype, f"Both operands must be same type. First operand ({lhs_dtype}) and second operand ({rhs_dtype})" - assert lhs_dtype.is_int8() or lhs_dtype.is_uint8(), f"Both operands must be either int8 or uint8. Operand type ({lhs_dtype})" - elif lhs_dtype.is_fp8() or rhs_dtype.is_fp8(): - assert lhs_dtype.is_fp8e4nv() or lhs_dtype.is_fp8e5(), f"Only supports fp8e4nv or fp8e5. First operand ({lhs_dtype})" - assert rhs_dtype.is_fp8e4nv() or rhs_dtype.is_fp8e5(), f"Only supports fp8e4nv or fp8e5. Second operand ({rhs_dtype})" - else: - assert lhs_dtype.is_fp16() or lhs_dtype.is_bf16() or lhs_dtype.is_fp32() or lhs_dtype.is_int1(), f"Unsupported dtype {lhs_dtype}" - assert rhs_dtype.is_fp16() or rhs_dtype.is_bf16() or rhs_dtype.is_fp32() or rhs_dtype.is_int1(), f"Unsupported dtype {rhs_dtype}" - assert lhs_dtype == rhs_dtype, f"First input ({lhs_dtype}) and second input ({rhs_dtype}) must have the same dtype!" - return - - assert lhs_dtype == rhs_dtype, f"First input ({lhs_dtype}) and second input ({rhs_dtype}) must have the same dtype!" - return + assert lhs_dtype.is_fp16() or lhs_dtype.is_bf16() or lhs_dtype.is_fp32() or lhs_dtype.is_int1(), f"Unsupported dtype {lhs_dtype}" + assert rhs_dtype.is_fp16() or rhs_dtype.is_bf16() or rhs_dtype.is_fp32() or rhs_dtype.is_int1(), f"Unsupported dtype {rhs_dtype}" + assert lhs_dtype == rhs_dtype, f"First input ({lhs_dtype}) and second input ({rhs_dtype}) must have the same dtype!" assert lhs.type.is_block() and rhs.type.is_block() assert_dtypes_valid(lhs.dtype, rhs.dtype, builder.target) @@ -1396,7 +1390,6 @@ def assert_dtypes_valid(lhs_dtype, rhs_dtype, target): if not supported_fp8_dot and rhs_fp8: rhs = cast(rhs, tl.float16, builder) - if lhs.type.scalar.is_int(): assert lhs.type.scalar == tl.int8, "only int8 supported!" # TODO: This is CUDA specific, check if ROCm has the same limitation @@ -1417,6 +1410,11 @@ def assert_dtypes_valid(lhs_dtype, rhs_dtype, target): # Cast operands of types f16 and i8 for configurations where FMA only supported. if is_hip() and not mfma_supported(M, N, lhs.type.shape[1], allow_tf32, ret_scalar_ty): + # max_num_imprecise_acc does not yet apply to hip + if is_hip(): + max_num_imprecise_acc = 0 + if max_num_imprecise_acc is None: + max_num_imprecise_acc = 2**30 ret_cast_scalar_ty = tl.float32 if lhs.type.scalar.is_int() else ret_scalar_ty lhs = cast(lhs, ret_cast_scalar_ty, builder) rhs = cast(rhs, ret_cast_scalar_ty, builder) @@ -1425,10 +1423,16 @@ def assert_dtypes_valid(lhs_dtype, rhs_dtype, target): else: _0 = builder.create_splat(builder.get_fp32(0), [M, N]) ret_ty = tl.block_type(ret_cast_scalar_ty, [M, N]) - ret = tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32), + ret = tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32, max_num_imprecise_acc), ret_ty) return cast(ret, ret_scalar_ty, builder) - if is_hip() and mfma_supported(M, N, lhs.type.shape[1], allow_tf32, ret_scalar_ty) and ret_scalar_ty.primitive_bitwidth < 32: + if is_hip() and mfma_supported(M, N, lhs.type.shape[1], allow_tf32, ret_scalar_ty) and ret_scalar_ty.primitive_bitwidth <= 32: + # max_num_imprecise_acc does not yet apply to hip + if is_hip(): + max_num_imprecise_acc = 0 + if max_num_imprecise_acc is None: + max_num_imprecise_acc = 2**30 + if lhs.type.scalar.is_int(): ret_dot_scalar_ty = tl.int32 _0 = builder.create_splat(builder.get_int32(0), [M, N]) @@ -1436,7 +1440,7 @@ def assert_dtypes_valid(lhs_dtype, rhs_dtype, target): ret_dot_scalar_ty = tl.float32 _0 = builder.create_splat(builder.get_fp32(0), [M, N]) ret_ty = tl.block_type(ret_dot_scalar_ty, [M, N]) - ret = tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32), + ret = tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32, max_num_imprecise_acc), ret_ty) return cast(ret, ret_scalar_ty, builder) From 8bc417b9b71659503bb088733c6a6be48d35ec83 Mon Sep 17 00:00:00 2001 From: Alexander Efimov Date: Thu, 2 Nov 2023 14:44:41 +0000 Subject: [PATCH 119/122] do not emit nvidia inline asm --- lib/Conversion/TritonGPUToLLVM/Utility.cpp | 4 ++++ test/Conversion/tritongpu_to_llvm.mlir | 10 ++++++---- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index f0ced79ce0e9..0529320e9cc1 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -258,6 +258,9 @@ Value storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr, Value loadShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr, Value pred) { +#if USE_ROCM + return load(ptr); +#else MLIRContext *ctx = rewriter.getContext(); auto ptrTy = ptr.getType().cast(); assert(ptrTy.getAddressSpace() == 3 && "Invalid addr space for loadShared"); @@ -272,6 +275,7 @@ Value loadShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr, auto &ld = builder.create<>("ld")->shared().b(bitwidth); ld(dOpr, ptrOpr).predicate(pred, "b"); return builder.launch(rewriter, loc, elemTy); +#endif } static Value commonShflSync(Location loc, ConversionPatternRewriter &rewriter, diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index d1574c24f4f7..c265f2f442b5 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -2013,10 +2013,12 @@ module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-c // ----- // CHECK-LABEL: copyitem -// CHECK: st.shared.b8 -// CHECK: ld.shared.b8 -// CHECK-NOT: st.shared.b1 -// CHECK-NOT: ld.shared.b1 +// GCN: llvm.store +// GCN: llvm.load +// PTX: st.shared.b8 +// PTX: ld.shared.b8 +// PTX-NOT: st.shared.b1 +// PTX-NOT: ld.shared.b1 #blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { tt.func public @copyitem() attributes {noinline = false} { From aefc94bd25ecdace45d99bee2a31aa732fba5424 Mon Sep 17 00:00:00 2001 From: Alexander Efimov Date: Thu, 2 Nov 2023 17:40:37 +0000 Subject: [PATCH 120/122] ROCM IFU: fix test_dot_mfma_vector_load test fix for previous commit --- python/test/unit/language/test_core_amd.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/python/test/unit/language/test_core_amd.py b/python/test/unit/language/test_core_amd.py index ce4bfcf2f596..6264ec62fef4 100644 --- a/python/test/unit/language/test_core_amd.py +++ b/python/test/unit/language/test_core_amd.py @@ -2718,11 +2718,6 @@ def __str__(self): return f"#{GPU_DIALECT}.shared<{{vec={self.vec}, perPhase={self.per_phase}, maxPhase={self.max_phase}, order={self.order}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>" -def get_gpu_name(): - arch = triton.compiler.compiler.get_architecture_descriptor(None) - return arch["gfx_arch"] - - @pytest.mark.parametrize("vec_size", [2, 4]) @pytest.mark.parametrize("swizzle", [True, False]) @pytest.mark.parametrize("transposeA", [True, False]) @@ -2795,7 +2790,7 @@ def test_dot_mfma_vector_load(vec_size, swizzle, transposeA, transposeB): %21 = tt.load %13 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32xf16, #blocked> %22 = triton_gpu.convert_layout %21 : (tensor<32x32xf16, #blocked>) -> tensor<32x32xf16, #shared2> %23 = triton_gpu.convert_layout %22 : (tensor<32x32xf16, #shared2>) -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth=4}>> - %24 = tt.dot %20, %23, %cst {allowTF32 = false} : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth=4}>> * tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth=4}>> -> tensor<32x32xf32, #mfma> + %24 = tt.dot %20, %23, %cst {allowTF32 = false, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth=4}>> * tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth=4}>> -> tensor<32x32xf32, #mfma> %25 = triton_gpu.convert_layout %24 : (tensor<32x32xf32, #mfma>) -> tensor<32x32xf32, #blocked> %26 = arith.truncf %25 : tensor<32x32xf32, #blocked> to tensor<32x32xf16, #blocked> tt.store %17, %26 {cache = 1 : i32, evict = 1 : i32} : tensor<32x32xf16, #blocked> @@ -2817,11 +2812,8 @@ def test_dot_mfma_vector_load(vec_size, swizzle, transposeA, transposeB): with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: f.write(ir) f.flush() - arch_triple = "amdgcn-amd-amdhsa" - arch_name = get_gpu_name() - features = "" - warp_size = 64 - capabilities = [arch_triple, arch_name, features, warp_size] + backend = triton.common.backend.get_backend("hip") + capabilities = backend.get_architecture_descriptor() kernel = triton.compile(f.name, device_type="hip", cc=capabilities) import triton.language.semantic as sem From 502525ff11cdd84adde0ec2428032363fff8a309 Mon Sep 17 00:00:00 2001 From: oplavsic <130548569+oplavsic@users.noreply.github.com> Date: Mon, 6 Nov 2023 17:08:26 +0100 Subject: [PATCH 121/122] ROCM IFU: Fix ScanOp test by implementing idx ShflKind in commonShflSync (#391) * Fix ScanOp test * Remove commented-out code --------- Co-authored-by: Ognjen Fix typo in commonShflSync --- lib/Conversion/TritonGPUToLLVM/Utility.cpp | 49 +++++++++++----------- 1 file changed, 25 insertions(+), 24 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index 0529320e9cc1..7d08533a459f 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -282,14 +282,6 @@ static Value commonShflSync(Location loc, ConversionPatternRewriter &rewriter, Value val, Value i, int strideInt, NVVM::ShflKind mode, Value clamp, Value laneId = Value()) { unsigned bits = val.getType().getIntOrFloatBitWidth(); - //int stride = i.cast(); - //int stride = i.dyn_cast(); - //int stride = i.cast().getValue().getSExtValue(); - //int stride = i.Value(); - //constantOp.getValue().cast().getValue().getSExtValue(); - //unsigned strideint = i.cast().getValue().getSExtValue(); - //auto intAttr = i.dyn_cast_or_null(); - //auto strideint = intAttr.getValue().getSExtValue(); #ifdef USE_ROCM //On AMD, the ds_swizzle_b32 and ds_permute_b32 instructions work on 32bit/dwords @@ -316,6 +308,21 @@ static Value commonShflSync(Location loc, ConversionPatternRewriter &rewriter, #ifdef USE_ROCM GCNBuilder builder; + + auto permute = [&](Value lane, StringRef permuteInstStr) { + assert(permuteInstStr == "ds_permute_b32" || + permuteInstStr == "ds_bpermute_b32"); + // Multiple lineId by 4. (More on permute instruction semantics: + // https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/instruction-set-architectures/instinct-mi200-cdna2-instruction-set-architecture.pdf#page=180 + Value byteOffset = i32_val(2); + Value permuteAddr = shl(lane, byteOffset); + auto shfl = builder.create(permuteInstStr.str()); + auto dOpr = builder.newOperand("=v"); + auto addrOpr = builder.newOperand(permuteAddr, "v"); + auto aOpr = builder.newOperand(val, "v"); + (*shfl)(dOpr, addrOpr, aOpr); + }; + switch (mode) { case NVVM::ShflKind::bfly: if (strideInt > 16) { @@ -327,14 +334,8 @@ static Value commonShflSync(Location loc, ConversionPatternRewriter &rewriter, loc, rewriter.getIndexType(), ::mlir::gpu::Dimension::x)}) .getResult(0); Value stride = i32_val(32); - Value byteOffset = i32_val(2); Value lineId = add(threadId, stride); - Value permuteAddr = shl(lineId, byteOffset); - auto shfl = builder.create("ds_permute_b32"); - auto dOpr = builder.newOperand("=v"); - auto addrOpr = builder.newOperand(permuteAddr, "v"); - auto aOpr = builder.newOperand(val, "v"); - (*shfl)(dOpr, addrOpr, aOpr); + permute(lineId, "ds_permute_b32"); } else { // This map facilates the butterfly shuffle pattern for a stride less // than 16. The pattern stride is the key of the map. @@ -348,18 +349,18 @@ static Value commonShflSync(Location loc, ConversionPatternRewriter &rewriter, (*shfl)(dOpr, aOpr, maskOpr); } break; - case NVVM::ShflKind::up: - //assert(shuffleType == "up" && "Only shfl_bfly and shfl_up are supported"); + case NVVM::ShflKind::up: { Value mask = icmp_slt(laneId, i); Value delta = sub(laneId, i); Value index = select(mask, laneId, delta); - Value byteOffset = i32_val(2); - Value permuteAddr = shl(index, byteOffset); - auto shfl = builder.create("ds_bpermute_b32"); - auto dOpr = builder.newOperand("=v"); - auto addrOpr = builder.newOperand(permuteAddr, "v"); - auto aOpr = builder.newOperand(val, "v"); - (*shfl)(dOpr, addrOpr, aOpr); + permute(index, "ds_bpermute_b32"); + break; + } + case NVVM::ShflKind::idx: + permute(i, "ds_bpermute_b32"); + break; + default: + assert(false && "Unsupported ShflKind"); break; } From 85216ea5c5b62d15c0fa2d9c5fb78e5566cca1a9 Mon Sep 17 00:00:00 2001 From: Jason Furmanek Date: Tue, 7 Nov 2023 03:26:02 +0000 Subject: [PATCH 122/122] ROCM IFU: Resoolve conflicts in FA tutorial --- python/tutorials/06-fused-attention.py | 117 ------------------------- 1 file changed, 117 deletions(-) diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index 8f8ca18c48f5..c9939c7ebc6f 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -701,12 +701,7 @@ def forward(ctx, q, k, v, causal, sm_scale, split_kernel=False): Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] assert Lq == Lk and Lk == Lv assert Lk in {16, 32, 64, 128} -<<<<<<< HEAD -<<<<<<< HEAD o = torch.empty_like(q, dtype=v.dtype) -======= - o = torch.empty_like(q) ->>>>>>> 0eed50883... resolve some merge conflicts if torch.version.hip is None: BLOCK_M = 128 BLOCK_N = 64 if Lk <= 64 else 32 @@ -721,18 +716,6 @@ def forward(ctx, q, k, v, causal, sm_scale, split_kernel=False): ) M = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) -<<<<<<< HEAD -======= - o = torch.empty_like(q) - BLOCK_M = 128 - BLOCK_N = 64 if Lk <= 64 else 32 - num_stages = 4 if Lk <= 64 else 3 - num_warps = 4 - grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1) - M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 -======= ->>>>>>> 0eed50883... resolve some merge conflicts _attn_fwd[grid]( q, k, v, sm_scale, M, o, q.stride(0), q.stride(1), q.stride(2), q.stride(3), @@ -749,20 +732,7 @@ def forward(ctx, q, k, v, causal, sm_scale, split_kernel=False): best_config = _attn_fwd.get_best_config(Z = q.shape[0], H = q.shape[1], N_CTX = q.shape[2], STAGE = stage, BLOCK_DMODEL=Lk) block_m = int(best_config.__str__().split(",")[0].split("BLOCK_M:")[1]) grid = (triton.cdiv(q.shape[2], block_m), q.shape[0] * q.shape[1], 1) -<<<<<<< HEAD -======= - BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N, - BLOCK_DMODEL=Lk, - STAGE=3, - num_warps=num_warps, - num_stages=num_stages, - ) ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 - -======= ->>>>>>> 0eed50883... resolve some merge conflicts ctx.save_for_backward(q, k, v, o, M) ctx.grid = grid ctx.sm_scale = sm_scale @@ -862,39 +832,17 @@ def backward(ctx, do): @pytest.mark.parametrize('causal', [False, True]) def test_op_fwd(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16): torch.manual_seed(20) -<<<<<<< HEAD q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() if TORCH_HAS_FP8E5: q = q.to(torch_dtype) k = k.to(torch_dtype) -======= - q = ( - torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda") - .normal_(mean=0., std=0.5) - .requires_grad_() - ) - k = ( - torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda") - .normal_(mean=0.0, std=0.5) - .requires_grad_() - ) - v = ( - torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda") - .normal_(mean=0.0, std=0.5) - .requires_grad_() - ) ->>>>>>> 0eed50883... resolve some merge conflicts sm_scale = 0.5 dout = torch.randn_like(q, dtype=torch.float16) # reference implementation M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) -<<<<<<< HEAD p = torch.matmul(q.half(), k.transpose(2, 3).half()) * sm_scale -======= - p = torch.matmul(q, k.transpose(2, 3)) * sm_scale ->>>>>>> 0eed50883... resolve some merge conflicts if causal: p[:, :, M == 0] = float("-inf") p = torch.softmax(p.float(), dim=-1).half() @@ -918,27 +866,6 @@ def test_op_bwd(Z, H, N_CTX, D_HEAD, dtype=torch.float16): k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() -======= -@pytest.mark.parametrize("Z, H, N_CTX, D_HEAD", [(1, 2, 1024, 64)]) -@pytest.mark.parametrize("causal", [True]) -def test_op(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16): - torch.manual_seed(20) - q = ( - torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda") - .normal_(mean=0.0, std=0.5) - .requires_grad_() - ) - k = ( - torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda") - .normal_(mean=0.0, std=0.5) - .requires_grad_() - ) - v = ( - torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda") - .normal_(mean=0.0, std=0.5) - .requires_grad_() - ) ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 sm_scale = 0.5 split_kernel = True dout = torch.randn_like(q) @@ -984,7 +911,6 @@ def test_op(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16): for D_HEAD in [128, 64]: if mode == 'bwd' and D_HEAD == 128: continue -<<<<<<< HEAD for causal in [False, True]: if mode == 'bwd' and causal == False: continue @@ -1013,49 +939,6 @@ def test_op(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16): 'mode': mode, 'causal': causal}) ) -======= -configs = [ - triton.testing.Benchmark( - x_names=["N_CTX"], - x_vals=[2**i for i in range(10, 15)], - line_arg="provider", - line_vals=["triton"] + (["flash"] if HAS_FLASH else []), - line_names=["Triton"] + (["Flash-2"] if HAS_FLASH else []), - styles=[("red", "-"), ("blue", "-")], - ylabel="ms", - plot_name=f"fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}", - args={ - "H": N_HEADS, - "BATCH": BATCH, - "D_HEAD": D_HEAD, - "dtype": torch.float16, - "mode": mode, - "causal": causal, - }, - ) - for mode in ["fwd", "bwd"] - for causal in [True] -] ->>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 -======= - configs.append(triton.testing.Benchmark( - x_names=['N_CTX'], - x_vals=[2**i for i in range(10, 15)], - line_arg='provider', - line_vals=['triton'] + (['flash'] if HAS_FLASH else []), - line_names=['Triton'] + ([f'Flash-{FLASH_VER}'] if HAS_FLASH else []), - styles=[('red', '-'), ('blue', '-')], - ylabel='ms', - plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}-causal={causal}', - args={ - 'H': N_HEADS, - 'BATCH': BATCH, - 'D_HEAD': D_HEAD, - 'dtype': torch.float16, - 'mode': mode, - 'causal': causal}) - ) ->>>>>>> 0eed50883... resolve some merge conflicts @triton.testing.perf_report(configs)