Skip to content

Commit

Permalink
Pallas now uses MLIR Python builders to lower to Triton IR
Browse files Browse the repository at this point in the history
This allows us to drop a dependency on the Triton Python package in the future,
and delegate ->ptx compilation to XLA.

PiperOrigin-RevId: 596916998
  • Loading branch information
superbobry authored and The jax_triton Authors committed Jan 11, 2024
1 parent 4a5791d commit 40122b1
Showing 1 changed file with 13 additions and 0 deletions.
13 changes: 13 additions & 0 deletions jax_triton/triton_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import copy
import functools
import os
import tempfile
import types
from typing import Any, Callable, Dict, Optional, Protocol, Sequence, Tuple, Union
import zlib
Expand All @@ -46,6 +47,7 @@
import triton.language as tl
from triton.runtime import autotuner
import triton._C.libtriton as _triton
from triton._C.libtriton import ir as tl_ir
from triton.common.backend import get_backend
import triton.compiler.backends.cuda as cb

Expand Down Expand Up @@ -192,6 +194,17 @@ def compile_ttir_to_ptx_inplace(
compute_capability = triton_kernel_call_lib.get_compute_capability(device)
if cuda_options.debug:
print(ttir)
if isinstance(ttir, ir.Module):
# Triton compilation APIs only accept Triton-specific MLIR wrappers.
# So, here we serialize an ir.Module to a file and then deserialize
# it as a tl_ir.module.
tl_context = tl_ir.context()
tl_context.load_triton()
with tempfile.NamedTemporaryFile(mode="wb") as f:
ttir.operation.write_bytecode(f)
f.flush()
ttir = tl_ir.parse_mlir_module(f.name, tl_context)
ttir.context = tl_context
try:
metadata = dict()
opt_ttir = cuda_backend.make_ttir(ttir, metadata, cuda_options)
Expand Down

0 comments on commit 40122b1

Please sign in to comment.