From 038cbc5641c4dee3835879bed86ce636d930e1dc Mon Sep 17 00:00:00 2001 From: Samuel Ginzburg Date: Tue, 5 Nov 2024 09:02:26 -0500 Subject: [PATCH] [frontend] Remove Complex Regex for MLIR Parsing (#4924) There were a number of complex regexes used for parsing MLIR in the python frontend. For maintainability reasons, it is likely better to just expose the MLIR bindings to python and use those instead. The PTX regex is left in place because we don't have an easy way to parse PTX (for now). --- python/src/ir.cc | 48 +++++++++++++ python/test/unit/tools/test_irsource.py | 93 +++++++++++++++++++++++++ python/triton/compiler/__init__.py | 7 +- python/triton/compiler/compiler.py | 65 +++++++++-------- 4 files changed, 181 insertions(+), 32 deletions(-) create mode 100644 python/test/unit/tools/test_irsource.py diff --git a/python/src/ir.cc b/python/src/ir.cc index cce7c87e8d87..a2a2c7263c69 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -24,6 +24,7 @@ #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" #include "mlir/Transforms/LocationSnapshot.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Types.h" #include "triton/Dialect/Triton/IR/Utility.h" @@ -491,6 +492,16 @@ void init_triton_ir(py::module &&m) { [](ModuleOp &self, FuncOp &funcOp) -> void { self.push_back(funcOp); }) + .def("get_entry_func_name", + [](ModuleOp &self) -> std::string { + for (auto &op : self.getOps()) { + if (auto func = dyn_cast(op)) { + if (LLVM::isKernel(func)) + return func.getName().str(); + } + } + return ""; + }) .def("has_function", [](ModuleOp &self, std::string &funcName) -> bool { if (self.lookupSymbol(funcName)) @@ -501,6 +512,43 @@ void init_triton_ir(py::module &&m) { [](ModuleOp &self, std::string &funcName) -> FuncOp { return self.lookupSymbol(funcName); }) + /* + * def ty_to_cpp(ty) is the consumer of this function. + * If the type is a ptr it expects ty[0] == '*', else the type itself. + */ + + .def("get_function_signature", + [](ModuleOp &self, FuncOp &func) -> std::vector { + std::vector strVec; + + auto type = func.getFunctionType(); + unsigned numArgs = type.getNumInputs(); + for (unsigned i = 0; i != numArgs; ++i) { + std::string tempType; + llvm::raw_string_ostream os(tempType); + + auto ty = type.getInput(i); + if (auto attributes = func.getCallableArgAttrs()) { + Attribute attr = attributes[i]; + // Check for tt.nv_tma_desc = 1 + if (auto dAttr = dyn_cast(attr)) { + if (dAttr.contains("tt.nv_tma_desc")) { + strVec.push_back("nvTmaDesc"); + continue; + } + } + } + if (auto ptrType = dyn_cast(ty)) { + auto pType = ptrType.getPointeeType(); + os << "*"; + pType.print(os); + } else { + ty.print(os); + } + strVec.push_back(tempType); + } + return strVec; + }) .def("get_int_attr", [](ModuleOp &self, std::string name) -> py::object { auto ret = self->getAttrOfType(name); diff --git a/python/test/unit/tools/test_irsource.py b/python/test/unit/tools/test_irsource.py new file mode 100644 index 000000000000..a886ebb457f4 --- /dev/null +++ b/python/test/unit/tools/test_irsource.py @@ -0,0 +1,93 @@ +import tempfile +import triton +from triton.compiler import IRSource +from triton._C.libtriton import ir + +target = triton.runtime.driver.active.get_current_target() + + +def test_mlir_attribute_parsing() -> None: + ''' + Tests that MLIR attributes are parsed correctly from input ttir/ttgir. + + Checks for the following: + 1. Name and type signature are parsed correctly + 2. _get_num_warps_from_ir_str() works + 3. tt.nv_tma_desc attribute is parsed correctly + ''' + + sample_ttgir = r""" +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 8]}> +#shared = #triton_gpu.shared<{vec = 4, perPhase = 2, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @matmul_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32}, + %arg2: !tt.ptr {tt.divisibility = 16 : i32}, + %arg3: i32 {tt.divisibility = 16 : i32}, + %arg4: i32 {tt.divisibility = 16 : i32}, + %arg5: i32 {tt.divisibility = 16 : i32}, + %arg6: i32 {tt.divisibility = 16 : i32}, + %arg7: i32 {tt.divisibility = 16 : i32}, + %arg8: i32 {tt.divisibility = 16 : i32, tt.nv_tma_desc = 0 : i32}, + %desc: !tt.ptr {tt.nv_tma_desc = 1 : i32}) attributes {noinline = false} { + tt.return + } +} +""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: + f.write(sample_ttgir) + f.flush() + context = ir.context() + src = IRSource(f.name, context) + + # check name and type signature + # should match ty_to_cpp(...) + assert src.signature == \ + {0: "*f32", 1: "*f32", 2: "*f32", 3: "i32", \ + 4: "i32", 5: "i32", 6: "i32", 7: "i32", 8: "nvTmaDesc", 9: "nvTmaDesc"} + assert src.name == "@matmul_kernel" + + # check num warps + assert src.parse_options()['num_warps'] == 8 + + sample_ttgir_vector_add = r""" + #blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> + module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @add_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32}, + %arg2: !tt.ptr {tt.divisibility = 16 : i32}, + %arg3: i32 {tt.divisibility = 16 : i32}) + attributes {noinline = false} { + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> + %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> + %5 = tt.splat %arg3 : i32 -> tensor<1024xi32, #blocked> + %6 = arith.cmpi slt, %4, %5 : tensor<1024xi32, #blocked> + %7 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %9 = tt.load %8, %6 : tensor<1024x!tt.ptr, #blocked> + %10 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %11 = tt.addptr %10, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %12 = tt.load %11, %6 : tensor<1024x!tt.ptr, #blocked> + %13 = arith.addi %9, %12 : tensor<1024xi32, #blocked> + %14 = tt.splat %arg2 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %15 = tt.addptr %14, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + tt.store %15, %13, %6 : tensor<1024x!tt.ptr, #blocked> + tt.return + } + } + """ + with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: + f.write(sample_ttgir_vector_add) + f.flush() + context = ir.context() + src = IRSource(f.name, context) + + # now test compilation + triton.compile(f.name, target=target) diff --git a/python/triton/compiler/__init__.py b/python/triton/compiler/__init__.py index bbe8c047c6d1..a05efd7e0807 100644 --- a/python/triton/compiler/__init__.py +++ b/python/triton/compiler/__init__.py @@ -1,4 +1,7 @@ -from .compiler import CompiledKernel, ASTSource, compile, make_backend, LazyDict +from .compiler import CompiledKernel, ASTSource, IRSource, compile, make_backend, LazyDict from .errors import CompilationError -__all__ = ["compile", "make_backend", "ASTSource", "AttrsDescriptor", "CompiledKernel", "CompilationError", "LazyDict"] +__all__ = [ + "compile", "make_backend", "ASTSource", "IRSource", "AttrsDescriptor", "CompiledKernel", "CompilationError", + "LazyDict" +] diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index 304b40697458..18ffb85c9ff1 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -25,19 +25,13 @@ # - (\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\)) : match a pair of parentheses enclosing # zero or more arguments separated by commas, and capture it as group 2 (the argument list) # - (attributes \{[\S\s]+\})? : optionally match attributes enclosed in braces and capture it as group 3 -mlir_prototype_pattern = r"^\s*tt\.func\s+(?:public\s+)?(@\w+)(\((?:%\w+: [\S\s]+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\))\s*(attributes \{[\S\s]+\})?\s+\{\s*$" ptx_prototype_pattern = r"\.(?:visible|extern)\s+\.(?:entry|func)\s+(\w+)\s*\(([^)]*)\)" prototype_pattern = { - "ttir": mlir_prototype_pattern, - "ttgir": mlir_prototype_pattern, "ptx": ptx_prototype_pattern, } -mlir_arg_type_pattern = r'%\w+: ((?:[^,\s<)]+|<[^>]+>)+(?: {[^}]+})?),?' ptx_arg_type_pattern = r"\.param\s+\.(\w+)" arg_type_pattern = { - "ttir": mlir_arg_type_pattern, - "ttgir": mlir_arg_type_pattern, "ptx": ptx_arg_type_pattern, } @@ -55,16 +49,6 @@ def convert_type_repr(x): return x -def _get_num_warps_from_ir_str(src: str): - ttgir_num_warps_pattern = r'"triton_gpu.num-warps"\s?=\s?(\d+)\s?:' - # TODO(jlebar): Using a regex to get num-warps is a hack, and will break if - # e.g. someone has an instruction (not module) attribute named "num-warps". - num_warps_matches = re.findall(ttgir_num_warps_pattern, src) - assert len(num_warps_matches) == 1, "Expected exactly one match for num_warps" - num_warps = int(num_warps_matches[0]) - return num_warps - - class ASTSource: def __init__(self, fn, signature, constants=None, attrs=None) -> None: @@ -107,28 +91,41 @@ def parse_options(self): class IRSource: - def __init__(self, path): + def __init__(self, path, context): self.path = path path = Path(path) self.ext = path.suffix[1:] self.src = path.read_text() - match = re.search(prototype_pattern[self.ext], self.src, re.MULTILINE) - self.name = match.group(1) - signature = match.group(2) - types = re.findall(arg_type_pattern[self.ext], signature) - self.signature = {k: convert_type_repr(ty) for k, ty in enumerate(types)} + ir.load_dialects(context) + + # We don't have a easy-to-use PTX parser that we can use, so keep that regex for now. + # TODO - replace with a proper parser + if self.ext == "ptx": + match = re.search(prototype_pattern[self.ext], self.src, re.MULTILINE) + self.name = match.group(1) + signature = match.group(2) + types = re.findall(arg_type_pattern[self.ext], signature) + self.signature = {k: convert_type_repr(ty) for k, ty in enumerate(types)} + else: + self.module = ir.parse_mlir_module(self.path, context) + fn_name = self.module.get_entry_func_name() + self.name = "@" + fn_name + funcOp = self.module.get_function(fn_name) + func_ty = self.module.get_function_signature(funcOp) + self.signature = {k: ty for k, ty in enumerate(func_ty)} def hash(self): return hashlib.sha256(self.src.encode("utf-8")).hexdigest() def make_ir(self, options, codegen_fns, module_map, context): - module = ir.parse_mlir_module(self.path, context) - module.context = context - return module + self.module.context = context + return self.module def parse_options(self): if self.ext == "ttgir": - return {'num_warps': _get_num_warps_from_ir_str(self.src)} + num_warps = self.module.get_int_attr("triton_gpu.num-warps") + assert num_warps is not None, "Unable to parse triton_gpu.num-warps attribute" + return {'num_warps': num_warps} return dict() @@ -225,7 +222,9 @@ def compile(src, target=None, options=None): # create backend if ir_source: assert isinstance(src, str), "source must be either AST or a filepath" - src = IRSource(src) + context = ir.context() + src = IRSource(src, context) + extra_options = src.parse_options() options = backend.parse_options(dict(options or dict(), **extra_options)) # create cache manager @@ -266,9 +265,15 @@ def compile(src, target=None, options=None): # when the source is an IR file, don't apply the passes related to this stage. This makes it easier to write IR level tests. if ir_source: first_stage += 1 - context = ir.context() - ir.load_dialects(context) - backend.load_dialects(context) + + if not isinstance(src, IRSource): + context = ir.context() + ir.load_dialects(context) + backend.load_dialects(context) + else: + # For IRSource, we have already grabbed the context + called ir.load_dialects + # just need to load the dialects for the backend. + backend.load_dialects(context) codegen_fns = backend.get_codegen_implementation() module_map = backend.get_module_map() try: