From 0c2899d6df0717119cdd66061ba701a03527e9a0 Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Tue, 21 Mar 2023 16:15:24 +0000 Subject: [PATCH 1/6] Initial attempt --- jax_triton/pallas/core.py | 129 +++++++++++++++++++++++++++++++ jax_triton/pallas/pallas_call.py | 126 ++++++------------------------ tests/pallas_test.py | 2 +- 3 files changed, 152 insertions(+), 105 deletions(-) diff --git a/jax_triton/pallas/core.py b/jax_triton/pallas/core.py index cd165c5c..ae492e92 100644 --- a/jax_triton/pallas/core.py +++ b/jax_triton/pallas/core.py @@ -15,10 +15,27 @@ """Module for pallas-core functionality.""" import contextlib import dataclasses +import functools +from functools import partial from typing import Any, Callable, Iterator, List, Optional, Tuple, Union +import jax.numpy as jnp +from jax._src import api_util from jax._src import core as jax_core +from jax._src import linear_util as lu +from jax._src import state +from jax._src import tree_util +from jax._src.lax.control_flow import for_loop +from jax.interpreters import partial_eval as pe +from jax._src.util import weakref_lru_cache, safe_map, safe_zip +from jax._src.state.types import AbstractRef + +map, unsafe_map = safe_map, map +zip, unsafe_zip = safe_zip, zip + +Grid = tuple[int, ...] +GridOrLambda = Union[Callable[..., Grid], Grid] @dataclasses.dataclass class GridEnv: @@ -73,3 +90,115 @@ class GridSpec: mapped_dims: Tuple[int, ...] replace = dataclasses.replace + +@weakref_lru_cache +def _initial_style_open_jaxpr(fun: Callable, in_tree, in_avals, + primitive_name: Optional[str] = None): + wrapped_fun, out_tree_thunk = api_util.flatten_fun_nokwargs( + lu.wrap_init(fun), in_tree) + debug_info = pe.debug_info(fun, in_tree, out_tree_thunk, False, + primitive_name or "") + jaxpr, consts = _initial_style_flat_jaxpr(wrapped_fun, in_avals, + debug_info=debug_info) + return jaxpr, consts, out_tree_thunk() + +def _initial_style_flat_jaxpr(fun: lu.WrappedFun, in_avals, + debug_info: Optional[jax_core.DebugInfo] = None + ) -> tuple[jax_core.Jaxpr, list[Any]]: + jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(fun, in_avals, debug_info) + jaxpr = for_loop._hoist_consts_to_refs(jaxpr) + return jaxpr, consts + +def preprocess_grid(grid: Optional[Union[Grid, int]]) -> Grid: + if grid is None: + return () + if isinstance(grid, int): + return (grid,) + return grid + +def extract_function_name(f: Callable, name: Optional[str]) -> str: + if name is None: + name = f.__name__ if hasattr(f, "__name__") and f.__name__ else "func" + return name + +def convert_block_spec_to_block_mapping( + grid: Grid, block_spec: Optional[BlockSpec]) -> Optional[BlockMapping]: + if block_spec is None: + return None + in_avals = [jax_core.ShapedArray((), jnp.int32) for _ in grid] + block_shape = tuple( + mapped if s is None else s for s in block_spec.block_shape) + jaxpr, _, consts = pe.trace_to_jaxpr_dynamic( + lu.wrap_init(block_spec.compute_index), in_avals) + return BlockMapping(block_shape, jax_core.ClosedJaxpr(jaxpr, consts)) + +def compute_shape_from_block_spec(block_spec: Optional[BlockSpec], + arg_shape: tuple[int, ...] + ) -> tuple[int, ...]: + if block_spec is None: + return arg_shape + return tuple(s for s in block_spec.block_shape if s is not None) + +@dataclasses.dataclass +class SpecializedKernel: + jaxpr: jax_core.Jaxpr + grid_spec: GridSpec + +@dataclasses.dataclass(frozen=True) +class Kernel: + func: lu.WrappedFun + name: Optional[str] + grid: Optional[Grid] + in_specs: Optional[list[Optional[BlockSpec]]] + out_specs: Optional[list[Optional[BlockSpec]]] + + def __post_init__(self): + if self.grid is None: + if self.in_specs is not None: + raise ValueError("Cannot specify `in_specs` with a `None` grid.") + if self.out_specs is not None: + raise ValueError("Cannot specify `out_specs` with a `None` grid.") + + def get_name(self) -> str: + return extract_function_name(self.func, self.name) + + def specialize(self, + in_avals: tuple[AbstractRef, ...], + out_avals: tuple[AbstractRef, ...], + in_tree: tree_util.PyTreeDef + ) -> tuple[SpecializedKernel, ...]: + grid = preprocess_grid(self.grid) + in_specs = self.in_specs + out_specs = self.out_specs + if out_specs is not None and not isinstance(out_specs, (tuple, list)): + out_specs = (out_specs,) + if out_specs is not None and not isinstance(out_specs, tuple): + out_specs = tuple(out_specs) + if in_specs is None: + in_specs = [None] * len(in_avals) + if out_specs is None: + out_specs = [None] * len(out_avals) + if grid == (): + in_ref_avals = [state.shaped_array_ref(arg.shape, arg.dtype) + for arg in in_avals] + out_ref_avals = [state.shaped_array_ref(arg.shape, arg.dtype) + for arg in out_avals] + else: + in_ref_avals = [ + state.shaped_array_ref( + compute_shape_from_block_spec(block_spec, aval.shape), + aval.dtype) + for block_spec, aval in zip(in_specs, in_avals)] + out_ref_avals = [ + state.shaped_array_ref( + compute_shape_from_block_spec(block_spec, aval.shape), + aval.dtype) + for block_spec, aval in zip(out_specs, out_avals)] + in_block_mappings = map(partial(convert_block_spec_to_block_mapping, grid), + in_specs) + out_block_mappings = map(partial(convert_block_spec_to_block_mapping, grid), + out_specs) + grid_spec = GridSpec(grid, (*in_block_mappings, *out_block_mappings), ()) + jaxpr, consts, out_tree = _initial_style_open_jaxpr( + self.func, in_tree, tuple((*in_ref_avals, *out_ref_avals))) + return [SpecializedKernel(jaxpr, grid_spec)], consts, out_tree diff --git a/jax_triton/pallas/pallas_call.py b/jax_triton/pallas/pallas_call.py index b5a5e56e..3b9e255b 100644 --- a/jax_triton/pallas/pallas_call.py +++ b/jax_triton/pallas/pallas_call.py @@ -32,14 +32,13 @@ from jax._src.lib.mlir.dialects import mhlo from jax._src import state from jax._src.util import ( - split_list, safe_map, safe_zip, weakref_lru_cache, + split_list, safe_map, safe_zip, tuple_insert, partition_list) -from jax._src.lax.control_flow import for_loop import jax.numpy as jnp import numpy as np -from jax_triton.utils import avals_to_layouts, normalize_grid from jax_triton.pallas import core as pallas_core +from jax_triton.pallas import tracing_utils map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip @@ -280,47 +279,6 @@ def _pallas_call_batching_rule(args, dims, *, return out, (0,) * len(out) batching.primitive_batchers[pallas_call_p] = _pallas_call_batching_rule -@weakref_lru_cache -def _initial_style_open_jaxpr(fun: Callable, in_tree, in_avals, - primitive_name: Optional[str] = None): - wrapped_fun, out_tree_thunk = api_util.flatten_fun_nokwargs( - lu.wrap_init(fun), in_tree) - debug = pe.debug_info(fun, in_tree, out_tree_thunk, False, - primitive_name or "") - jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, debug) - jaxpr = for_loop._hoist_consts_to_refs(jaxpr) - return jaxpr, consts, out_tree_thunk() - -def _preprocess_grid(grid: Optional[Union[Grid, int]]) -> Grid: - if grid is None: - return () - if isinstance(grid, int): - return (grid,) - return grid - -def _extract_function_name(f: Callable, name: Optional[str]) -> str: - if name is None: - name = f.__name__ if hasattr(f, "__name__") and f.__name__ else "func" - return name - -def _convert_block_spec_to_block_mapping( - grid: Grid, block_spec: Optional[BlockSpec]) -> Optional[BlockMapping]: - if block_spec is None: - return None - in_avals = [jax_core.ShapedArray((), jnp.int32) for _ in grid] - block_shape = tuple( - pallas_core.mapped if s is None else s for s in block_spec.block_shape) - jaxpr, _, consts = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(block_spec.compute_index), in_avals) - return BlockMapping(block_shape, jax_core.ClosedJaxpr(jaxpr, consts)) - -def _compute_shape_from_block_spec(block_spec: Optional[BlockSpec], - arg_shape: Tuple[int, ...] - ) -> Tuple[int, ...]: - if block_spec is None: - return arg_shape - return tuple(s for s in block_spec.block_shape if s is not None) - def pallas_call(f: Callable, out_shape: Any, *, debug: bool = False, grid: Optional[Grid] = None, in_specs: Optional[Sequence[Optional[BlockSpec]]] = None, @@ -329,82 +287,42 @@ def pallas_call(f: Callable, out_shape: Any, *, debug: bool = False, interpret: bool = False, name: Optional[str] = None, **compiler_params: Any): - if grid is None: - if in_specs is not None: - raise ValueError("Cannot specify `in_specs` with a `None` grid.") - if out_specs is not None: - raise ValueError("Cannot specify `out_specs` with a `None` grid.") - grid = _preprocess_grid(grid) - name = _extract_function_name(f, name) - singleton = False if not isinstance(out_shape, (tuple, list)): out_shape = (out_shape,) singleton = True if not isinstance(out_shape, tuple): out_shape = tuple(out_shape) - flat_out_shapes, out_tree = tree_util.tree_flatten(out_shape) - if out_specs is not None and not isinstance(out_specs, (tuple, list)): - out_specs = (out_specs,) - if out_specs is not None and not isinstance(out_specs, tuple): - out_specs = tuple(out_specs) - flat_out_shapes = [jax.ShapeDtypeStruct(x.shape, x.dtype) - for x in flat_out_shapes] + @jax.jit def wrapped(*args): flat_args, in_tree = tree_util.tree_flatten(args) - if grid is None: - flat_in_specs = [None] * len(flat_args) - flat_out_specs = [None] * len(flat_out_shapes) - in_ref_avals = [state.shaped_array_ref(arg.shape, arg.dtype) - for arg in flat_args] - out_ref_avals = [state.shaped_array_ref(arg.shape, arg.dtype) - for arg in flat_out_shapes] - else: - if in_specs is None: - flat_in_specs = [None for arg in flat_args] - else: - flat_in_specs, in_block_tree = tree_util.tree_flatten(tuple(in_specs)) - if in_block_tree != in_tree: - raise ValueError( - "Pytree specs for arguments and `in_specs` must match: " - f"{in_tree} vs. {in_block_tree}") - if out_specs is None: - flat_out_specs = [None for arg in flat_out_shapes] - else: - flat_out_specs, out_block_tree = tree_util.tree_flatten(out_specs) - if out_block_tree != out_tree: - raise ValueError("Pytree specs for `out_shape` and `out_specs` must match: " - f"{out_tree} vs. {out_block_tree}") - in_ref_avals = [ - state.shaped_array_ref( - _compute_shape_from_block_spec(block_spec, arg.shape), arg.dtype) - for block_spec, arg in zip(flat_in_specs, flat_args)] - out_ref_avals = [ - state.shaped_array_ref( - _compute_shape_from_block_spec(block_spec, arg.shape), arg.dtype) - for block_spec, arg in zip(flat_out_specs, flat_out_shapes)] - in_block_mappings = map(partial(_convert_block_spec_to_block_mapping, grid), - flat_in_specs) - out_block_mappings = map(partial(_convert_block_spec_to_block_mapping, grid), - flat_out_specs) + flat_out_shapes, out_tree = tree_util.tree_flatten(out_shape) + jaxpr_in_tree = tree_util.tree_structure((*args, *out_shape)) - jaxpr, consts, _ = _initial_style_open_jaxpr( - f, jaxpr_in_tree, tuple((*in_ref_avals, *out_ref_avals)), - primitive_name="pallas_call") - flat_in_specs = it.chain([None] * len(consts), flat_in_specs) - grid_spec = GridSpec(grid, tuple((*in_block_mappings, - *out_block_mappings)), - ()) + if not isinstance(f, pallas_core.Kernel): + kernel = pallas_core.Kernel(f, name, grid, in_specs, out_specs) + else: + kernel = f + flat_in_avals = tuple(jax_core.raise_to_shaped(jax_core.get_aval(a)) + for a in flat_args) + flat_out_avals = tuple(jax_core.ShapedArray(a.shape, a.dtype) + for a in flat_out_shapes) + specialized_kernels, consts, out_tree = kernel.specialize( + flat_in_avals, flat_out_avals, jaxpr_in_tree) which_linear = (False,) * len(flat_args) out_flat = pallas_call_p.bind( - *consts, *flat_args, jaxpr=jaxpr, name=name, which_linear=which_linear, + *consts, *flat_args, + kernels=specialized_kernels, + name=name, + which_linear=which_linear, in_shapes=tuple(jax.ShapeDtypeStruct(a.shape, a.dtype) for a in flat_args), - out_shapes=tuple(flat_out_shapes), debug=debug, + out_shapes=tuple(flat_out_shapes), + debug=debug, interpret=interpret, - grid_spec=grid_spec, input_output_aliases=tuple(input_output_aliases.items()), **compiler_params) + print(out_flat) out = tree_util.tree_unflatten(out_tree, out_flat) if singleton: return out[0] diff --git a/tests/pallas_test.py b/tests/pallas_test.py index 1f2a2b76..9b4e32fa 100644 --- a/tests/pallas_test.py +++ b/tests/pallas_test.py @@ -33,7 +33,7 @@ import jax.numpy as jnp import jax_triton as jt from jax_triton import pallas as pl -from jax_triton.pallas.pallas_call import _initial_style_open_jaxpr +from jax_triton.pallas.core import _initial_style_open_jaxpr from jax_triton.pallas.ops import attention from jax_triton.pallas.ops import layer_norm try: From 8e1dfe4d565348d269962c3070a94b0002013b47 Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Thu, 30 Mar 2023 01:11:29 +0000 Subject: [PATCH 2/6] Working! --- jax_triton/pallas/__init__.py | 1 + jax_triton/pallas/core.py | 93 ++------ jax_triton/pallas/pallas_call.py | 287 ++++++++++++++++-------- jax_triton/pallas/tracing_utils.py | 43 ++++ jax_triton/pallas/triton_ir_lowering.py | 137 ++++++----- tests/pallas_test.py | 47 +++- 6 files changed, 381 insertions(+), 227 deletions(-) create mode 100644 jax_triton/pallas/tracing_utils.py diff --git a/jax_triton/pallas/__init__.py b/jax_triton/pallas/__init__.py index 6ef4b147..98a1c415 100644 --- a/jax_triton/pallas/__init__.py +++ b/jax_triton/pallas/__init__.py @@ -14,6 +14,7 @@ """Module for pallas, a jaxpr "dialect" for Triton.""" from jax_triton.pallas.core import BlockSpec +from jax_triton.pallas.core import Config from jax_triton.pallas.pallas_call import pallas_call from jax_triton.pallas.pallas_call import pallas_call_p from jax_triton.pallas.primitives import atomic_add diff --git a/jax_triton/pallas/core.py b/jax_triton/pallas/core.py index ae492e92..c254b356 100644 --- a/jax_triton/pallas/core.py +++ b/jax_triton/pallas/core.py @@ -31,6 +31,8 @@ from jax._src.util import weakref_lru_cache, safe_map, safe_zip from jax._src.state.types import AbstractRef +from jax_triton.pallas import tracing_utils + map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip @@ -91,23 +93,18 @@ class GridSpec: replace = dataclasses.replace -@weakref_lru_cache -def _initial_style_open_jaxpr(fun: Callable, in_tree, in_avals, - primitive_name: Optional[str] = None): - wrapped_fun, out_tree_thunk = api_util.flatten_fun_nokwargs( - lu.wrap_init(fun), in_tree) - debug_info = pe.debug_info(fun, in_tree, out_tree_thunk, False, - primitive_name or "") - jaxpr, consts = _initial_style_flat_jaxpr(wrapped_fun, in_avals, - debug_info=debug_info) - return jaxpr, consts, out_tree_thunk() - -def _initial_style_flat_jaxpr(fun: lu.WrappedFun, in_avals, - debug_info: Optional[jax_core.DebugInfo] = None - ) -> tuple[jax_core.Jaxpr, list[Any]]: - jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(fun, in_avals, debug_info) - jaxpr = for_loop._hoist_consts_to_refs(jaxpr) - return jaxpr, consts +Platform = str + +@dataclasses.dataclass +class Config: + meta: dict[str, Any] + compiler_params: dict[Platform, dict[str, Any]] + + def to_string(self, platform: str) -> str: + compiler_params = self.compiler_params.get(platform, {}) + return "-".join([*(f"{k}_{v}" for k, v in self.meta.items()), + *(f"{k}_{v}" for k, v in compiler_params.items())]) + def preprocess_grid(grid: Optional[Union[Grid, int]]) -> Grid: if grid is None: @@ -141,64 +138,8 @@ def compute_shape_from_block_spec(block_spec: Optional[BlockSpec], @dataclasses.dataclass class SpecializedKernel: + name: str jaxpr: jax_core.Jaxpr + num_consts: int grid_spec: GridSpec - -@dataclasses.dataclass(frozen=True) -class Kernel: - func: lu.WrappedFun - name: Optional[str] - grid: Optional[Grid] - in_specs: Optional[list[Optional[BlockSpec]]] - out_specs: Optional[list[Optional[BlockSpec]]] - - def __post_init__(self): - if self.grid is None: - if self.in_specs is not None: - raise ValueError("Cannot specify `in_specs` with a `None` grid.") - if self.out_specs is not None: - raise ValueError("Cannot specify `out_specs` with a `None` grid.") - - def get_name(self) -> str: - return extract_function_name(self.func, self.name) - - def specialize(self, - in_avals: tuple[AbstractRef, ...], - out_avals: tuple[AbstractRef, ...], - in_tree: tree_util.PyTreeDef - ) -> tuple[SpecializedKernel, ...]: - grid = preprocess_grid(self.grid) - in_specs = self.in_specs - out_specs = self.out_specs - if out_specs is not None and not isinstance(out_specs, (tuple, list)): - out_specs = (out_specs,) - if out_specs is not None and not isinstance(out_specs, tuple): - out_specs = tuple(out_specs) - if in_specs is None: - in_specs = [None] * len(in_avals) - if out_specs is None: - out_specs = [None] * len(out_avals) - if grid == (): - in_ref_avals = [state.shaped_array_ref(arg.shape, arg.dtype) - for arg in in_avals] - out_ref_avals = [state.shaped_array_ref(arg.shape, arg.dtype) - for arg in out_avals] - else: - in_ref_avals = [ - state.shaped_array_ref( - compute_shape_from_block_spec(block_spec, aval.shape), - aval.dtype) - for block_spec, aval in zip(in_specs, in_avals)] - out_ref_avals = [ - state.shaped_array_ref( - compute_shape_from_block_spec(block_spec, aval.shape), - aval.dtype) - for block_spec, aval in zip(out_specs, out_avals)] - in_block_mappings = map(partial(convert_block_spec_to_block_mapping, grid), - in_specs) - out_block_mappings = map(partial(convert_block_spec_to_block_mapping, grid), - out_specs) - grid_spec = GridSpec(grid, (*in_block_mappings, *out_block_mappings), ()) - jaxpr, consts, out_tree = _initial_style_open_jaxpr( - self.func, in_tree, tuple((*in_ref_avals, *out_ref_avals))) - return [SpecializedKernel(jaxpr, grid_spec)], consts, out_tree + compiler_params: dict[str, Any] diff --git a/jax_triton/pallas/pallas_call.py b/jax_triton/pallas/pallas_call.py index 3b9e255b..f782c0f9 100644 --- a/jax_triton/pallas/pallas_call.py +++ b/jax_triton/pallas/pallas_call.py @@ -13,6 +13,7 @@ # limitations under the License. """Module for calling pallas functions from JAX.""" +import dataclasses from functools import partial import itertools as it @@ -20,7 +21,6 @@ import jax from jax import api_util -from jax import linear_util as lu from jax import tree_util from jax import lax from jax.interpreters import ad @@ -29,6 +29,7 @@ from jax.interpreters import xla from jax._src import ad_util from jax._src import core as jax_core +from jax._src import linear_util as lu from jax._src.lib.mlir.dialects import mhlo from jax._src import state from jax._src.util import ( @@ -46,7 +47,9 @@ Grid = Tuple[int, ...] BlockSpec = pallas_core.BlockSpec BlockMapping = pallas_core.BlockMapping +Config = pallas_core.Config GridSpec = pallas_core.GridSpec +SpecializedKernel = pallas_core.SpecializedKernel pallas_call_p = jax_core.Primitive('pallas_call') pallas_call_p.multiple_results = True @@ -75,16 +78,17 @@ def _maybe_dynamic_update_slice(start_idx, block_shape, value, update, assert update.shape == block_shape return lax.dynamic_update_slice(value, update, start_idx) -def _pallas_call_impl(*args, jaxpr, name, out_shapes, which_linear, - interpret, debug: bool, +def _pallas_call_impl(*args, kernels: Sequence[SpecializedKernel], + name: str, out_shapes, which_linear: Sequence[bool], + interpret: bool, debug: bool, in_shapes, - input_output_aliases: Tuple[Tuple[int, int], ...], - grid_spec: GridSpec, - **compiler_params: Any): + input_output_aliases: Tuple[Tuple[int, int], ...]): if interpret: # If we're in interpreter mode, we *scan* over the grid and eval the # discharged jaxpr. This should reproduce exactly what compiling to Triton # will do. + kernel = kernels[0] + grid_spec, jaxpr = kernel.grid_spec, kernel.jaxpr grid = grid_spec.grid discharged_jaxpr, consts = state.discharge_state(jaxpr, ()) if debug: @@ -131,17 +135,23 @@ def body(carry): out_shapes=out_shapes, which_linear=which_linear, grid_spec=grid_spec, interpret=interpret, debug=debug, - input_output_aliases=input_output_aliases, - **compiler_params) + input_output_aliases=input_output_aliases) pallas_call_p.def_impl(_pallas_call_impl) def _pallas_call_abstract_eval(*avals, out_shapes, **_): return map(lambda x: jax_core.ShapedArray(x.shape, x.dtype), out_shapes) pallas_call_p.def_abstract_eval(_pallas_call_abstract_eval) -def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name, which_linear, +def _pallas_call_jvp_rule( + primals, tangents, *, + kernels: Sequence[SpecializedKernel], + name: str, + which_linear: Sequence[bool], input_output_aliases: Tuple[Tuple[int, int], ...], - in_shapes, out_shapes, grid_spec, debug, interpret, **compiler_params: Any): + in_shapes: Sequence[jax.ShapeDtypeStruct], + out_shapes: Sequence[jax.ShapeDtypeStruct], + debug: bool, + interpret: bool): if input_output_aliases: raise NotImplementedError("JVP with aliasing not supported.") nonzero_tangents = [not isinstance(t, ad_util.Zero) for t in tangents] @@ -149,44 +159,48 @@ def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name, which_linear, for t, inst in zip(tangents, nonzero_tangents)] tangents = [t for t in tangents if type(t) is not ad_util.Zero] nonzero_tangents_with_outputs = nonzero_tangents + [True] * len(out_shapes) - closed_jaxpr = jax_core.ClosedJaxpr(jaxpr, ()) - jvp_jaxpr_, _ = ad.jvp_jaxpr(closed_jaxpr, nonzero_tangents_with_outputs, []) - jvp_jaxpr, () = jvp_jaxpr_.jaxpr, jvp_jaxpr_.consts # TODO consts - jvp_which_linear = which_linear + (True,) * len(tangents) - jvp_inshapes = (*in_shapes, *in_shapes) - jvp_outshapes = (*out_shapes, *out_shapes) - if input_output_aliases: - raise NotImplementedError("`input_output_aliases` jvp not supported.") - # `pallas_call` takes in inputs and returns outputs but its jaxpr *does not*. - # `pallas_call` takes in a stateful jaxpr, meaning the jaxpr accepts input - # `Ref`s that are read from followed by output `Ref`s that are written to. - # This means that when we do `jvp_jaxpr` on the `jaxpr`, we get out a new - # jaxpr that has tangents following primals. In order for this jaxpr to be - # compatible w/ `pallas_call` (inputs then outputs), we need to shuffle around - # the jaxpr's invars. - logical_primals, logical_tangents = split_list( - jvp_jaxpr.invars, [len(primals) + len(out_shapes)]) - logical_primal_inputs, logical_primal_outputs = split_list(logical_primals, [len(primals)]) - logical_tangent_inputs, logical_tangent_outputs = split_list(logical_tangents, [len(tangents)]) - in_bms, out_bms = split_list(grid_spec.block_mappings, [len(primals)]) - new_bms = tuple((*in_bms, *in_bms, *out_bms, *out_bms)) - new_grid_spec = grid_spec.replace(block_mappings=new_bms) - jvp_jaxpr = jvp_jaxpr.replace(invars=[*logical_primal_inputs, - *logical_tangent_inputs, - *logical_primal_outputs, - *logical_tangent_outputs]) - if debug: - print(jvp_jaxpr) - out_flat = pallas_call_p.bind(*primals, *tangents, jaxpr=jvp_jaxpr, + new_kernels = [] + for kernel in kernels: + jaxpr, grid_spec = kernel.jaxpr, kernel.grid_spec + closed_jaxpr = jax_core.ClosedJaxpr(jaxpr, ()) + jvp_jaxpr_, _ = ad.jvp_jaxpr(closed_jaxpr, nonzero_tangents_with_outputs, []) + jvp_jaxpr, () = jvp_jaxpr_.jaxpr, jvp_jaxpr_.consts # TODO consts + jvp_which_linear = which_linear + (True,) * len(tangents) + jvp_inshapes = (*in_shapes, *in_shapes) + jvp_outshapes = (*out_shapes, *out_shapes) + if input_output_aliases: + raise NotImplementedError("`input_output_aliases` jvp not supported.") + # `pallas_call` takes in inputs and returns outputs but its jaxpr *does not*. + # `pallas_call` takes in a stateful jaxpr, meaning the jaxpr accepts input + # `Ref`s that are read from followed by output `Ref`s that are written to. + # This means that when we do `jvp_jaxpr` on the `jaxpr`, we get out a new + # jaxpr that has tangents following primals. In order for this jaxpr to be + # compatible w/ `pallas_call` (inputs then outputs), we need to shuffle around + # the jaxpr's invars. + logical_primals, logical_tangents = split_list( + jvp_jaxpr.invars, [len(primals) + len(out_shapes)]) + logical_primal_inputs, logical_primal_outputs = split_list(logical_primals, [len(primals)]) + logical_tangent_inputs, logical_tangent_outputs = split_list(logical_tangents, [len(tangents)]) + in_bms, out_bms = split_list(grid_spec.block_mappings, [len(primals)]) + new_bms = tuple((*in_bms, *in_bms, *out_bms, *out_bms)) + new_grid_spec = grid_spec.replace(block_mappings=new_bms) + jvp_jaxpr = jvp_jaxpr.replace(invars=[*logical_primal_inputs, + *logical_tangent_inputs, + *logical_primal_outputs, + *logical_tangent_outputs]) + if debug: + print(jvp_jaxpr) + new_kernels.append(SpecializedKernel(kernel.name, jvp_jaxpr, + kernel.num_consts, new_grid_spec, + kernel.compiler_params)) + out_flat = pallas_call_p.bind(*primals, *tangents, kernels=new_kernels, name=f"{name}_jvp", in_shapes=jvp_inshapes, out_shapes=jvp_outshapes, - grid_spec=new_grid_spec, which_linear=jvp_which_linear, interpret=interpret, debug=debug, - input_output_aliases=(), - **compiler_params) + input_output_aliases=()) out_primals, out_tangents = split_list(out_flat, [len(out_flat) // 2]) return out_primals, out_tangents ad.primitive_jvps[pallas_call_p] = _pallas_call_jvp_rule @@ -216,69 +230,125 @@ def _block_map_function(new_idx, *args): jax_core.ClosedJaxpr(block_mapping_jaxpr, consts)) def _pallas_call_batching_rule(args, dims, *, - jaxpr: jax_core.Jaxpr, + kernels: Sequence[SpecializedKernel], name: str, in_shapes: Tuple[jax.ShapeDtypeStruct, ...], out_shapes: Tuple[jax.ShapeDtypeStruct, ...], - grid_spec: GridSpec, input_output_aliases: Tuple[Tuple[int, int], ...], debug: bool, interpret: bool, - which_linear: Tuple[bool, ...], - **compiler_params: Any): + which_linear: Tuple[bool, ...]): axis_size, = {x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped} - block_mappings = grid_spec.block_mappings - avals = [v.aval for v in jaxpr.invars] - # How should we pick output dimensions? This actually matters because XLA - # can't optimize our pallas kernels, and this layout impacts performance. For - # now, because `vmap` doesn't really offer a way of inferring good output - # dimensions. For now, we just use 0. - # TODO(sharadmv): explore inferring better output dimensions via a heuristic - # TODO(sharadmv): explore a long term solution to output dim inference + new_kernels = [] + for kernel in kernels: + jaxpr, grid_spec = kernel.jaxpr, kernel.grid_spec + block_mappings = grid_spec.block_mappings + avals = [v.aval for v in jaxpr.invars] + # How should we pick output dimensions? This actually matters because XLA + # can't optimize our pallas kernels, and this layout impacts performance. For + # now, because `vmap` doesn't really offer a way of inferring good output + # dimensions. For now, we just use 0. + # TODO(sharadmv): explore inferring better output dimensions via a heuristic + # TODO(sharadmv): explore a long term solution to output dim inference - # When we have input/output aliasing, since the output will be mapped, we need - # to make sure to broadcast the input across that dimension if it is not - # mapped. - dims_ = list(dims) - args_ = list(args) - for input_index, _ in input_output_aliases: - dim = dims_[input_index] - if dim is batching.not_mapped: - dims_[input_index] = 0 - args_[input_index] = batching.broadcast(args_[input_index], axis_size, 0) - args = tuple(args_) - dims = tuple(dims_) + # When we have input/output aliasing, since the output will be mapped, we need + # to make sure to broadcast the input across that dimension if it is not + # mapped. + dims_ = list(dims) + args_ = list(args) + for input_index, _ in input_output_aliases: + dim = dims_[input_index] + if dim is batching.not_mapped: + dims_[input_index] = 0 + args_[input_index] = batching.broadcast(args_[input_index], axis_size, 0) + args = tuple(args_) + dims = tuple(dims_) - all_dims = list(dims) + [0] * len(out_shapes) + all_dims = list(dims) + [0] * len(out_shapes) - batched_block_mappings = map(partial(_batch_block_mapping, grid_spec.grid), - avals, all_dims, block_mappings) - batched_in_shapes = tuple( - jax.ShapeDtypeStruct(x.shape if dim is batching.not_mapped else - tuple_insert(x.shape, dim, axis_size), - x.dtype) - for x, dim in zip(in_shapes, dims)) - batched_out_shapes = tuple( - jax.ShapeDtypeStruct(tuple_insert(x.shape, 0, axis_size), x.dtype) - for x in out_shapes) + batched_block_mappings = map(partial(_batch_block_mapping, grid_spec.grid), + avals, all_dims, block_mappings) + batched_in_shapes = tuple( + jax.ShapeDtypeStruct(x.shape if dim is batching.not_mapped else + tuple_insert(x.shape, dim, axis_size), + x.dtype) + for x, dim in zip(in_shapes, dims)) + batched_out_shapes = tuple( + jax.ShapeDtypeStruct(tuple_insert(x.shape, 0, axis_size), x.dtype) + for x in out_shapes) - batched_grid_spec = grid_spec.replace(grid=(axis_size, *grid_spec.grid), - block_mappings=tuple(batched_block_mappings), - mapped_dims=(0,) + tuple(a + 1 for a in + batched_grid_spec = grid_spec.replace(grid=(axis_size, *grid_spec.grid), + block_mappings=tuple(batched_block_mappings), + mapped_dims=(0,) + tuple(a + 1 for a in grid_spec.mapped_dims)) - out = pallas_call_p.bind(*args, jaxpr=jaxpr, name=f"batched_{name}", + new_kernel = SpecializedKernel(kernel.name, jaxpr, kernel.num_consts, + batched_grid_spec, kernel.compiler_params) + new_kernels.append(new_kernel) + + out = pallas_call_p.bind(*args, kernels=new_kernels, name=f"batched_{name}", in_shapes=batched_in_shapes, out_shapes=batched_out_shapes, which_linear=which_linear, - grid_spec=batched_grid_spec, input_output_aliases=input_output_aliases, debug=debug, - interpret=interpret, - **compiler_params) + interpret=interpret) return out, (0,) * len(out) batching.primitive_batchers[pallas_call_p] = _pallas_call_batching_rule +Kernel = Callable[..., Any] +MaybeSpec = Optional[Union[pallas_core.BlockSpec, + Callable[..., pallas_core.BlockSpec]]] + +def _compute_spec(config: Config, spec: MaybeSpec, + ) -> Optional[pallas_core.BlockSpec]: + if callable(spec): + spec = spec(**config.meta) + return spec + +def specialize_kernel(config: Config, + func: Callable, + grid: Optional[pallas_core.Grid], + name: Optional[str], + in_specs: Optional[list[Optional[BlockSpec]]], + out_specs: Optional[list[Optional[BlockSpec]]], + in_avals: tuple[jax_core.ShapedArray, ...], + out_avals: tuple[jax_core.ShapedArray, ...], + in_tree: tree_util.PyTreeDef, + compiler_params: dict[str, Any] + ) -> tuple[SpecializedKernel, ...]: + specialized_grid = grid + if callable(specialized_grid): + specialized_grid = specialized_grid(**config.meta) + specialized_grid = pallas_core.preprocess_grid(specialized_grid) + specialized_in_specs = map(partial(_compute_spec, config), in_specs) + specialized_out_specs = map(partial(_compute_spec, config), out_specs) + if specialized_grid == (): + in_ref_avals = [state.shaped_array_ref(arg.shape, arg.dtype) + for arg in in_avals] + out_ref_avals = [state.shaped_array_ref(arg.shape, arg.dtype) + for arg in out_avals] + else: + in_ref_avals = [ + state.shaped_array_ref( + pallas_core.compute_shape_from_block_spec(block_spec, aval.shape), + aval.dtype) + for block_spec, aval in zip(specialized_in_specs, in_avals)] + out_ref_avals = [ + state.shaped_array_ref( + pallas_core.compute_shape_from_block_spec(block_spec, aval.shape), + aval.dtype) + for block_spec, aval in zip(specialized_out_specs, out_avals)] + in_block_mappings = map(partial(pallas_core.convert_block_spec_to_block_mapping, specialized_grid), + specialized_in_specs) + out_block_mappings = map(partial(pallas_core.convert_block_spec_to_block_mapping, specialized_grid), + specialized_out_specs) + grid_spec = pallas_core.GridSpec(specialized_grid, (*in_block_mappings, *out_block_mappings), ()) + jaxpr, consts, out_tree = tracing_utils.initial_style_open_jaxpr( + func, in_tree, tuple((*in_ref_avals, *out_ref_avals)), "pallas_call", **config.meta) + return SpecializedKernel("foo", jaxpr, len(consts), grid_spec, + dict(compiler_params, **config.compiler_params)), consts, out_tree + def pallas_call(f: Callable, out_shape: Any, *, debug: bool = False, grid: Optional[Grid] = None, in_specs: Optional[Sequence[Optional[BlockSpec]]] = None, @@ -286,12 +356,21 @@ def pallas_call(f: Callable, out_shape: Any, *, debug: bool = False, input_output_aliases: Dict[int, int] = {}, interpret: bool = False, name: Optional[str] = None, + autotuning_configs: Optional[list[Config]] = None, **compiler_params: Any): + singleton = False if not isinstance(out_shape, (tuple, list)): out_shape = (out_shape,) singleton = True if not isinstance(out_shape, tuple): out_shape = tuple(out_shape) + if in_specs is not None and not isinstance(in_specs, (tuple, list)): + in_specs = (in_specs,) + if out_specs is not None and not isinstance(out_specs, (tuple, list)): + out_specs = (out_specs,) + + if not name: + name = f.__name__ if hasattr(f, "__name__") else "unnamed" @jax.jit def wrapped(*args): @@ -299,20 +378,40 @@ def wrapped(*args): flat_out_shapes, out_tree = tree_util.tree_flatten(out_shape) jaxpr_in_tree = tree_util.tree_structure((*args, *out_shape)) - if not isinstance(f, pallas_core.Kernel): - kernel = pallas_core.Kernel(f, name, grid, in_specs, out_specs) - else: - kernel = f flat_in_avals = tuple(jax_core.raise_to_shaped(jax_core.get_aval(a)) for a in flat_args) flat_out_avals = tuple(jax_core.ShapedArray(a.shape, a.dtype) for a in flat_out_shapes) - specialized_kernels, consts, out_tree = kernel.specialize( - flat_in_avals, flat_out_avals, jaxpr_in_tree) + kernels = [] + flat_in_specs = in_specs + flat_out_specs = out_specs + if flat_in_specs is None: + flat_in_specs = [None] * len(flat_in_avals) + if flat_out_specs is None: + flat_out_specs = [None] * len(flat_out_avals) + all_consts = [] + if autotuning_configs is None: + specialized_kernel, consts, jaxpr_out_tree = specialize_kernel( + Config({}, {}), f, grid, name, flat_in_specs, flat_out_specs, flat_in_avals, + flat_out_avals, jaxpr_in_tree, compiler_params) + kernels.append(specialized_kernel) + all_consts.extend(consts) + else: + if len(autotuning_configs) == 0: + raise ValueError("Cannot pass in empty autotuning configs") + for config in autotuning_configs: + specialized_kernel, consts, jaxpr_out_tree = specialize_kernel( + config, f, grid, name, flat_in_specs, flat_out_specs, flat_in_avals, flat_out_avals, + jaxpr_in_tree, compiler_params) + kernels.append(specialized_kernel) + all_consts.extend(consts) + if all_consts: + raise NotImplementedError("Cannot handle consts.") + del jaxpr_out_tree which_linear = (False,) * len(flat_args) out_flat = pallas_call_p.bind( - *consts, *flat_args, - kernels=specialized_kernels, + *all_consts, *flat_args, + kernels=kernels, name=name, which_linear=which_linear, in_shapes=tuple(jax.ShapeDtypeStruct(a.shape, a.dtype) @@ -320,9 +419,7 @@ def wrapped(*args): out_shapes=tuple(flat_out_shapes), debug=debug, interpret=interpret, - input_output_aliases=tuple(input_output_aliases.items()), - **compiler_params) - print(out_flat) + input_output_aliases=tuple(input_output_aliases.items())) out = tree_util.tree_unflatten(out_tree, out_flat) if singleton: return out[0] diff --git a/jax_triton/pallas/tracing_utils.py b/jax_triton/pallas/tracing_utils.py new file mode 100644 index 00000000..5415399e --- /dev/null +++ b/jax_triton/pallas/tracing_utils.py @@ -0,0 +1,43 @@ +# Copyright 2022 The jax_triton Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Module for tracing utilities.""" +from typing import Any, Callable, Optional + +from jax.interpreters import partial_eval as pe +from jax._src import api_util +from jax._src import core as jax_core +from jax._src import linear_util as lu +from jax._src.util import weakref_lru_cache, safe_map, safe_zip, HashablePartial +from jax._src.lax.control_flow import for_loop + +map, unsafe_map = safe_map, map +zip, unsafe_zip = safe_zip, zip + +@weakref_lru_cache +def initial_style_open_jaxpr(fun: Callable, in_tree, in_avals, primitive_name, *args, **kwargs): + wrapped_fun, out_tree_thunk = api_util.flatten_fun_nokwargs( + lu.wrap_init(HashablePartial(fun, *args, **kwargs)), in_tree) + debug_info = pe.debug_info(fun, in_tree, out_tree_thunk, False, + primitive_name or "") + jaxpr, consts = initial_style_flat_jaxpr(wrapped_fun, in_avals, + debug_info=debug_info) + return jaxpr, consts, out_tree_thunk() + +def initial_style_flat_jaxpr(fun: lu.WrappedFun, in_avals, + debug_info: Optional[jax_core.JaxprDebugInfo] = None + ) -> tuple[jax_core.Jaxpr, list[Any]]: + jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(fun, in_avals, debug_info) + jaxpr = for_loop._hoist_consts_to_refs(jaxpr) + return jaxpr, consts + diff --git a/jax_triton/pallas/triton_ir_lowering.py b/jax_triton/pallas/triton_ir_lowering.py index b3973372..8dbdfd20 100644 --- a/jax_triton/pallas/triton_ir_lowering.py +++ b/jax_triton/pallas/triton_ir_lowering.py @@ -46,6 +46,7 @@ from triton.language import ir as tl_ir import triton.libtriton.triton as _triton +from jax_triton import triton_lib from jax_triton import triton_kernel_call_lib from jax_triton import utils as triton_utils from jax_triton.pallas import primitives @@ -136,7 +137,7 @@ def _process_grid_to_3d_grid(builder, grid_spec: GridSpec): assert len(out_indices) == len(grid_spec.grid) return new_grid, out_indices -def lower_jaxpr_to_triton_module(jaxpr: jax_core.Jaxpr, in_shapes, grid_spec: GridSpec, +def lower_jaxpr_to_triton_module(jaxpr: jax_core.Jaxpr, num_consts: int, in_shapes, grid_spec: GridSpec, name: str) -> tl_ir.module: jaxpr, _ = pe.dce_jaxpr(jaxpr, [True] * len(jaxpr.outvars), instantiate=True) ir_context = tl_ir.context() @@ -167,11 +168,13 @@ def lower_jaxpr_to_triton_module(jaxpr: jax_core.Jaxpr, in_shapes, grid_spec: Gr local_program_ids) start_indices = map(partial(_eval_index_map, ctx, program_ids), grid_spec.block_mappings) - block_infos = [BlockInfo(shape, - start_idx, block_mapping.block_shape) - if block_mapping is not None else None - for shape, block_mapping, start_idx in - zip(in_shapes, grid_spec.block_mappings, start_indices)] + arg_block_infos = [BlockInfo(shape, + start_idx, block_mapping.block_shape) + if block_mapping is not None else None + for shape, block_mapping, start_idx in + zip(in_shapes, grid_spec.block_mappings, start_indices)] + const_block_infos = [None] * num_consts + block_infos = [*const_block_infos, *arg_block_infos] () = lower_jaxpr_to_triton_ir(ctx, jaxpr, block_infos, *args) ctx.builder.ret_void() ctx.builder.set_insert_block(insert_pt) @@ -798,21 +801,10 @@ def _while_lowering_rule(ctx: TritonLoweringRuleContext, *args, cond_nconsts, triton_lowering_rules[lax.while_p] = _while_lowering_rule @weakref_lru_cache -def compile_jaxpr(jaxpr: jax_core.Jaxpr, in_shapes, grid_spec: GridSpec, +def compile_jaxpr(jaxpr: jax_core.Jaxpr, num_consts: int, in_shapes, grid_spec: GridSpec, name: str, num_warps: int, num_stages: int ) -> TritonCompilationResult: - lowering_result = lower_jaxpr_to_triton_module(jaxpr, in_shapes, grid_spec, name) - backend = _triton.runtime.backend.CUDA - device = 0 - name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, lowering_result.module, device, - num_warps, num_stages, {}, 0) - return TritonCompilationResult(name, asm, shared_mem, lowering_result) - -@weakref_lru_cache -def compile_jaxpr(jaxpr: jax_core.Jaxpr, in_shapes, grid_spec: GridSpec, - name: str, num_warps: int, num_stages: int - ) -> TritonCompilationResult: - lowering_result = lower_jaxpr_to_triton_module(jaxpr, in_shapes, grid_spec, name) + lowering_result = lower_jaxpr_to_triton_module(jaxpr, num_consts, in_shapes, grid_spec, name) backend = _triton.runtime.backend.CUDA device = 0 name, asm, shared_mem = _triton.code_gen.compile_ttir( @@ -821,59 +813,96 @@ def compile_jaxpr(jaxpr: jax_core.Jaxpr, in_shapes, grid_spec: GridSpec, def pallas_call_lowering(ctx: mlir.LoweringRuleContext, *in_nodes, - jaxpr: jax_core.Jaxpr, + kernels: Sequence[pallas_core.SpecializedKernel], name: str, in_shapes: Tuple[jax.ShapeDtypeStruct, ...], out_shapes: Tuple[jax.ShapeDtypeStruct, ...], which_linear: Tuple[bool, ...], interpret: bool, debug: bool, - input_output_aliases: Tuple[Tuple[int, int], ...], - grid_spec: GridSpec, - **compiler_params: Any): + input_output_aliases: Tuple[Tuple[int, int], ...]): if interpret: return mlir.lower_fun(pallas_call_p.impl, multiple_results=True)( - ctx, *in_nodes, jaxpr=jaxpr, name=name, out_shapes=out_shapes, + ctx, *in_nodes, kernels=kernels, name=name, out_shapes=out_shapes, in_shapes=in_shapes, which_linear=which_linear, interpret=interpret, debug=debug, - input_output_aliases=input_output_aliases, - grid_spec=grid_spec, **compiler_params) - num_warps = compiler_params.get("num_warps", 4) - num_stages = compiler_params.get("num_stages", 3) - compilation_result = compile_jaxpr(jaxpr, tuple((*in_shapes, *out_shapes)), - grid_spec, name, num_warps, num_stages) - name = compilation_result.name - asm = compilation_result.asm - shared_mem = compilation_result.shared_mem - if debug: - print(jaxpr) - print(grid_spec) - lowering_result = compilation_result.lowering_result - if debug: - lowering_result.module.print() + input_output_aliases=input_output_aliases) + lowered_kernels = [] + for kernel in kernels: + if debug: + print(kernel.jaxpr) + print(kernel.grid_spec) + compiler_params = kernel.compiler_params + num_warps = compiler_params.get("num_warps", 4) + num_stages = compiler_params.get("num_stages", 3) + compilation_result = compile_jaxpr(kernel.jaxpr, kernel.num_consts, + tuple((*in_shapes, *out_shapes)), + kernel.grid_spec, kernel.name, num_warps, num_stages) + name = compilation_result.name + asm = compilation_result.asm + shared_mem = compilation_result.shared_mem + lowering_result = compilation_result.lowering_result + if debug: + lowering_result.module.print() + lowered_kernels.append((name, asm, shared_mem, lowering_result)) out_type = ir.TupleType.get_tuple([ ir.RankedTensorType.get(out_shape.shape, mlir.dtype_to_ir_type(out_shape.dtype)) for out_shape in ctx.avals_out]) i32_type = ir.IntegerType.get_signless(32) - kernel = triton_kernel_call_lib.TritonKernel( - asm["cubin"], name, num_warps, shared_mem - ) - - grid = triton_utils.normalize_grid( - compilation_result.lowering_result.grid, metaparams={}) - kernel_params = [] - for _ in range(len(in_shapes) + len(out_shapes)): - kernel_params.append( - triton_kernel_call_lib.create_array_parameter( - 0, # bytes to zero # TODO(cjfj): Expose through user API. - True, # divisible by 16 + if len(lowered_kernels) == 1: + name, asm, shared_mem, lowering_result = lowered_kernels[0] + kernel = triton_kernel_call_lib.TritonKernel( + asm["cubin"], name, num_warps, shared_mem + ) + grid = triton_utils.normalize_grid( + lowering_result.grid, metaparams={}) + # All arguments are buffers. + kernel_params = [] + for _ in range(len(in_shapes) + len(out_shapes)): + kernel_params.append( + triton_kernel_call_lib.create_array_parameter( + 0, # bytes to zero # TODO(cjfj): Expose through user API. + True, # divisible by 16 + ) + ) + kernel_call = triton_kernel_call_lib.TritonKernelCall( + kernel, grid[0], grid[1], grid[2], kernel_params + ) + elif len(lowered_kernels) > 1: + kernel_calls = [] + for name, asm, shared_mem, lowering_result in lowered_kernels: + kernel = triton_kernel_call_lib.TritonKernel( + asm["cubin"], name, num_warps, shared_mem + ) + grid = triton_utils.normalize_grid( + lowering_result.grid, metaparams={}) + # All arguments are buffers. + kernel_params = [] + for _ in range(len(in_shapes) + len(out_shapes)): + kernel_params.append( + triton_kernel_call_lib.create_array_parameter( + 0, # bytes to zero # TODO(cjfj): Expose through user API. + True, # divisible by 16 + ) ) + kernel_call = triton_kernel_call_lib.TritonKernelCall( + kernel, grid[0], grid[1], grid[2], kernel_params + ) + kernel_calls.append(kernel_call) + input_output_aliases_with_sizes = tuple( + (input_idx, output_idx, triton_lib.aval_size_bytes(ctx.avals_in[input_idx])) + for input_idx, output_idx in input_output_aliases ) - kernel_call = triton_kernel_call_lib.TritonKernelCall( - kernel, grid[0], grid[1], grid[2], kernel_params - ) + + kernel_call = triton_kernel_call_lib.TritonAutotunedKernelCall( + name, + [(call, f"config{i}") for i, call in enumerate(kernel_calls)], + input_output_aliases_with_sizes, + ) + else: + raise ValueError("Cannot have 0 kernels.") ctx.module_context.add_keepalive(kernel_call) output_operand_aliases = ir.ArrayAttr.get([ diff --git a/tests/pallas_test.py b/tests/pallas_test.py index 9b4e32fa..076c1677 100644 --- a/tests/pallas_test.py +++ b/tests/pallas_test.py @@ -33,9 +33,9 @@ import jax.numpy as jnp import jax_triton as jt from jax_triton import pallas as pl -from jax_triton.pallas.core import _initial_style_open_jaxpr from jax_triton.pallas.ops import attention from jax_triton.pallas.ops import layer_norm +from jax_triton.pallas.tracing_utils import initial_style_open_jaxpr try: from jax_triton.pallas.triton_ir_lowering import compile_jaxpr except ModuleNotFoundError: @@ -123,7 +123,7 @@ def setUp(self): super().setUp() if compile_jaxpr: compile_jaxpr.cache_clear() - _initial_style_open_jaxpr.cache_clear() + initial_style_open_jaxpr.cache_clear() def pallas_call(self, *args, **kwargs): return pl.pallas_call(*args, **kwargs, interpret=self.INTERPRET) @@ -851,6 +851,49 @@ def body(x_ref): lu.wrap_init(body), [state.shaped_array_ref((4, 3, 2), jnp.int32)]) self.assertIn(expected, jaxpr.pretty_print(use_color=False)) +class PallasAutotuningTest(PallasTest): + + def test_basic_autotuning(self): + + @functools.partial( + self.pallas_call, out_shape=jax.ShapeDtypeStruct((8,), jnp.float32), + grid=lambda block_size: 8 // block_size, + autotuning_configs=[ + pl.Config(dict(block_size=2), {}), + pl.Config(dict(block_size=4), {}), + ]) + def add_one(x_ref, o_ref, *, block_size): + idx = pl.program_id(0) * block_size + jnp.arange(block_size) + o_ref[idx] = x_ref[idx] + 1. + + x = jnp.arange(8.) + np.testing.assert_allclose(add_one(x), x + 1.) + + def test_basic_autotuning_with_block_spec(self): + + @functools.partial( + self.pallas_call, out_shape=jax.ShapeDtypeStruct((8,), jnp.float32), + grid=lambda block_size: 8 // block_size, + in_specs=[ + lambda block_size: pl.BlockSpec(lambda i: i, (block_size,)), + ], + out_specs=[ + lambda block_size: pl.BlockSpec(lambda i: i, (block_size,)), + ], + autotuning_configs=[ + pl.Config(dict(block_size=1), {}), + pl.Config(dict(block_size=2), {}), + pl.Config(dict(block_size=4), {}), + pl.Config(dict(block_size=8), {}), + ], + debug=True) + def add_one(x_ref, o_ref, *, block_size): + del block_size + o_ref[...] = x_ref[...] + 1. + + x = jnp.arange(8.) + np.testing.assert_allclose(add_one(x), x + 1.) + class FusedAttentionTest(parameterized.TestCase): @parameterized.parameters(*[ From 295df3d1b551414f80612c19f6f51d6e3c9b6b88 Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Thu, 30 Mar 2023 20:41:36 +0000 Subject: [PATCH 3/6] Remove lambdas --- jax_triton/pallas/__init__.py | 2 +- jax_triton/pallas/core.py | 14 ++- jax_triton/pallas/pallas_call.py | 113 +++++++++++++++--------- jax_triton/pallas/triton_ir_lowering.py | 8 +- tests/pallas_test.py | 26 +++--- 5 files changed, 102 insertions(+), 61 deletions(-) diff --git a/jax_triton/pallas/__init__.py b/jax_triton/pallas/__init__.py index 98a1c415..df581891 100644 --- a/jax_triton/pallas/__init__.py +++ b/jax_triton/pallas/__init__.py @@ -14,7 +14,7 @@ """Module for pallas, a jaxpr "dialect" for Triton.""" from jax_triton.pallas.core import BlockSpec -from jax_triton.pallas.core import Config +from jax_triton.pallas.core import KernelConfig from jax_triton.pallas.pallas_call import pallas_call from jax_triton.pallas.pallas_call import pallas_call_p from jax_triton.pallas.primitives import atomic_add diff --git a/jax_triton/pallas/core.py b/jax_triton/pallas/core.py index c254b356..7a1405b5 100644 --- a/jax_triton/pallas/core.py +++ b/jax_triton/pallas/core.py @@ -18,7 +18,7 @@ import functools from functools import partial -from typing import Any, Callable, Iterator, List, Optional, Tuple, Union +from typing import Any, Callable, Iterator, List, Optional, Sequence, Tuple, Union import jax.numpy as jnp from jax._src import api_util @@ -95,6 +95,18 @@ class GridSpec: Platform = str + +@dataclasses.dataclass +class KernelConfig: + in_specs: Optional[Sequence[Optional[BlockSpec]]] = None + out_specs: Optional[Sequence[Optional[BlockSpec]]] = None + grid: Optional[Union[Grid, int]] = None + meta: dict[str, Any] = dataclasses.field(default_factory=dict) + compiler_params: dict[Platform, dict[str, Any]] = dataclasses.field(default_factory=dict) + + def replace(self, *args, **kwargs): + return dataclasses.replace(self, *args, **kwargs) + @dataclasses.dataclass class Config: meta: dict[str, Any] diff --git a/jax_triton/pallas/pallas_call.py b/jax_triton/pallas/pallas_call.py index f782c0f9..6ca68970 100644 --- a/jax_triton/pallas/pallas_call.py +++ b/jax_triton/pallas/pallas_call.py @@ -306,24 +306,16 @@ def _compute_spec(config: Config, spec: MaybeSpec, spec = spec(**config.meta) return spec -def specialize_kernel(config: Config, +def specialize_kernel(config: pallas_core.KernelConfig, func: Callable, - grid: Optional[pallas_core.Grid], name: Optional[str], - in_specs: Optional[list[Optional[BlockSpec]]], - out_specs: Optional[list[Optional[BlockSpec]]], in_avals: tuple[jax_core.ShapedArray, ...], out_avals: tuple[jax_core.ShapedArray, ...], in_tree: tree_util.PyTreeDef, compiler_params: dict[str, Any] ) -> tuple[SpecializedKernel, ...]: - specialized_grid = grid - if callable(specialized_grid): - specialized_grid = specialized_grid(**config.meta) - specialized_grid = pallas_core.preprocess_grid(specialized_grid) - specialized_in_specs = map(partial(_compute_spec, config), in_specs) - specialized_out_specs = map(partial(_compute_spec, config), out_specs) - if specialized_grid == (): + grid = config.grid + if grid == (): in_ref_avals = [state.shaped_array_ref(arg.shape, arg.dtype) for arg in in_avals] out_ref_avals = [state.shaped_array_ref(arg.shape, arg.dtype) @@ -333,42 +325,76 @@ def specialize_kernel(config: Config, state.shaped_array_ref( pallas_core.compute_shape_from_block_spec(block_spec, aval.shape), aval.dtype) - for block_spec, aval in zip(specialized_in_specs, in_avals)] + for block_spec, aval in zip(config.in_specs, in_avals)] out_ref_avals = [ state.shaped_array_ref( pallas_core.compute_shape_from_block_spec(block_spec, aval.shape), aval.dtype) - for block_spec, aval in zip(specialized_out_specs, out_avals)] - in_block_mappings = map(partial(pallas_core.convert_block_spec_to_block_mapping, specialized_grid), - specialized_in_specs) - out_block_mappings = map(partial(pallas_core.convert_block_spec_to_block_mapping, specialized_grid), - specialized_out_specs) - grid_spec = pallas_core.GridSpec(specialized_grid, (*in_block_mappings, *out_block_mappings), ()) + for block_spec, aval in zip(config.out_specs, out_avals)] + in_block_mappings = map( + partial(pallas_core.convert_block_spec_to_block_mapping, grid), + config.in_specs) + out_block_mappings = map( + partial(pallas_core.convert_block_spec_to_block_mapping, grid), + config.out_specs) + grid_spec = pallas_core.GridSpec(grid, (*in_block_mappings, *out_block_mappings), ()) jaxpr, consts, out_tree = tracing_utils.initial_style_open_jaxpr( func, in_tree, tuple((*in_ref_avals, *out_ref_avals)), "pallas_call", **config.meta) return SpecializedKernel("foo", jaxpr, len(consts), grid_spec, dict(compiler_params, **config.compiler_params)), consts, out_tree -def pallas_call(f: Callable, out_shape: Any, *, debug: bool = False, +def _canonicalize_kernel_config( + maybe_kernel_config: Optional[pallas_core.KernelConfig], + in_avals: Sequence[jax_core.AbstractValue], + out_avals: Sequence[jax_core.AbstractValue], + in_specs: Optional[Sequence[Optional[BlockSpec]]], + out_specs: Optional[Sequence[Optional[BlockSpec]]], + grid: Optional[Union[Grid, int]], + ) -> pallas_core.KernelConfig: + if not maybe_kernel_config: + config = pallas_core.KernelConfig(in_specs=in_specs, out_specs=out_specs, grid=grid) + else: + config = maybe_kernel_config + grid = maybe_kernel_config.grid + grid, in_specs, out_specs = config.grid, config.in_specs, config.out_specs + grid = pallas_core.preprocess_grid(grid) + if in_specs is not None and not isinstance(in_specs, (tuple, list)): + in_specs = (in_specs,) + if out_specs is not None and not isinstance(out_specs, (tuple, list)): + out_specs = (out_specs,) + if in_specs is None: + in_specs = [None] * len(in_avals) + if out_specs is None: + out_specs = [None] * len(out_avals) + return config.replace(grid=grid, in_specs=in_specs, out_specs=out_specs) + +def pallas_call(f: Callable, out_shape: Any, *, grid: Optional[Grid] = None, + config: Optional[pallas_core.KernelConfig] = None, in_specs: Optional[Sequence[Optional[BlockSpec]]] = None, out_specs: Optional[Sequence[Optional[BlockSpec]]] = None, input_output_aliases: Dict[int, int] = {}, interpret: bool = False, name: Optional[str] = None, - autotuning_configs: Optional[list[Config]] = None, + autotuning_configs: Optional[Sequence[pallas_core.KernelConfig]] = None, + debug: bool = False, **compiler_params: Any): + if config is not None: + if grid is not None or in_specs is not None or out_specs is not None: + raise ValueError("Cannot specify both `config` and any of `grid`, " + "`in_specs`, or `out_specs`.") + if autotuning_configs is not None: + raise ValueError("Cannot specify both `config` and `autotuning_configs`") + if autotuning_configs is not None: + if grid is not None or in_specs is not None or out_specs is not None: + raise ValueError("Cannot specify both `autotuning_configs` and any of `grid`, " + "`in_specs`, or `out_specs`.") singleton = False if not isinstance(out_shape, (tuple, list)): out_shape = (out_shape,) singleton = True if not isinstance(out_shape, tuple): out_shape = tuple(out_shape) - if in_specs is not None and not isinstance(in_specs, (tuple, list)): - in_specs = (in_specs,) - if out_specs is not None and not isinstance(out_specs, (tuple, list)): - out_specs = (out_specs,) - if not name: name = f.__name__ if hasattr(f, "__name__") else "unnamed" @@ -382,29 +408,32 @@ def wrapped(*args): for a in flat_args) flat_out_avals = tuple(jax_core.ShapedArray(a.shape, a.dtype) for a in flat_out_shapes) + canonicalized_configs = [] + if autotuning_configs is None: + canonicalized_configs.append(_canonicalize_kernel_config(config, + flat_in_avals, + flat_out_avals, + in_specs, + out_specs, + grid)) + else: + canonicalized_configs.extend(map(partial(_canonicalize_kernel_config, + in_avals=flat_in_avals, + out_avals=flat_out_avals, + in_specs=in_specs, + out_specs=out_specs, + grid=grid), + autotuning_configs)) kernels = [] - flat_in_specs = in_specs - flat_out_specs = out_specs - if flat_in_specs is None: - flat_in_specs = [None] * len(flat_in_avals) - if flat_out_specs is None: - flat_out_specs = [None] * len(flat_out_avals) all_consts = [] - if autotuning_configs is None: + if len(canonicalized_configs) == 0: + raise ValueError("Cannot pass in empty autotuning configs") + for canonicalized_config in canonicalized_configs: specialized_kernel, consts, jaxpr_out_tree = specialize_kernel( - Config({}, {}), f, grid, name, flat_in_specs, flat_out_specs, flat_in_avals, + canonicalized_config, f, name, flat_in_avals, flat_out_avals, jaxpr_in_tree, compiler_params) kernels.append(specialized_kernel) all_consts.extend(consts) - else: - if len(autotuning_configs) == 0: - raise ValueError("Cannot pass in empty autotuning configs") - for config in autotuning_configs: - specialized_kernel, consts, jaxpr_out_tree = specialize_kernel( - config, f, grid, name, flat_in_specs, flat_out_specs, flat_in_avals, flat_out_avals, - jaxpr_in_tree, compiler_params) - kernels.append(specialized_kernel) - all_consts.extend(consts) if all_consts: raise NotImplementedError("Cannot handle consts.") del jaxpr_out_tree diff --git a/jax_triton/pallas/triton_ir_lowering.py b/jax_triton/pallas/triton_ir_lowering.py index 8dbdfd20..bbca5b54 100644 --- a/jax_triton/pallas/triton_ir_lowering.py +++ b/jax_triton/pallas/triton_ir_lowering.py @@ -833,9 +833,11 @@ def pallas_call_lowering(ctx: mlir.LoweringRuleContext, *in_nodes, if debug: print(kernel.jaxpr) print(kernel.grid_spec) - compiler_params = kernel.compiler_params - num_warps = compiler_params.get("num_warps", 4) - num_stages = compiler_params.get("num_stages", 3) + compiler_params = dict(kernel.compiler_params) + num_warps = compiler_params.pop("num_warps", 4) + num_stages = compiler_params.pop("num_stages", 3) + if compiler_params: + raise ValueError(f"Invalid compiler params: {compiler_params}") compilation_result = compile_jaxpr(kernel.jaxpr, kernel.num_consts, tuple((*in_shapes, *out_shapes)), kernel.grid_spec, kernel.name, num_warps, num_stages) diff --git a/tests/pallas_test.py b/tests/pallas_test.py index 076c1677..92051f27 100644 --- a/tests/pallas_test.py +++ b/tests/pallas_test.py @@ -857,10 +857,10 @@ def test_basic_autotuning(self): @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((8,), jnp.float32), - grid=lambda block_size: 8 // block_size, autotuning_configs=[ - pl.Config(dict(block_size=2), {}), - pl.Config(dict(block_size=4), {}), + pl.KernelConfig(meta=dict(block_size=block_size), + grid=8 // block_size) + for block_size in [1, 2, 4, 8] ]) def add_one(x_ref, o_ref, *, block_size): idx = pl.program_id(0) * block_size + jnp.arange(block_size) @@ -873,18 +873,16 @@ def test_basic_autotuning_with_block_spec(self): @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((8,), jnp.float32), - grid=lambda block_size: 8 // block_size, - in_specs=[ - lambda block_size: pl.BlockSpec(lambda i: i, (block_size,)), - ], - out_specs=[ - lambda block_size: pl.BlockSpec(lambda i: i, (block_size,)), - ], autotuning_configs=[ - pl.Config(dict(block_size=1), {}), - pl.Config(dict(block_size=2), {}), - pl.Config(dict(block_size=4), {}), - pl.Config(dict(block_size=8), {}), + pl.KernelConfig(meta=dict(block_size=block_size), + in_specs=[ + pl.BlockSpec(lambda i: i, (block_size,)) + ], + out_specs=[ + pl.BlockSpec(lambda i: i, (block_size,)) + ], + grid=8 // block_size) + for block_size in [1, 2, 4, 8] ], debug=True) def add_one(x_ref, o_ref, *, block_size): From 5ce77274b3c9689ef869b39a6bb1ddbf7d4c4941 Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Thu, 30 Mar 2023 20:49:07 +0000 Subject: [PATCH 4/6] Rebase --- jax_triton/pallas/__init__.py | 7 ++----- jax_triton/pallas/triton_ir_lowering.py | 2 +- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/jax_triton/pallas/__init__.py b/jax_triton/pallas/__init__.py index df581891..6ddbc645 100644 --- a/jax_triton/pallas/__init__.py +++ b/jax_triton/pallas/__init__.py @@ -36,8 +36,5 @@ from jax_triton.pallas.primitives import swap from jax_triton.utils import cdiv -try: - from jax_triton.pallas import triton_ir_lowering - del triton_ir_lowering -except (ImportError, ModuleNotFoundError): - pass +from jax_triton.pallas import triton_ir_lowering +del triton_ir_lowering diff --git a/jax_triton/pallas/triton_ir_lowering.py b/jax_triton/pallas/triton_ir_lowering.py index bbca5b54..fbd94db8 100644 --- a/jax_triton/pallas/triton_ir_lowering.py +++ b/jax_triton/pallas/triton_ir_lowering.py @@ -44,7 +44,7 @@ import triton import triton.language as tl from triton.language import ir as tl_ir -import triton.libtriton.triton as _triton +import triton._C.libtriton.triton as _triton from jax_triton import triton_lib from jax_triton import triton_kernel_call_lib From 4b73bde897001c220d21267b0c266a78c25bb1d5 Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Thu, 30 Mar 2023 21:49:49 +0000 Subject: [PATCH 5/6] Treat names properly --- jax_triton/pallas/__init__.py | 7 +++-- jax_triton/pallas/core.py | 13 +--------- jax_triton/pallas/pallas_call.py | 34 +++++++++++++++---------- jax_triton/pallas/triton_ir_lowering.py | 10 +++++++- tests/pallas_test.py | 10 +++++--- 5 files changed, 42 insertions(+), 32 deletions(-) diff --git a/jax_triton/pallas/__init__.py b/jax_triton/pallas/__init__.py index 6ddbc645..df581891 100644 --- a/jax_triton/pallas/__init__.py +++ b/jax_triton/pallas/__init__.py @@ -36,5 +36,8 @@ from jax_triton.pallas.primitives import swap from jax_triton.utils import cdiv -from jax_triton.pallas import triton_ir_lowering -del triton_ir_lowering +try: + from jax_triton.pallas import triton_ir_lowering + del triton_ir_lowering +except (ImportError, ModuleNotFoundError): + pass diff --git a/jax_triton/pallas/core.py b/jax_triton/pallas/core.py index 7a1405b5..465b3d10 100644 --- a/jax_triton/pallas/core.py +++ b/jax_triton/pallas/core.py @@ -95,9 +95,9 @@ class GridSpec: Platform = str - @dataclasses.dataclass class KernelConfig: + name: Optional[str] = None in_specs: Optional[Sequence[Optional[BlockSpec]]] = None out_specs: Optional[Sequence[Optional[BlockSpec]]] = None grid: Optional[Union[Grid, int]] = None @@ -107,17 +107,6 @@ class KernelConfig: def replace(self, *args, **kwargs): return dataclasses.replace(self, *args, **kwargs) -@dataclasses.dataclass -class Config: - meta: dict[str, Any] - compiler_params: dict[Platform, dict[str, Any]] - - def to_string(self, platform: str) -> str: - compiler_params = self.compiler_params.get(platform, {}) - return "-".join([*(f"{k}_{v}" for k, v in self.meta.items()), - *(f"{k}_{v}" for k, v in compiler_params.items())]) - - def preprocess_grid(grid: Optional[Union[Grid, int]]) -> Grid: if grid is None: return () diff --git a/jax_triton/pallas/pallas_call.py b/jax_triton/pallas/pallas_call.py index 6ca68970..23396b48 100644 --- a/jax_triton/pallas/pallas_call.py +++ b/jax_triton/pallas/pallas_call.py @@ -47,7 +47,6 @@ Grid = Tuple[int, ...] BlockSpec = pallas_core.BlockSpec BlockMapping = pallas_core.BlockMapping -Config = pallas_core.Config GridSpec = pallas_core.GridSpec SpecializedKernel = pallas_core.SpecializedKernel @@ -299,20 +298,16 @@ def _pallas_call_batching_rule(args, dims, *, Kernel = Callable[..., Any] MaybeSpec = Optional[Union[pallas_core.BlockSpec, Callable[..., pallas_core.BlockSpec]]] - -def _compute_spec(config: Config, spec: MaybeSpec, - ) -> Optional[pallas_core.BlockSpec]: - if callable(spec): - spec = spec(**config.meta) - return spec +Platform = str +CompilerParams = dict[Platform, dict[str, Any]] def specialize_kernel(config: pallas_core.KernelConfig, func: Callable, - name: Optional[str], + name: str, in_avals: tuple[jax_core.ShapedArray, ...], out_avals: tuple[jax_core.ShapedArray, ...], in_tree: tree_util.PyTreeDef, - compiler_params: dict[str, Any] + compiler_params: CompilerParams, ) -> tuple[SpecializedKernel, ...]: grid = config.grid if grid == (): @@ -340,7 +335,9 @@ def specialize_kernel(config: pallas_core.KernelConfig, grid_spec = pallas_core.GridSpec(grid, (*in_block_mappings, *out_block_mappings), ()) jaxpr, consts, out_tree = tracing_utils.initial_style_open_jaxpr( func, in_tree, tuple((*in_ref_avals, *out_ref_avals)), "pallas_call", **config.meta) - return SpecializedKernel("foo", jaxpr, len(consts), grid_spec, + if config.name is not None: + name = f"{name}_{config.name}" + return SpecializedKernel(name, jaxpr, len(consts), grid_spec, dict(compiler_params, **config.compiler_params)), consts, out_tree def _canonicalize_kernel_config( @@ -350,11 +347,16 @@ def _canonicalize_kernel_config( in_specs: Optional[Sequence[Optional[BlockSpec]]], out_specs: Optional[Sequence[Optional[BlockSpec]]], grid: Optional[Union[Grid, int]], + compiler_params: dict[str, Any], ) -> pallas_core.KernelConfig: if not maybe_kernel_config: - config = pallas_core.KernelConfig(in_specs=in_specs, out_specs=out_specs, grid=grid) + config = pallas_core.KernelConfig(in_specs=in_specs, out_specs=out_specs, + grid=grid, + compiler_params=compiler_params) else: config = maybe_kernel_config + config = config.replace(compiler_params=dict(compiler_params, + **config.compiler_params)) grid = maybe_kernel_config.grid grid, in_specs, out_specs = config.grid, config.in_specs, config.out_specs grid = pallas_core.preprocess_grid(grid) @@ -378,7 +380,7 @@ def pallas_call(f: Callable, out_shape: Any, *, name: Optional[str] = None, autotuning_configs: Optional[Sequence[pallas_core.KernelConfig]] = None, debug: bool = False, - **compiler_params: Any): + compiler_params: Optional[CompilerParams] = None): if config is not None: if grid is not None or in_specs is not None or out_specs is not None: raise ValueError("Cannot specify both `config` and any of `grid`, " @@ -389,6 +391,8 @@ def pallas_call(f: Callable, out_shape: Any, *, if grid is not None or in_specs is not None or out_specs is not None: raise ValueError("Cannot specify both `autotuning_configs` and any of `grid`, " "`in_specs`, or `out_specs`.") + if compiler_params is None: + compiler_params = {} singleton = False if not isinstance(out_shape, (tuple, list)): out_shape = (out_shape,) @@ -415,14 +419,16 @@ def wrapped(*args): flat_out_avals, in_specs, out_specs, - grid)) + grid, + compiler_params)) else: canonicalized_configs.extend(map(partial(_canonicalize_kernel_config, in_avals=flat_in_avals, out_avals=flat_out_avals, in_specs=in_specs, out_specs=out_specs, - grid=grid), + grid=grid, + compiler_params=compiler_params), autotuning_configs)) kernels = [] all_consts = [] diff --git a/jax_triton/pallas/triton_ir_lowering.py b/jax_triton/pallas/triton_ir_lowering.py index fbd94db8..8f1c2ce9 100644 --- a/jax_triton/pallas/triton_ir_lowering.py +++ b/jax_triton/pallas/triton_ir_lowering.py @@ -800,10 +800,17 @@ def _while_lowering_rule(ctx: TritonLoweringRuleContext, *args, cond_nconsts, return post_args triton_lowering_rules[lax.while_p] = _while_lowering_rule +def _mangle_name(name: str): + name = name.replace("-", "_") + name = name.replace("-", "_") + name = name.replace("=", "_") + return name + @weakref_lru_cache def compile_jaxpr(jaxpr: jax_core.Jaxpr, num_consts: int, in_shapes, grid_spec: GridSpec, name: str, num_warps: int, num_stages: int ) -> TritonCompilationResult: + name = _mangle_name(name) lowering_result = lower_jaxpr_to_triton_module(jaxpr, num_consts, in_shapes, grid_spec, name) backend = _triton.runtime.backend.CUDA device = 0 @@ -833,7 +840,8 @@ def pallas_call_lowering(ctx: mlir.LoweringRuleContext, *in_nodes, if debug: print(kernel.jaxpr) print(kernel.grid_spec) - compiler_params = dict(kernel.compiler_params) + print(kernel.compiler_params) + compiler_params = dict(kernel.compiler_params).get("triton", {}) num_warps = compiler_params.pop("num_warps", 4) num_stages = compiler_params.pop("num_stages", 3) if compiler_params: diff --git a/tests/pallas_test.py b/tests/pallas_test.py index 92051f27..153cb939 100644 --- a/tests/pallas_test.py +++ b/tests/pallas_test.py @@ -874,15 +874,19 @@ def test_basic_autotuning_with_block_spec(self): @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((8,), jnp.float32), autotuning_configs=[ - pl.KernelConfig(meta=dict(block_size=block_size), + pl.KernelConfig(name=f"block_size={block_size}_num_warps={num_warps}_num_stages={num_stages}", + meta=dict(block_size=block_size), in_specs=[ pl.BlockSpec(lambda i: i, (block_size,)) ], out_specs=[ pl.BlockSpec(lambda i: i, (block_size,)) ], - grid=8 // block_size) - for block_size in [1, 2, 4, 8] + grid=8 // block_size, + compiler_params=dict(triton=dict(num_warps=num_warps, + num_stages=num_stages))) + for block_size in [4, 8] + for num_warps, num_stages in zip([4, 4], [3, 2]) ], debug=True) def add_one(x_ref, o_ref, *, block_size): From 9fee747d0ab025a6b060d34e29ad2b5749fcba43 Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Thu, 30 Mar 2023 23:07:23 +0000 Subject: [PATCH 6/6] Add matmul example --- jax_triton/pallas/ops/matmul.py | 80 +++++++++++++++++++++++++ jax_triton/pallas/tracing_utils.py | 1 + jax_triton/pallas/triton_ir_lowering.py | 9 +-- tests/pallas_test.py | 45 ++++++++++++++ 4 files changed, 131 insertions(+), 4 deletions(-) create mode 100644 jax_triton/pallas/ops/matmul.py diff --git a/jax_triton/pallas/ops/matmul.py b/jax_triton/pallas/ops/matmul.py new file mode 100644 index 00000000..a9c7b89d --- /dev/null +++ b/jax_triton/pallas/ops/matmul.py @@ -0,0 +1,80 @@ +# Copyright 2023 The jax_triton Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Module containing fused matmuls.""" +import functools + +from typing import Any, Optional + +import jax +import jax.numpy as jnp +from jax._src.lax.control_flow.for_loop import for_loop + +import jax_triton as jt +from jax_triton import pallas as pl + +def _compute_bound_configs(): + yield from [ + dict(bm=128, bn=256, bk=32, compiler_params=dict(triton=dict(num_stages=3, num_warps=8))), + dict(bm=256, bn=128, bk=32, compiler_params=dict(triton=dict(num_stages=3, num_warps=8))), + dict(bm=256, bn=64 , bk=32, compiler_params=dict(triton=dict(num_stages=4, num_warps=4))), + dict(bm=64, bn=256, bk=32, compiler_params=dict(triton=dict(num_stages=4, num_warps=4))), + dict(bm=128, bn=128, bk=32, compiler_params=dict(triton=dict(num_stages=4, num_warps=4))), + dict(bm=128, bn=64, bk=32, compiler_params=dict(triton=dict(num_stages=4, num_warps=4))), + dict(bm=64, bn=128, bk=32, compiler_params=dict(triton=dict(num_stages=4, num_warps=4))), + dict(bm=128, bn=32, bk=32, compiler_params=dict(triton=dict(num_stages=4, num_warps=4))), + dict(bm=64, bn=32, bk=32, compiler_params=dict(triton=dict(num_stages=5, num_warps=2))), + ] + +@functools.partial(jax.jit, static_argnames=["interpret", "debug"]) +def matmul(x, y, *, interpret=False, debug=False): + # TODO(sharadmv): make this implementation better + # 1. reordered programs for better L2 + # 2. split K + # 3. masking + m, n, k = x.shape[0], y.shape[1], x.shape[1] + + @functools.partial( + pl.pallas_call, out_shape=jax.ShapeDtypeStruct((m, n), jnp.float32), + interpret=interpret, + debug=debug, + autotuning_configs=[ + pl.KernelConfig( + name=f"bm={config['bm']}_bn={config['bn']}_bk={config['bk']}", + meta=dict(bk=config["bk"]), + in_specs=[ + pl.BlockSpec(lambda i, _: (i, 0), (config["bm"], x.shape[1])), + pl.BlockSpec(lambda _, j: (0, j), (y.shape[0], config["bn"])) + ], + out_specs=pl.BlockSpec(lambda i, j: (i, j), (config["bm"], + config["bn"])), + grid=(jt.cdiv(m, config["bm"]), jt.cdiv(n, config["bn"])), + compiler_params=config["compiler_params"], + ) + for config in _compute_bound_configs() + ] + ) + def matmul_kernel(x_ref, y_ref, o_ref, *, bk: int): + acc = jnp.zeros(o_ref.shape, dtype=jnp.float32) + def body(i, acc_ref): + x_block = pl.load(x_ref, (slice(None), pl.ds(i * bk, bk))) + y_block = pl.load(y_ref, (pl.ds(i * bk, bk), slice(None))) + acc_ref[:, :] += jnp.dot(x_block, y_block) + acc = for_loop(k // bk, body, acc).astype(o_ref.dtype) + o_ref[:, :] = acc + return matmul_kernel(x, y) + +@jax.jit +def matmul_reference(x, y): + return jnp.dot(x, y) diff --git a/jax_triton/pallas/tracing_utils.py b/jax_triton/pallas/tracing_utils.py index 5415399e..4ec68658 100644 --- a/jax_triton/pallas/tracing_utils.py +++ b/jax_triton/pallas/tracing_utils.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Module for tracing utilities.""" + from typing import Any, Callable, Optional from jax.interpreters import partial_eval as pe diff --git a/jax_triton/pallas/triton_ir_lowering.py b/jax_triton/pallas/triton_ir_lowering.py index 8f1c2ce9..96f2e1a2 100644 --- a/jax_triton/pallas/triton_ir_lowering.py +++ b/jax_triton/pallas/triton_ir_lowering.py @@ -848,21 +848,22 @@ def pallas_call_lowering(ctx: mlir.LoweringRuleContext, *in_nodes, raise ValueError(f"Invalid compiler params: {compiler_params}") compilation_result = compile_jaxpr(kernel.jaxpr, kernel.num_consts, tuple((*in_shapes, *out_shapes)), - kernel.grid_spec, kernel.name, num_warps, num_stages) + kernel.grid_spec, kernel.name, num_warps, + num_stages) name = compilation_result.name asm = compilation_result.asm shared_mem = compilation_result.shared_mem lowering_result = compilation_result.lowering_result if debug: lowering_result.module.print() - lowered_kernels.append((name, asm, shared_mem, lowering_result)) + lowered_kernels.append((name, asm, shared_mem, lowering_result, num_warps)) out_type = ir.TupleType.get_tuple([ ir.RankedTensorType.get(out_shape.shape, mlir.dtype_to_ir_type(out_shape.dtype)) for out_shape in ctx.avals_out]) i32_type = ir.IntegerType.get_signless(32) if len(lowered_kernels) == 1: - name, asm, shared_mem, lowering_result = lowered_kernels[0] + name, asm, shared_mem, lowering_result, num_warps = lowered_kernels[0] kernel = triton_kernel_call_lib.TritonKernel( asm["cubin"], name, num_warps, shared_mem ) @@ -882,7 +883,7 @@ def pallas_call_lowering(ctx: mlir.LoweringRuleContext, *in_nodes, ) elif len(lowered_kernels) > 1: kernel_calls = [] - for name, asm, shared_mem, lowering_result in lowered_kernels: + for name, asm, shared_mem, lowering_result, num_warps in lowered_kernels: kernel = triton_kernel_call_lib.TritonKernel( asm["cubin"], name, num_warps, shared_mem ) diff --git a/tests/pallas_test.py b/tests/pallas_test.py index 153cb939..dd7cd2e1 100644 --- a/tests/pallas_test.py +++ b/tests/pallas_test.py @@ -35,6 +35,7 @@ from jax_triton import pallas as pl from jax_triton.pallas.ops import attention from jax_triton.pallas.ops import layer_norm +from jax_triton.pallas.ops import matmul as matmul_ops from jax_triton.pallas.tracing_utils import initial_style_open_jaxpr try: from jax_triton.pallas.triton_ir_lowering import compile_jaxpr @@ -896,6 +897,50 @@ def add_one(x_ref, o_ref, *, block_size): x = jnp.arange(8.) np.testing.assert_allclose(add_one(x), x + 1.) + @parameterized.parameters( + (256, 256, 256), + (1024, 1024, 1024), + (1024, 512, 1024), + (1024, 512, 512), + ) + def test_matmul_autotuned(self, m, n, k): + key = random.PRNGKey(0) + k1, k2 = random.split(key) + x = random.normal(k1, (m, k), dtype=jnp.float16) + y = random.normal(k2, (k, n), dtype=jnp.float16) + np.testing.assert_allclose(matmul_ops.matmul(x, y, debug=True), + matmul_ops.matmul_reference(x, y), + atol=0.05, rtol=0.05) + + + def test_vmap_of_autotuned_kernel(self): + + @functools.partial( + self.pallas_call, out_shape=jax.ShapeDtypeStruct((8,), jnp.float32), + autotuning_configs=[ + pl.KernelConfig(name=f"block_size={block_size}_num_warps={num_warps}_num_stages={num_stages}", + meta=dict(block_size=block_size), + in_specs=[ + pl.BlockSpec(lambda i: i, (block_size,)), + pl.BlockSpec(lambda i: i, (block_size,)) + ], + out_specs=[ + pl.BlockSpec(lambda i: i, (block_size,)) + ], + grid=8 // block_size, + compiler_params=dict(triton=dict(num_warps=num_warps, + num_stages=num_stages))) + for block_size in [4, 8] + for num_warps, num_stages in zip([4, 4], [3, 2]) + ], + debug=True) + def add(x_ref, y_ref, o_ref, *, block_size): + del block_size + o_ref[...] = x_ref[...] + y_ref[...] + + x = jnp.arange(16.).reshape((2, 8)) + np.testing.assert_allclose(jax.vmap(add)(x, x), x + x) + class FusedAttentionTest(parameterized.TestCase): @parameterized.parameters(*[