From 40122b145fe8ad1c85596e03f483f6495afe059b Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Tue, 9 Jan 2024 06:42:25 -0800 Subject: [PATCH] Pallas now uses MLIR Python builders to lower to Triton IR This allows us to drop a dependency on the Triton Python package in the future, and delegate ->ptx compilation to XLA. PiperOrigin-RevId: 596916998 --- jax_triton/triton_lib.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/jax_triton/triton_lib.py b/jax_triton/triton_lib.py index 3617eca1..8ef12e4a 100644 --- a/jax_triton/triton_lib.py +++ b/jax_triton/triton_lib.py @@ -20,6 +20,7 @@ import copy import functools import os +import tempfile import types from typing import Any, Callable, Dict, Optional, Protocol, Sequence, Tuple, Union import zlib @@ -46,6 +47,7 @@ import triton.language as tl from triton.runtime import autotuner import triton._C.libtriton as _triton + from triton._C.libtriton import ir as tl_ir from triton.common.backend import get_backend import triton.compiler.backends.cuda as cb @@ -192,6 +194,17 @@ def compile_ttir_to_ptx_inplace( compute_capability = triton_kernel_call_lib.get_compute_capability(device) if cuda_options.debug: print(ttir) + if isinstance(ttir, ir.Module): + # Triton compilation APIs only accept Triton-specific MLIR wrappers. + # So, here we serialize an ir.Module to a file and then deserialize + # it as a tl_ir.module. + tl_context = tl_ir.context() + tl_context.load_triton() + with tempfile.NamedTemporaryFile(mode="wb") as f: + ttir.operation.write_bytecode(f) + f.flush() + ttir = tl_ir.parse_mlir_module(f.name, tl_context) + ttir.context = tl_context try: metadata = dict() opt_ttir = cuda_backend.make_ttir(ttir, metadata, cuda_options)