Skip to content

Commit

Permalink
Make JAX-Triton calls serializable.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 542524794
  • Loading branch information
chr1sj0nes authored and The jax_triton Authors committed Jun 22, 2023
1 parent e97cf5d commit f947255
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 117 deletions.
64 changes: 38 additions & 26 deletions jax_triton/pallas/triton_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@
from jax._src.state import AbstractRef
from jax._src.state import discharge
from jax._src.state import primitives as sp
from jax._src.util import split_list
from jax._src.util import merge_lists
from jax._src.util import partition_list
from jax._src.util import split_list
from jax._src.util import weakref_lru_cache
from jax.interpreters import mlir
from jax.interpreters import partial_eval as pe
Expand All @@ -47,7 +47,7 @@
from jax_triton.pallas import core as pallas_core
from jax_triton.pallas import pallas_call_p
from jax_triton.pallas import primitives
from jax_triton.triton_lib import compile_ttir_inplace
from jax_triton.triton_lib import compile_ttir_to_ptx_inplace
from jax_triton.triton_lib import get_triton_type
import numpy as np
from triton._C.libtriton.triton import ir as tl_ir
Expand Down Expand Up @@ -106,9 +106,11 @@ class TritonLoweringResult:

@dataclasses.dataclass
class TritonCompilationResult:
cubin: bytes
name: str
shared_mem: int
kernel_name: str
ttir: str
ptx: str
shared_mem_bytes: int
compute_capability: int
lowering_result: TritonLoweringResult


Expand Down Expand Up @@ -1571,15 +1573,17 @@ def compile_jaxpr(
jaxpr, in_shapes, grid_mapping, name
)
device = 0
ttir = lowering_result.module
cubin, name, shared_mem = compile_ttir_inplace(
ttir,
ttir = str(lowering_result.module)
ptx, name, shared_mem_bytes, compute_capability = compile_ttir_to_ptx_inplace(
lowering_result.module,
device=device,
num_warps=num_warps,
num_stages=num_stages,
dump=debug,
)
return TritonCompilationResult(cubin, name, shared_mem, lowering_result)
return TritonCompilationResult(
name, ttir, ptx, shared_mem_bytes, compute_capability, lowering_result
)


def pallas_call_lowering(
Expand Down Expand Up @@ -1625,38 +1629,45 @@ def pallas_call_lowering(
num_stages,
debug=debug,
)
cubin = compilation_result.cubin
name = compilation_result.name
shared_mem = compilation_result.shared_mem
lowering_result = compilation_result.lowering_result

if debug:
lowering_result.module.dump()
out_type = ir.TupleType.get_tuple(
[
ir.RankedTensorType.get(
out_shape.shape, mlir.dtype_to_ir_type(out_shape.dtype)
)
for out_shape in ctx.avals_out
]
)
compilation_result.lowering_result.module.dump()

kernel = triton_kernel_call_lib.TritonKernel(
cubin, name, num_warps, shared_mem
compilation_result.kernel_name,
num_warps,
compilation_result.shared_mem_bytes,
compilation_result.ptx,
compilation_result.ttir,
compilation_result.compute_capability,
)

grid = triton_utils.normalize_grid(
compilation_result.lowering_result.grid, metaparams={}
)

kernel_params = []
for _ in range(len(in_shapes) + len(out_shapes)):
kernel_params.append(
triton_kernel_call_lib.create_array_parameter(
0, # bytes to zero # TODO(cjfj): Expose through user API.
True, # divisible by 16
16, # divisible by 16
)
)

kernel_call = triton_kernel_call_lib.TritonKernelCall(
kernel, grid[0], grid[1], grid[2], kernel_params
)
ctx.module_context.add_keepalive(kernel_call)

out_type = ir.TupleType.get_tuple(
[
ir.RankedTensorType.get(
out_shape.shape, mlir.dtype_to_ir_type(out_shape.dtype)
)
for out_shape in ctx.avals_out
]
)

output_operand_aliases = ir.ArrayAttr.get(
[
mhlo.OutputOperandAlias.get(
Expand All @@ -1667,12 +1678,13 @@ def pallas_call_lowering(
for input, output in input_output_aliases
]
)

out = mhlo.CustomCallOp(
[out_type],
in_nodes,
call_target_name=ir.StringAttr.get("triton_kernel_call"),
has_side_effect=ir.BoolAttr.get(False),
backend_config=ir.StringAttr.get(kernel_call.descriptor),
backend_config=ir.StringAttr.get(kernel_call.to_proto(b"")),
api_version=mlir.i32_attr(2), # API_VERSION_STATUS_RETURNING
called_computations=ir.ArrayAttr.get([]),
operand_layouts=triton_utils.avals_to_layouts(ctx.avals_in),
Expand Down
162 changes: 76 additions & 86 deletions jax_triton/triton_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import os
import types
from typing import Any, Callable, Dict, Optional, Protocol, Sequence, Tuple, Union
import weakref

from absl import logging
import jax
Expand Down Expand Up @@ -126,22 +125,17 @@ def aval_size_bytes(aval):
return np.dtype(aval.dtype).itemsize * aval.size


# Compiled kernels are kept alive by the kernel call which, in turn, are kept
# alive by the jitted JAX function.
_COMPILED_KERNEL_CACHE = weakref.WeakValueDictionary()

def ptx_get_kernel_name(module) -> str:
return tc.get_kernel_name(module, pattern='// .globl')


def compile_ttir_inplace(
def compile_ttir_to_ptx_inplace(
ttir,
device: int = 0,
num_warps: int = 4,
num_stages: Optional[int] = None,
dump: bool = False,
) -> Tuple[bytes, str, int]:
"""Compiles a TTIR module to CUBIN (the TTIR is modified in-place)."""
) -> Tuple[str, str, int, int]:
compute_capability = triton_kernel_call_lib.get_compute_capability(device)
if num_stages is None:
num_stages = 3 if compute_capability >= 75 else 2
Expand All @@ -162,15 +156,17 @@ def compile_ttir_inplace(
except RuntimeError as e:
ttgir.dump()
raise ValueError("TTGIR->LLIR pass failed!") from e
shared_mem = _triton.get_shared_memory_size(ttgir)
shared_mem_bytes = _triton.get_shared_memory_size(ttgir)
if dump:
print(llir)
ptx = tc.llir_to_ptx(llir, compute_capability)
if dump:
print(ptx)
name = ptx_get_kernel_name(ptx)
cubin = tc.ptx_to_cubin(ptx, compute_capability)
return cubin, name, shared_mem
return ptx, name, shared_mem_bytes, compute_capability


_COMPILED_KERNEL_CACHE = {} # TODO(cjfj): Convert to LRU cache?


def get_or_create_triton_kernel(
Expand Down Expand Up @@ -215,26 +211,29 @@ def get_or_create_triton_kernel(
# which is fine when we have multiple of the same GPU but this won't work in
# general.
device = 0
ttir = code_gen.ast_to_ttir(fn, signature, specialization, constants,
debug=dump)
cubin, name, shared_mem = compile_ttir_inplace(
ttir,
device=device,
num_warps=num_warps,
num_stages=num_stages,
dump=dump,
module = code_gen.ast_to_ttir(
fn, signature, specialization, constants, debug=dump
)
ttir = str(module) # `module`` is compiled in-place, so copy TTIR here.
ptx, kernel_name, shared_mem_bytes, compute_capability = (
compile_ttir_to_ptx_inplace(
module,
device=device,
num_warps=num_warps,
num_stages=num_stages,
dump=dump,
)
)

kernel = triton_kernel_call_lib.TritonKernel(
cubin, name, num_warps, shared_mem
kernel_name, num_warps, shared_mem_bytes, ptx, ttir, compute_capability
)

_COMPILED_KERNEL_CACHE[cache_key] = kernel

return kernel, specialization


_KERNEL_CALL_CACHE = weakref.WeakValueDictionary()


def triton_kernel_call_lowering(
ctx,
*array_args,
Expand All @@ -248,6 +247,7 @@ def triton_kernel_call_lowering(
input_output_aliases,
zeroed_outputs,
debug,
serialized_metadata,
**metaparams,
):
if jaxlib.version.__version_info__ < (0, 3, 22) and input_output_aliases:
Expand Down Expand Up @@ -332,71 +332,56 @@ def prune_configs(configs, named_args):
)
)

# Cache auto-tuned calls with the same parameters, so the auto-tuning need
# only be performed once.
cache_key = (
fn,
tuple(arg_dtypes),
tuple(scalar_args),
tuple(tuple(p.items()) for p in config_params),
)
kernel_call = _KERNEL_CALL_CACHE.get(cache_key)

if kernel_call is None:
kernel_calls = []
for params in config_params:
kernel, specialization = get_or_create_triton_kernel(
fn,
arg_dtypes,
scalar_args,
num_warps=params["num_warps"],
num_stages=params["num_stages"],
metaparams=dict(params["metaparams"]),
dump=debug,
)

kernel_params = []
zeroed_params_with_sizes = dict(params["zeroed_params_with_sizes"])
for i, (arg, dtype) in enumerate(zip(args, arg_dtypes)):
if isinstance(arg, core.ShapedArray):
kernel_params.append(
triton_kernel_call_lib.create_array_parameter(
zeroed_params_with_sizes.get(i, 0),
i in specialization.divisible_by_16,
)
)
elif i not in specialization.equal_to_1:
kernel_params.append(
triton_kernel_call_lib.create_scalar_parameter(arg, dtype)
)

kernel_calls.append(
triton_kernel_call_lib.TritonKernelCall(
kernel,
params["grid"][0],
params["grid"][1],
params["grid"][2],
kernel_params,
)
)
kernel_calls = []
for params in config_params:
kernel, specialization = get_or_create_triton_kernel(
fn,
arg_dtypes,
scalar_args,
num_warps=params["num_warps"],
num_stages=params["num_stages"],
metaparams=dict(params["metaparams"]),
dump=debug,
)

if len(kernel_calls) > 1:
named_scalar_args = {fn.arg_names[i]: v for i, _, v in scalar_args}
input_output_aliases_with_sizes = tuple(
(input_idx, output_idx, aval_size_bytes(ctx.avals_in[input_idx]))
for input_idx, output_idx in input_output_aliases
)
kernel_call = triton_kernel_call_lib.TritonAutotunedKernelCall(
f"{fn.fn.__name__} ({call_name=}) {named_scalar_args}",
[(call, str(config)) for call, config in zip(kernel_calls, configs)],
input_output_aliases_with_sizes,
)
else:
kernel_call = kernel_calls[0]
kernel_params = []
zeroed_params_with_sizes = dict(params["zeroed_params_with_sizes"])
for i, (arg, dtype) in enumerate(zip(args, arg_dtypes)):
if isinstance(arg, core.ShapedArray):
kernel_params.append(
triton_kernel_call_lib.create_array_parameter(
zeroed_params_with_sizes.get(i, 0),
16 if (i in specialization.divisible_by_16) else 0,
)
)
elif i not in specialization.equal_to_1:
kernel_params.append(
triton_kernel_call_lib.create_scalar_parameter(arg, dtype)
)

_KERNEL_CALL_CACHE[cache_key] = kernel_call
kernel_calls.append(
triton_kernel_call_lib.TritonKernelCall(
kernel,
params["grid"][0],
params["grid"][1],
params["grid"][2],
kernel_params,
)
)

ctx.module_context.add_keepalive(kernel_call)
if len(kernel_calls) > 1:
named_scalar_args = {fn.arg_names[i]: v for i, _, v in scalar_args}
input_output_aliases_with_sizes = tuple(
(input_idx, output_idx, aval_size_bytes(ctx.avals_in[input_idx]))
for input_idx, output_idx in input_output_aliases
)
kernel_call = triton_kernel_call_lib.TritonAutotunedKernelCall(
f"{fn.fn.__name__} ({call_name=}) {named_scalar_args}",
[(call, str(config)) for call, config in zip(kernel_calls, configs)],
input_output_aliases_with_sizes,
)
else:
kernel_call = kernel_calls[0]

out_types = [
ir.RankedTensorType.get(shape.shape, mlir.dtype_to_ir_type(shape.dtype))
Expand All @@ -416,12 +401,13 @@ def prune_configs(configs, named_args):
)
)

serialized = kernel_call.to_proto(serialized_metadata)
return mhlo.CustomCallOp(
out_types,
array_args,
call_target_name=ir.StringAttr.get(call_name),
has_side_effect=ir.BoolAttr.get(False),
backend_config=ir.StringAttr.get(kernel_call.descriptor),
backend_config=ir.StringAttr.get(serialized),
api_version=mlir.i32_attr(2), # API_VERSION_STATUS_RETURNING
called_computations=ir.ArrayAttr.get([]),
operand_layouts=utils.avals_to_layouts(ctx.avals_in),
Expand Down Expand Up @@ -457,6 +443,7 @@ def triton_call(
Sequence[int], Callable[[Dict[str, Any]], Sequence[int]]
] = (),
debug: bool = False,
serialized_metadata: bytes = b"",
**metaparams: Any,
) -> Any:
"""Calls a Triton kernel with `jax.Array` arguments.
Expand Down Expand Up @@ -531,6 +518,8 @@ def add(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
num_warps: The number of warps used to execute the Triton kernel.
num_stages: The number of stages emitted by the Triton compiler.
debug: Prints out intermediate IRs if True for debugging purposes.
serialized_metadata: Arbitrary metadata that will be added into the
serialized kernel call.
**metaparams: Additional keyword arguments that will be provided to a `grid`
(if it is a function) and to the Triton kernel as `constexpr` arguments.
Expand Down Expand Up @@ -573,6 +562,7 @@ def add(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
input_output_aliases=tuple(input_output_aliases.items()),
zeroed_outputs=zeroed_outputs,
debug=debug,
serialized_metadata=serialized_metadata,
**metaparams,
)
return tree_util.tree_unflatten(out_tree, out_flat)
Loading

0 comments on commit f947255

Please sign in to comment.