diff --git a/jax_triton/pallas/__init__.py b/jax_triton/pallas/__init__.py index 6ef4b147..df581891 100644 --- a/jax_triton/pallas/__init__.py +++ b/jax_triton/pallas/__init__.py @@ -14,6 +14,7 @@ """Module for pallas, a jaxpr "dialect" for Triton.""" from jax_triton.pallas.core import BlockSpec +from jax_triton.pallas.core import KernelConfig from jax_triton.pallas.pallas_call import pallas_call from jax_triton.pallas.pallas_call import pallas_call_p from jax_triton.pallas.primitives import atomic_add diff --git a/jax_triton/pallas/core.py b/jax_triton/pallas/core.py index cd165c5c..465b3d10 100644 --- a/jax_triton/pallas/core.py +++ b/jax_triton/pallas/core.py @@ -15,10 +15,29 @@ """Module for pallas-core functionality.""" import contextlib import dataclasses +import functools +from functools import partial -from typing import Any, Callable, Iterator, List, Optional, Tuple, Union +from typing import Any, Callable, Iterator, List, Optional, Sequence, Tuple, Union +import jax.numpy as jnp +from jax._src import api_util from jax._src import core as jax_core +from jax._src import linear_util as lu +from jax._src import state +from jax._src import tree_util +from jax._src.lax.control_flow import for_loop +from jax.interpreters import partial_eval as pe +from jax._src.util import weakref_lru_cache, safe_map, safe_zip +from jax._src.state.types import AbstractRef + +from jax_triton.pallas import tracing_utils + +map, unsafe_map = safe_map, map +zip, unsafe_zip = safe_zip, zip + +Grid = tuple[int, ...] +GridOrLambda = Union[Callable[..., Grid], Grid] @dataclasses.dataclass class GridEnv: @@ -73,3 +92,55 @@ class GridSpec: mapped_dims: Tuple[int, ...] replace = dataclasses.replace + +Platform = str + +@dataclasses.dataclass +class KernelConfig: + name: Optional[str] = None + in_specs: Optional[Sequence[Optional[BlockSpec]]] = None + out_specs: Optional[Sequence[Optional[BlockSpec]]] = None + grid: Optional[Union[Grid, int]] = None + meta: dict[str, Any] = dataclasses.field(default_factory=dict) + compiler_params: dict[Platform, dict[str, Any]] = dataclasses.field(default_factory=dict) + + def replace(self, *args, **kwargs): + return dataclasses.replace(self, *args, **kwargs) + +def preprocess_grid(grid: Optional[Union[Grid, int]]) -> Grid: + if grid is None: + return () + if isinstance(grid, int): + return (grid,) + return grid + +def extract_function_name(f: Callable, name: Optional[str]) -> str: + if name is None: + name = f.__name__ if hasattr(f, "__name__") and f.__name__ else "func" + return name + +def convert_block_spec_to_block_mapping( + grid: Grid, block_spec: Optional[BlockSpec]) -> Optional[BlockMapping]: + if block_spec is None: + return None + in_avals = [jax_core.ShapedArray((), jnp.int32) for _ in grid] + block_shape = tuple( + mapped if s is None else s for s in block_spec.block_shape) + jaxpr, _, consts = pe.trace_to_jaxpr_dynamic( + lu.wrap_init(block_spec.compute_index), in_avals) + return BlockMapping(block_shape, jax_core.ClosedJaxpr(jaxpr, consts)) + +def compute_shape_from_block_spec(block_spec: Optional[BlockSpec], + arg_shape: tuple[int, ...] + ) -> tuple[int, ...]: + if block_spec is None: + return arg_shape + return tuple(s for s in block_spec.block_shape if s is not None) + +@dataclasses.dataclass +class SpecializedKernel: + name: str + jaxpr: jax_core.Jaxpr + num_consts: int + grid_spec: GridSpec + compiler_params: dict[str, Any] diff --git a/jax_triton/pallas/ops/matmul.py b/jax_triton/pallas/ops/matmul.py new file mode 100644 index 00000000..a9c7b89d --- /dev/null +++ b/jax_triton/pallas/ops/matmul.py @@ -0,0 +1,80 @@ +# Copyright 2023 The jax_triton Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Module containing fused matmuls.""" +import functools + +from typing import Any, Optional + +import jax +import jax.numpy as jnp +from jax._src.lax.control_flow.for_loop import for_loop + +import jax_triton as jt +from jax_triton import pallas as pl + +def _compute_bound_configs(): + yield from [ + dict(bm=128, bn=256, bk=32, compiler_params=dict(triton=dict(num_stages=3, num_warps=8))), + dict(bm=256, bn=128, bk=32, compiler_params=dict(triton=dict(num_stages=3, num_warps=8))), + dict(bm=256, bn=64 , bk=32, compiler_params=dict(triton=dict(num_stages=4, num_warps=4))), + dict(bm=64, bn=256, bk=32, compiler_params=dict(triton=dict(num_stages=4, num_warps=4))), + dict(bm=128, bn=128, bk=32, compiler_params=dict(triton=dict(num_stages=4, num_warps=4))), + dict(bm=128, bn=64, bk=32, compiler_params=dict(triton=dict(num_stages=4, num_warps=4))), + dict(bm=64, bn=128, bk=32, compiler_params=dict(triton=dict(num_stages=4, num_warps=4))), + dict(bm=128, bn=32, bk=32, compiler_params=dict(triton=dict(num_stages=4, num_warps=4))), + dict(bm=64, bn=32, bk=32, compiler_params=dict(triton=dict(num_stages=5, num_warps=2))), + ] + +@functools.partial(jax.jit, static_argnames=["interpret", "debug"]) +def matmul(x, y, *, interpret=False, debug=False): + # TODO(sharadmv): make this implementation better + # 1. reordered programs for better L2 + # 2. split K + # 3. masking + m, n, k = x.shape[0], y.shape[1], x.shape[1] + + @functools.partial( + pl.pallas_call, out_shape=jax.ShapeDtypeStruct((m, n), jnp.float32), + interpret=interpret, + debug=debug, + autotuning_configs=[ + pl.KernelConfig( + name=f"bm={config['bm']}_bn={config['bn']}_bk={config['bk']}", + meta=dict(bk=config["bk"]), + in_specs=[ + pl.BlockSpec(lambda i, _: (i, 0), (config["bm"], x.shape[1])), + pl.BlockSpec(lambda _, j: (0, j), (y.shape[0], config["bn"])) + ], + out_specs=pl.BlockSpec(lambda i, j: (i, j), (config["bm"], + config["bn"])), + grid=(jt.cdiv(m, config["bm"]), jt.cdiv(n, config["bn"])), + compiler_params=config["compiler_params"], + ) + for config in _compute_bound_configs() + ] + ) + def matmul_kernel(x_ref, y_ref, o_ref, *, bk: int): + acc = jnp.zeros(o_ref.shape, dtype=jnp.float32) + def body(i, acc_ref): + x_block = pl.load(x_ref, (slice(None), pl.ds(i * bk, bk))) + y_block = pl.load(y_ref, (pl.ds(i * bk, bk), slice(None))) + acc_ref[:, :] += jnp.dot(x_block, y_block) + acc = for_loop(k // bk, body, acc).astype(o_ref.dtype) + o_ref[:, :] = acc + return matmul_kernel(x, y) + +@jax.jit +def matmul_reference(x, y): + return jnp.dot(x, y) diff --git a/jax_triton/pallas/pallas_call.py b/jax_triton/pallas/pallas_call.py index b5a5e56e..23396b48 100644 --- a/jax_triton/pallas/pallas_call.py +++ b/jax_triton/pallas/pallas_call.py @@ -13,6 +13,7 @@ # limitations under the License. """Module for calling pallas functions from JAX.""" +import dataclasses from functools import partial import itertools as it @@ -20,7 +21,6 @@ import jax from jax import api_util -from jax import linear_util as lu from jax import tree_util from jax import lax from jax.interpreters import ad @@ -29,17 +29,17 @@ from jax.interpreters import xla from jax._src import ad_util from jax._src import core as jax_core +from jax._src import linear_util as lu from jax._src.lib.mlir.dialects import mhlo from jax._src import state from jax._src.util import ( - split_list, safe_map, safe_zip, weakref_lru_cache, + split_list, safe_map, safe_zip, tuple_insert, partition_list) -from jax._src.lax.control_flow import for_loop import jax.numpy as jnp import numpy as np -from jax_triton.utils import avals_to_layouts, normalize_grid from jax_triton.pallas import core as pallas_core +from jax_triton.pallas import tracing_utils map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip @@ -48,6 +48,7 @@ BlockSpec = pallas_core.BlockSpec BlockMapping = pallas_core.BlockMapping GridSpec = pallas_core.GridSpec +SpecializedKernel = pallas_core.SpecializedKernel pallas_call_p = jax_core.Primitive('pallas_call') pallas_call_p.multiple_results = True @@ -76,16 +77,17 @@ def _maybe_dynamic_update_slice(start_idx, block_shape, value, update, assert update.shape == block_shape return lax.dynamic_update_slice(value, update, start_idx) -def _pallas_call_impl(*args, jaxpr, name, out_shapes, which_linear, - interpret, debug: bool, +def _pallas_call_impl(*args, kernels: Sequence[SpecializedKernel], + name: str, out_shapes, which_linear: Sequence[bool], + interpret: bool, debug: bool, in_shapes, - input_output_aliases: Tuple[Tuple[int, int], ...], - grid_spec: GridSpec, - **compiler_params: Any): + input_output_aliases: Tuple[Tuple[int, int], ...]): if interpret: # If we're in interpreter mode, we *scan* over the grid and eval the # discharged jaxpr. This should reproduce exactly what compiling to Triton # will do. + kernel = kernels[0] + grid_spec, jaxpr = kernel.grid_spec, kernel.jaxpr grid = grid_spec.grid discharged_jaxpr, consts = state.discharge_state(jaxpr, ()) if debug: @@ -132,17 +134,23 @@ def body(carry): out_shapes=out_shapes, which_linear=which_linear, grid_spec=grid_spec, interpret=interpret, debug=debug, - input_output_aliases=input_output_aliases, - **compiler_params) + input_output_aliases=input_output_aliases) pallas_call_p.def_impl(_pallas_call_impl) def _pallas_call_abstract_eval(*avals, out_shapes, **_): return map(lambda x: jax_core.ShapedArray(x.shape, x.dtype), out_shapes) pallas_call_p.def_abstract_eval(_pallas_call_abstract_eval) -def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name, which_linear, +def _pallas_call_jvp_rule( + primals, tangents, *, + kernels: Sequence[SpecializedKernel], + name: str, + which_linear: Sequence[bool], input_output_aliases: Tuple[Tuple[int, int], ...], - in_shapes, out_shapes, grid_spec, debug, interpret, **compiler_params: Any): + in_shapes: Sequence[jax.ShapeDtypeStruct], + out_shapes: Sequence[jax.ShapeDtypeStruct], + debug: bool, + interpret: bool): if input_output_aliases: raise NotImplementedError("JVP with aliasing not supported.") nonzero_tangents = [not isinstance(t, ad_util.Zero) for t in tangents] @@ -150,44 +158,48 @@ def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name, which_linear, for t, inst in zip(tangents, nonzero_tangents)] tangents = [t for t in tangents if type(t) is not ad_util.Zero] nonzero_tangents_with_outputs = nonzero_tangents + [True] * len(out_shapes) - closed_jaxpr = jax_core.ClosedJaxpr(jaxpr, ()) - jvp_jaxpr_, _ = ad.jvp_jaxpr(closed_jaxpr, nonzero_tangents_with_outputs, []) - jvp_jaxpr, () = jvp_jaxpr_.jaxpr, jvp_jaxpr_.consts # TODO consts - jvp_which_linear = which_linear + (True,) * len(tangents) - jvp_inshapes = (*in_shapes, *in_shapes) - jvp_outshapes = (*out_shapes, *out_shapes) - if input_output_aliases: - raise NotImplementedError("`input_output_aliases` jvp not supported.") - # `pallas_call` takes in inputs and returns outputs but its jaxpr *does not*. - # `pallas_call` takes in a stateful jaxpr, meaning the jaxpr accepts input - # `Ref`s that are read from followed by output `Ref`s that are written to. - # This means that when we do `jvp_jaxpr` on the `jaxpr`, we get out a new - # jaxpr that has tangents following primals. In order for this jaxpr to be - # compatible w/ `pallas_call` (inputs then outputs), we need to shuffle around - # the jaxpr's invars. - logical_primals, logical_tangents = split_list( - jvp_jaxpr.invars, [len(primals) + len(out_shapes)]) - logical_primal_inputs, logical_primal_outputs = split_list(logical_primals, [len(primals)]) - logical_tangent_inputs, logical_tangent_outputs = split_list(logical_tangents, [len(tangents)]) - in_bms, out_bms = split_list(grid_spec.block_mappings, [len(primals)]) - new_bms = tuple((*in_bms, *in_bms, *out_bms, *out_bms)) - new_grid_spec = grid_spec.replace(block_mappings=new_bms) - jvp_jaxpr = jvp_jaxpr.replace(invars=[*logical_primal_inputs, - *logical_tangent_inputs, - *logical_primal_outputs, - *logical_tangent_outputs]) - if debug: - print(jvp_jaxpr) - out_flat = pallas_call_p.bind(*primals, *tangents, jaxpr=jvp_jaxpr, + new_kernels = [] + for kernel in kernels: + jaxpr, grid_spec = kernel.jaxpr, kernel.grid_spec + closed_jaxpr = jax_core.ClosedJaxpr(jaxpr, ()) + jvp_jaxpr_, _ = ad.jvp_jaxpr(closed_jaxpr, nonzero_tangents_with_outputs, []) + jvp_jaxpr, () = jvp_jaxpr_.jaxpr, jvp_jaxpr_.consts # TODO consts + jvp_which_linear = which_linear + (True,) * len(tangents) + jvp_inshapes = (*in_shapes, *in_shapes) + jvp_outshapes = (*out_shapes, *out_shapes) + if input_output_aliases: + raise NotImplementedError("`input_output_aliases` jvp not supported.") + # `pallas_call` takes in inputs and returns outputs but its jaxpr *does not*. + # `pallas_call` takes in a stateful jaxpr, meaning the jaxpr accepts input + # `Ref`s that are read from followed by output `Ref`s that are written to. + # This means that when we do `jvp_jaxpr` on the `jaxpr`, we get out a new + # jaxpr that has tangents following primals. In order for this jaxpr to be + # compatible w/ `pallas_call` (inputs then outputs), we need to shuffle around + # the jaxpr's invars. + logical_primals, logical_tangents = split_list( + jvp_jaxpr.invars, [len(primals) + len(out_shapes)]) + logical_primal_inputs, logical_primal_outputs = split_list(logical_primals, [len(primals)]) + logical_tangent_inputs, logical_tangent_outputs = split_list(logical_tangents, [len(tangents)]) + in_bms, out_bms = split_list(grid_spec.block_mappings, [len(primals)]) + new_bms = tuple((*in_bms, *in_bms, *out_bms, *out_bms)) + new_grid_spec = grid_spec.replace(block_mappings=new_bms) + jvp_jaxpr = jvp_jaxpr.replace(invars=[*logical_primal_inputs, + *logical_tangent_inputs, + *logical_primal_outputs, + *logical_tangent_outputs]) + if debug: + print(jvp_jaxpr) + new_kernels.append(SpecializedKernel(kernel.name, jvp_jaxpr, + kernel.num_consts, new_grid_spec, + kernel.compiler_params)) + out_flat = pallas_call_p.bind(*primals, *tangents, kernels=new_kernels, name=f"{name}_jvp", in_shapes=jvp_inshapes, out_shapes=jvp_outshapes, - grid_spec=new_grid_spec, which_linear=jvp_which_linear, interpret=interpret, debug=debug, - input_output_aliases=(), - **compiler_params) + input_output_aliases=()) out_primals, out_tangents = split_list(out_flat, [len(out_flat) // 2]) return out_primals, out_tangents ad.primitive_jvps[pallas_call_p] = _pallas_call_jvp_rule @@ -217,194 +229,232 @@ def _block_map_function(new_idx, *args): jax_core.ClosedJaxpr(block_mapping_jaxpr, consts)) def _pallas_call_batching_rule(args, dims, *, - jaxpr: jax_core.Jaxpr, + kernels: Sequence[SpecializedKernel], name: str, in_shapes: Tuple[jax.ShapeDtypeStruct, ...], out_shapes: Tuple[jax.ShapeDtypeStruct, ...], - grid_spec: GridSpec, input_output_aliases: Tuple[Tuple[int, int], ...], debug: bool, interpret: bool, - which_linear: Tuple[bool, ...], - **compiler_params: Any): + which_linear: Tuple[bool, ...]): axis_size, = {x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped} - block_mappings = grid_spec.block_mappings - avals = [v.aval for v in jaxpr.invars] - # How should we pick output dimensions? This actually matters because XLA - # can't optimize our pallas kernels, and this layout impacts performance. For - # now, because `vmap` doesn't really offer a way of inferring good output - # dimensions. For now, we just use 0. - # TODO(sharadmv): explore inferring better output dimensions via a heuristic - # TODO(sharadmv): explore a long term solution to output dim inference + new_kernels = [] + for kernel in kernels: + jaxpr, grid_spec = kernel.jaxpr, kernel.grid_spec + block_mappings = grid_spec.block_mappings + avals = [v.aval for v in jaxpr.invars] + # How should we pick output dimensions? This actually matters because XLA + # can't optimize our pallas kernels, and this layout impacts performance. For + # now, because `vmap` doesn't really offer a way of inferring good output + # dimensions. For now, we just use 0. + # TODO(sharadmv): explore inferring better output dimensions via a heuristic + # TODO(sharadmv): explore a long term solution to output dim inference - # When we have input/output aliasing, since the output will be mapped, we need - # to make sure to broadcast the input across that dimension if it is not - # mapped. - dims_ = list(dims) - args_ = list(args) - for input_index, _ in input_output_aliases: - dim = dims_[input_index] - if dim is batching.not_mapped: - dims_[input_index] = 0 - args_[input_index] = batching.broadcast(args_[input_index], axis_size, 0) - args = tuple(args_) - dims = tuple(dims_) + # When we have input/output aliasing, since the output will be mapped, we need + # to make sure to broadcast the input across that dimension if it is not + # mapped. + dims_ = list(dims) + args_ = list(args) + for input_index, _ in input_output_aliases: + dim = dims_[input_index] + if dim is batching.not_mapped: + dims_[input_index] = 0 + args_[input_index] = batching.broadcast(args_[input_index], axis_size, 0) + args = tuple(args_) + dims = tuple(dims_) - all_dims = list(dims) + [0] * len(out_shapes) + all_dims = list(dims) + [0] * len(out_shapes) - batched_block_mappings = map(partial(_batch_block_mapping, grid_spec.grid), - avals, all_dims, block_mappings) - batched_in_shapes = tuple( - jax.ShapeDtypeStruct(x.shape if dim is batching.not_mapped else - tuple_insert(x.shape, dim, axis_size), - x.dtype) - for x, dim in zip(in_shapes, dims)) - batched_out_shapes = tuple( - jax.ShapeDtypeStruct(tuple_insert(x.shape, 0, axis_size), x.dtype) - for x in out_shapes) + batched_block_mappings = map(partial(_batch_block_mapping, grid_spec.grid), + avals, all_dims, block_mappings) + batched_in_shapes = tuple( + jax.ShapeDtypeStruct(x.shape if dim is batching.not_mapped else + tuple_insert(x.shape, dim, axis_size), + x.dtype) + for x, dim in zip(in_shapes, dims)) + batched_out_shapes = tuple( + jax.ShapeDtypeStruct(tuple_insert(x.shape, 0, axis_size), x.dtype) + for x in out_shapes) - batched_grid_spec = grid_spec.replace(grid=(axis_size, *grid_spec.grid), - block_mappings=tuple(batched_block_mappings), - mapped_dims=(0,) + tuple(a + 1 for a in + batched_grid_spec = grid_spec.replace(grid=(axis_size, *grid_spec.grid), + block_mappings=tuple(batched_block_mappings), + mapped_dims=(0,) + tuple(a + 1 for a in grid_spec.mapped_dims)) - out = pallas_call_p.bind(*args, jaxpr=jaxpr, name=f"batched_{name}", + new_kernel = SpecializedKernel(kernel.name, jaxpr, kernel.num_consts, + batched_grid_spec, kernel.compiler_params) + new_kernels.append(new_kernel) + + out = pallas_call_p.bind(*args, kernels=new_kernels, name=f"batched_{name}", in_shapes=batched_in_shapes, out_shapes=batched_out_shapes, which_linear=which_linear, - grid_spec=batched_grid_spec, input_output_aliases=input_output_aliases, debug=debug, - interpret=interpret, - **compiler_params) + interpret=interpret) return out, (0,) * len(out) batching.primitive_batchers[pallas_call_p] = _pallas_call_batching_rule -@weakref_lru_cache -def _initial_style_open_jaxpr(fun: Callable, in_tree, in_avals, - primitive_name: Optional[str] = None): - wrapped_fun, out_tree_thunk = api_util.flatten_fun_nokwargs( - lu.wrap_init(fun), in_tree) - debug = pe.debug_info(fun, in_tree, out_tree_thunk, False, - primitive_name or "") - jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, debug) - jaxpr = for_loop._hoist_consts_to_refs(jaxpr) - return jaxpr, consts, out_tree_thunk() - -def _preprocess_grid(grid: Optional[Union[Grid, int]]) -> Grid: - if grid is None: - return () - if isinstance(grid, int): - return (grid,) - return grid - -def _extract_function_name(f: Callable, name: Optional[str]) -> str: - if name is None: - name = f.__name__ if hasattr(f, "__name__") and f.__name__ else "func" - return name +Kernel = Callable[..., Any] +MaybeSpec = Optional[Union[pallas_core.BlockSpec, + Callable[..., pallas_core.BlockSpec]]] +Platform = str +CompilerParams = dict[Platform, dict[str, Any]] -def _convert_block_spec_to_block_mapping( - grid: Grid, block_spec: Optional[BlockSpec]) -> Optional[BlockMapping]: - if block_spec is None: - return None - in_avals = [jax_core.ShapedArray((), jnp.int32) for _ in grid] - block_shape = tuple( - pallas_core.mapped if s is None else s for s in block_spec.block_shape) - jaxpr, _, consts = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(block_spec.compute_index), in_avals) - return BlockMapping(block_shape, jax_core.ClosedJaxpr(jaxpr, consts)) +def specialize_kernel(config: pallas_core.KernelConfig, + func: Callable, + name: str, + in_avals: tuple[jax_core.ShapedArray, ...], + out_avals: tuple[jax_core.ShapedArray, ...], + in_tree: tree_util.PyTreeDef, + compiler_params: CompilerParams, + ) -> tuple[SpecializedKernel, ...]: + grid = config.grid + if grid == (): + in_ref_avals = [state.shaped_array_ref(arg.shape, arg.dtype) + for arg in in_avals] + out_ref_avals = [state.shaped_array_ref(arg.shape, arg.dtype) + for arg in out_avals] + else: + in_ref_avals = [ + state.shaped_array_ref( + pallas_core.compute_shape_from_block_spec(block_spec, aval.shape), + aval.dtype) + for block_spec, aval in zip(config.in_specs, in_avals)] + out_ref_avals = [ + state.shaped_array_ref( + pallas_core.compute_shape_from_block_spec(block_spec, aval.shape), + aval.dtype) + for block_spec, aval in zip(config.out_specs, out_avals)] + in_block_mappings = map( + partial(pallas_core.convert_block_spec_to_block_mapping, grid), + config.in_specs) + out_block_mappings = map( + partial(pallas_core.convert_block_spec_to_block_mapping, grid), + config.out_specs) + grid_spec = pallas_core.GridSpec(grid, (*in_block_mappings, *out_block_mappings), ()) + jaxpr, consts, out_tree = tracing_utils.initial_style_open_jaxpr( + func, in_tree, tuple((*in_ref_avals, *out_ref_avals)), "pallas_call", **config.meta) + if config.name is not None: + name = f"{name}_{config.name}" + return SpecializedKernel(name, jaxpr, len(consts), grid_spec, + dict(compiler_params, **config.compiler_params)), consts, out_tree -def _compute_shape_from_block_spec(block_spec: Optional[BlockSpec], - arg_shape: Tuple[int, ...] - ) -> Tuple[int, ...]: - if block_spec is None: - return arg_shape - return tuple(s for s in block_spec.block_shape if s is not None) +def _canonicalize_kernel_config( + maybe_kernel_config: Optional[pallas_core.KernelConfig], + in_avals: Sequence[jax_core.AbstractValue], + out_avals: Sequence[jax_core.AbstractValue], + in_specs: Optional[Sequence[Optional[BlockSpec]]], + out_specs: Optional[Sequence[Optional[BlockSpec]]], + grid: Optional[Union[Grid, int]], + compiler_params: dict[str, Any], + ) -> pallas_core.KernelConfig: + if not maybe_kernel_config: + config = pallas_core.KernelConfig(in_specs=in_specs, out_specs=out_specs, + grid=grid, + compiler_params=compiler_params) + else: + config = maybe_kernel_config + config = config.replace(compiler_params=dict(compiler_params, + **config.compiler_params)) + grid = maybe_kernel_config.grid + grid, in_specs, out_specs = config.grid, config.in_specs, config.out_specs + grid = pallas_core.preprocess_grid(grid) + if in_specs is not None and not isinstance(in_specs, (tuple, list)): + in_specs = (in_specs,) + if out_specs is not None and not isinstance(out_specs, (tuple, list)): + out_specs = (out_specs,) + if in_specs is None: + in_specs = [None] * len(in_avals) + if out_specs is None: + out_specs = [None] * len(out_avals) + return config.replace(grid=grid, in_specs=in_specs, out_specs=out_specs) -def pallas_call(f: Callable, out_shape: Any, *, debug: bool = False, +def pallas_call(f: Callable, out_shape: Any, *, grid: Optional[Grid] = None, + config: Optional[pallas_core.KernelConfig] = None, in_specs: Optional[Sequence[Optional[BlockSpec]]] = None, out_specs: Optional[Sequence[Optional[BlockSpec]]] = None, input_output_aliases: Dict[int, int] = {}, interpret: bool = False, name: Optional[str] = None, - **compiler_params: Any): - if grid is None: - if in_specs is not None: - raise ValueError("Cannot specify `in_specs` with a `None` grid.") - if out_specs is not None: - raise ValueError("Cannot specify `out_specs` with a `None` grid.") - grid = _preprocess_grid(grid) - name = _extract_function_name(f, name) + autotuning_configs: Optional[Sequence[pallas_core.KernelConfig]] = None, + debug: bool = False, + compiler_params: Optional[CompilerParams] = None): + if config is not None: + if grid is not None or in_specs is not None or out_specs is not None: + raise ValueError("Cannot specify both `config` and any of `grid`, " + "`in_specs`, or `out_specs`.") + if autotuning_configs is not None: + raise ValueError("Cannot specify both `config` and `autotuning_configs`") + if autotuning_configs is not None: + if grid is not None or in_specs is not None or out_specs is not None: + raise ValueError("Cannot specify both `autotuning_configs` and any of `grid`, " + "`in_specs`, or `out_specs`.") + if compiler_params is None: + compiler_params = {} singleton = False if not isinstance(out_shape, (tuple, list)): out_shape = (out_shape,) singleton = True if not isinstance(out_shape, tuple): out_shape = tuple(out_shape) - flat_out_shapes, out_tree = tree_util.tree_flatten(out_shape) - if out_specs is not None and not isinstance(out_specs, (tuple, list)): - out_specs = (out_specs,) - if out_specs is not None and not isinstance(out_specs, tuple): - out_specs = tuple(out_specs) - flat_out_shapes = [jax.ShapeDtypeStruct(x.shape, x.dtype) - for x in flat_out_shapes] + if not name: + name = f.__name__ if hasattr(f, "__name__") else "unnamed" + @jax.jit def wrapped(*args): flat_args, in_tree = tree_util.tree_flatten(args) - if grid is None: - flat_in_specs = [None] * len(flat_args) - flat_out_specs = [None] * len(flat_out_shapes) - in_ref_avals = [state.shaped_array_ref(arg.shape, arg.dtype) - for arg in flat_args] - out_ref_avals = [state.shaped_array_ref(arg.shape, arg.dtype) - for arg in flat_out_shapes] - else: - if in_specs is None: - flat_in_specs = [None for arg in flat_args] - else: - flat_in_specs, in_block_tree = tree_util.tree_flatten(tuple(in_specs)) - if in_block_tree != in_tree: - raise ValueError( - "Pytree specs for arguments and `in_specs` must match: " - f"{in_tree} vs. {in_block_tree}") - if out_specs is None: - flat_out_specs = [None for arg in flat_out_shapes] - else: - flat_out_specs, out_block_tree = tree_util.tree_flatten(out_specs) - if out_block_tree != out_tree: - raise ValueError("Pytree specs for `out_shape` and `out_specs` must match: " - f"{out_tree} vs. {out_block_tree}") - in_ref_avals = [ - state.shaped_array_ref( - _compute_shape_from_block_spec(block_spec, arg.shape), arg.dtype) - for block_spec, arg in zip(flat_in_specs, flat_args)] - out_ref_avals = [ - state.shaped_array_ref( - _compute_shape_from_block_spec(block_spec, arg.shape), arg.dtype) - for block_spec, arg in zip(flat_out_specs, flat_out_shapes)] - in_block_mappings = map(partial(_convert_block_spec_to_block_mapping, grid), - flat_in_specs) - out_block_mappings = map(partial(_convert_block_spec_to_block_mapping, grid), - flat_out_specs) + flat_out_shapes, out_tree = tree_util.tree_flatten(out_shape) + jaxpr_in_tree = tree_util.tree_structure((*args, *out_shape)) - jaxpr, consts, _ = _initial_style_open_jaxpr( - f, jaxpr_in_tree, tuple((*in_ref_avals, *out_ref_avals)), - primitive_name="pallas_call") - flat_in_specs = it.chain([None] * len(consts), flat_in_specs) - grid_spec = GridSpec(grid, tuple((*in_block_mappings, - *out_block_mappings)), - ()) + flat_in_avals = tuple(jax_core.raise_to_shaped(jax_core.get_aval(a)) + for a in flat_args) + flat_out_avals = tuple(jax_core.ShapedArray(a.shape, a.dtype) + for a in flat_out_shapes) + canonicalized_configs = [] + if autotuning_configs is None: + canonicalized_configs.append(_canonicalize_kernel_config(config, + flat_in_avals, + flat_out_avals, + in_specs, + out_specs, + grid, + compiler_params)) + else: + canonicalized_configs.extend(map(partial(_canonicalize_kernel_config, + in_avals=flat_in_avals, + out_avals=flat_out_avals, + in_specs=in_specs, + out_specs=out_specs, + grid=grid, + compiler_params=compiler_params), + autotuning_configs)) + kernels = [] + all_consts = [] + if len(canonicalized_configs) == 0: + raise ValueError("Cannot pass in empty autotuning configs") + for canonicalized_config in canonicalized_configs: + specialized_kernel, consts, jaxpr_out_tree = specialize_kernel( + canonicalized_config, f, name, flat_in_avals, + flat_out_avals, jaxpr_in_tree, compiler_params) + kernels.append(specialized_kernel) + all_consts.extend(consts) + if all_consts: + raise NotImplementedError("Cannot handle consts.") + del jaxpr_out_tree which_linear = (False,) * len(flat_args) out_flat = pallas_call_p.bind( - *consts, *flat_args, jaxpr=jaxpr, name=name, which_linear=which_linear, + *all_consts, *flat_args, + kernels=kernels, + name=name, + which_linear=which_linear, in_shapes=tuple(jax.ShapeDtypeStruct(a.shape, a.dtype) for a in flat_args), - out_shapes=tuple(flat_out_shapes), debug=debug, + out_shapes=tuple(flat_out_shapes), + debug=debug, interpret=interpret, - grid_spec=grid_spec, - input_output_aliases=tuple(input_output_aliases.items()), - **compiler_params) + input_output_aliases=tuple(input_output_aliases.items())) out = tree_util.tree_unflatten(out_tree, out_flat) if singleton: return out[0] diff --git a/jax_triton/pallas/tracing_utils.py b/jax_triton/pallas/tracing_utils.py new file mode 100644 index 00000000..4ec68658 --- /dev/null +++ b/jax_triton/pallas/tracing_utils.py @@ -0,0 +1,44 @@ +# Copyright 2022 The jax_triton Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Module for tracing utilities.""" + +from typing import Any, Callable, Optional + +from jax.interpreters import partial_eval as pe +from jax._src import api_util +from jax._src import core as jax_core +from jax._src import linear_util as lu +from jax._src.util import weakref_lru_cache, safe_map, safe_zip, HashablePartial +from jax._src.lax.control_flow import for_loop + +map, unsafe_map = safe_map, map +zip, unsafe_zip = safe_zip, zip + +@weakref_lru_cache +def initial_style_open_jaxpr(fun: Callable, in_tree, in_avals, primitive_name, *args, **kwargs): + wrapped_fun, out_tree_thunk = api_util.flatten_fun_nokwargs( + lu.wrap_init(HashablePartial(fun, *args, **kwargs)), in_tree) + debug_info = pe.debug_info(fun, in_tree, out_tree_thunk, False, + primitive_name or "") + jaxpr, consts = initial_style_flat_jaxpr(wrapped_fun, in_avals, + debug_info=debug_info) + return jaxpr, consts, out_tree_thunk() + +def initial_style_flat_jaxpr(fun: lu.WrappedFun, in_avals, + debug_info: Optional[jax_core.JaxprDebugInfo] = None + ) -> tuple[jax_core.Jaxpr, list[Any]]: + jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(fun, in_avals, debug_info) + jaxpr = for_loop._hoist_consts_to_refs(jaxpr) + return jaxpr, consts + diff --git a/jax_triton/pallas/triton_ir_lowering.py b/jax_triton/pallas/triton_ir_lowering.py index b3973372..96f2e1a2 100644 --- a/jax_triton/pallas/triton_ir_lowering.py +++ b/jax_triton/pallas/triton_ir_lowering.py @@ -44,8 +44,9 @@ import triton import triton.language as tl from triton.language import ir as tl_ir -import triton.libtriton.triton as _triton +import triton._C.libtriton.triton as _triton +from jax_triton import triton_lib from jax_triton import triton_kernel_call_lib from jax_triton import utils as triton_utils from jax_triton.pallas import primitives @@ -136,7 +137,7 @@ def _process_grid_to_3d_grid(builder, grid_spec: GridSpec): assert len(out_indices) == len(grid_spec.grid) return new_grid, out_indices -def lower_jaxpr_to_triton_module(jaxpr: jax_core.Jaxpr, in_shapes, grid_spec: GridSpec, +def lower_jaxpr_to_triton_module(jaxpr: jax_core.Jaxpr, num_consts: int, in_shapes, grid_spec: GridSpec, name: str) -> tl_ir.module: jaxpr, _ = pe.dce_jaxpr(jaxpr, [True] * len(jaxpr.outvars), instantiate=True) ir_context = tl_ir.context() @@ -167,11 +168,13 @@ def lower_jaxpr_to_triton_module(jaxpr: jax_core.Jaxpr, in_shapes, grid_spec: Gr local_program_ids) start_indices = map(partial(_eval_index_map, ctx, program_ids), grid_spec.block_mappings) - block_infos = [BlockInfo(shape, - start_idx, block_mapping.block_shape) - if block_mapping is not None else None - for shape, block_mapping, start_idx in - zip(in_shapes, grid_spec.block_mappings, start_indices)] + arg_block_infos = [BlockInfo(shape, + start_idx, block_mapping.block_shape) + if block_mapping is not None else None + for shape, block_mapping, start_idx in + zip(in_shapes, grid_spec.block_mappings, start_indices)] + const_block_infos = [None] * num_consts + block_infos = [*const_block_infos, *arg_block_infos] () = lower_jaxpr_to_triton_ir(ctx, jaxpr, block_infos, *args) ctx.builder.ret_void() ctx.builder.set_insert_block(insert_pt) @@ -797,22 +800,18 @@ def _while_lowering_rule(ctx: TritonLoweringRuleContext, *args, cond_nconsts, return post_args triton_lowering_rules[lax.while_p] = _while_lowering_rule -@weakref_lru_cache -def compile_jaxpr(jaxpr: jax_core.Jaxpr, in_shapes, grid_spec: GridSpec, - name: str, num_warps: int, num_stages: int - ) -> TritonCompilationResult: - lowering_result = lower_jaxpr_to_triton_module(jaxpr, in_shapes, grid_spec, name) - backend = _triton.runtime.backend.CUDA - device = 0 - name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, lowering_result.module, device, - num_warps, num_stages, {}, 0) - return TritonCompilationResult(name, asm, shared_mem, lowering_result) +def _mangle_name(name: str): + name = name.replace("-", "_") + name = name.replace("-", "_") + name = name.replace("=", "_") + return name @weakref_lru_cache -def compile_jaxpr(jaxpr: jax_core.Jaxpr, in_shapes, grid_spec: GridSpec, +def compile_jaxpr(jaxpr: jax_core.Jaxpr, num_consts: int, in_shapes, grid_spec: GridSpec, name: str, num_warps: int, num_stages: int ) -> TritonCompilationResult: - lowering_result = lower_jaxpr_to_triton_module(jaxpr, in_shapes, grid_spec, name) + name = _mangle_name(name) + lowering_result = lower_jaxpr_to_triton_module(jaxpr, num_consts, in_shapes, grid_spec, name) backend = _triton.runtime.backend.CUDA device = 0 name, asm, shared_mem = _triton.code_gen.compile_ttir( @@ -821,59 +820,100 @@ def compile_jaxpr(jaxpr: jax_core.Jaxpr, in_shapes, grid_spec: GridSpec, def pallas_call_lowering(ctx: mlir.LoweringRuleContext, *in_nodes, - jaxpr: jax_core.Jaxpr, + kernels: Sequence[pallas_core.SpecializedKernel], name: str, in_shapes: Tuple[jax.ShapeDtypeStruct, ...], out_shapes: Tuple[jax.ShapeDtypeStruct, ...], which_linear: Tuple[bool, ...], interpret: bool, debug: bool, - input_output_aliases: Tuple[Tuple[int, int], ...], - grid_spec: GridSpec, - **compiler_params: Any): + input_output_aliases: Tuple[Tuple[int, int], ...]): if interpret: return mlir.lower_fun(pallas_call_p.impl, multiple_results=True)( - ctx, *in_nodes, jaxpr=jaxpr, name=name, out_shapes=out_shapes, + ctx, *in_nodes, kernels=kernels, name=name, out_shapes=out_shapes, in_shapes=in_shapes, which_linear=which_linear, interpret=interpret, debug=debug, - input_output_aliases=input_output_aliases, - grid_spec=grid_spec, **compiler_params) - num_warps = compiler_params.get("num_warps", 4) - num_stages = compiler_params.get("num_stages", 3) - compilation_result = compile_jaxpr(jaxpr, tuple((*in_shapes, *out_shapes)), - grid_spec, name, num_warps, num_stages) - name = compilation_result.name - asm = compilation_result.asm - shared_mem = compilation_result.shared_mem - if debug: - print(jaxpr) - print(grid_spec) - lowering_result = compilation_result.lowering_result - if debug: - lowering_result.module.print() + input_output_aliases=input_output_aliases) + lowered_kernels = [] + for kernel in kernels: + if debug: + print(kernel.jaxpr) + print(kernel.grid_spec) + print(kernel.compiler_params) + compiler_params = dict(kernel.compiler_params).get("triton", {}) + num_warps = compiler_params.pop("num_warps", 4) + num_stages = compiler_params.pop("num_stages", 3) + if compiler_params: + raise ValueError(f"Invalid compiler params: {compiler_params}") + compilation_result = compile_jaxpr(kernel.jaxpr, kernel.num_consts, + tuple((*in_shapes, *out_shapes)), + kernel.grid_spec, kernel.name, num_warps, + num_stages) + name = compilation_result.name + asm = compilation_result.asm + shared_mem = compilation_result.shared_mem + lowering_result = compilation_result.lowering_result + if debug: + lowering_result.module.print() + lowered_kernels.append((name, asm, shared_mem, lowering_result, num_warps)) out_type = ir.TupleType.get_tuple([ ir.RankedTensorType.get(out_shape.shape, mlir.dtype_to_ir_type(out_shape.dtype)) for out_shape in ctx.avals_out]) i32_type = ir.IntegerType.get_signless(32) - kernel = triton_kernel_call_lib.TritonKernel( - asm["cubin"], name, num_warps, shared_mem - ) - - grid = triton_utils.normalize_grid( - compilation_result.lowering_result.grid, metaparams={}) - kernel_params = [] - for _ in range(len(in_shapes) + len(out_shapes)): - kernel_params.append( - triton_kernel_call_lib.create_array_parameter( - 0, # bytes to zero # TODO(cjfj): Expose through user API. - True, # divisible by 16 + if len(lowered_kernels) == 1: + name, asm, shared_mem, lowering_result, num_warps = lowered_kernels[0] + kernel = triton_kernel_call_lib.TritonKernel( + asm["cubin"], name, num_warps, shared_mem + ) + grid = triton_utils.normalize_grid( + lowering_result.grid, metaparams={}) + # All arguments are buffers. + kernel_params = [] + for _ in range(len(in_shapes) + len(out_shapes)): + kernel_params.append( + triton_kernel_call_lib.create_array_parameter( + 0, # bytes to zero # TODO(cjfj): Expose through user API. + True, # divisible by 16 + ) + ) + kernel_call = triton_kernel_call_lib.TritonKernelCall( + kernel, grid[0], grid[1], grid[2], kernel_params + ) + elif len(lowered_kernels) > 1: + kernel_calls = [] + for name, asm, shared_mem, lowering_result, num_warps in lowered_kernels: + kernel = triton_kernel_call_lib.TritonKernel( + asm["cubin"], name, num_warps, shared_mem + ) + grid = triton_utils.normalize_grid( + lowering_result.grid, metaparams={}) + # All arguments are buffers. + kernel_params = [] + for _ in range(len(in_shapes) + len(out_shapes)): + kernel_params.append( + triton_kernel_call_lib.create_array_parameter( + 0, # bytes to zero # TODO(cjfj): Expose through user API. + True, # divisible by 16 + ) ) + kernel_call = triton_kernel_call_lib.TritonKernelCall( + kernel, grid[0], grid[1], grid[2], kernel_params + ) + kernel_calls.append(kernel_call) + input_output_aliases_with_sizes = tuple( + (input_idx, output_idx, triton_lib.aval_size_bytes(ctx.avals_in[input_idx])) + for input_idx, output_idx in input_output_aliases ) - kernel_call = triton_kernel_call_lib.TritonKernelCall( - kernel, grid[0], grid[1], grid[2], kernel_params - ) + + kernel_call = triton_kernel_call_lib.TritonAutotunedKernelCall( + name, + [(call, f"config{i}") for i, call in enumerate(kernel_calls)], + input_output_aliases_with_sizes, + ) + else: + raise ValueError("Cannot have 0 kernels.") ctx.module_context.add_keepalive(kernel_call) output_operand_aliases = ir.ArrayAttr.get([ diff --git a/tests/pallas_test.py b/tests/pallas_test.py index 1f2a2b76..dd7cd2e1 100644 --- a/tests/pallas_test.py +++ b/tests/pallas_test.py @@ -33,9 +33,10 @@ import jax.numpy as jnp import jax_triton as jt from jax_triton import pallas as pl -from jax_triton.pallas.pallas_call import _initial_style_open_jaxpr from jax_triton.pallas.ops import attention from jax_triton.pallas.ops import layer_norm +from jax_triton.pallas.ops import matmul as matmul_ops +from jax_triton.pallas.tracing_utils import initial_style_open_jaxpr try: from jax_triton.pallas.triton_ir_lowering import compile_jaxpr except ModuleNotFoundError: @@ -123,7 +124,7 @@ def setUp(self): super().setUp() if compile_jaxpr: compile_jaxpr.cache_clear() - _initial_style_open_jaxpr.cache_clear() + initial_style_open_jaxpr.cache_clear() def pallas_call(self, *args, **kwargs): return pl.pallas_call(*args, **kwargs, interpret=self.INTERPRET) @@ -851,6 +852,95 @@ def body(x_ref): lu.wrap_init(body), [state.shaped_array_ref((4, 3, 2), jnp.int32)]) self.assertIn(expected, jaxpr.pretty_print(use_color=False)) +class PallasAutotuningTest(PallasTest): + + def test_basic_autotuning(self): + + @functools.partial( + self.pallas_call, out_shape=jax.ShapeDtypeStruct((8,), jnp.float32), + autotuning_configs=[ + pl.KernelConfig(meta=dict(block_size=block_size), + grid=8 // block_size) + for block_size in [1, 2, 4, 8] + ]) + def add_one(x_ref, o_ref, *, block_size): + idx = pl.program_id(0) * block_size + jnp.arange(block_size) + o_ref[idx] = x_ref[idx] + 1. + + x = jnp.arange(8.) + np.testing.assert_allclose(add_one(x), x + 1.) + + def test_basic_autotuning_with_block_spec(self): + + @functools.partial( + self.pallas_call, out_shape=jax.ShapeDtypeStruct((8,), jnp.float32), + autotuning_configs=[ + pl.KernelConfig(name=f"block_size={block_size}_num_warps={num_warps}_num_stages={num_stages}", + meta=dict(block_size=block_size), + in_specs=[ + pl.BlockSpec(lambda i: i, (block_size,)) + ], + out_specs=[ + pl.BlockSpec(lambda i: i, (block_size,)) + ], + grid=8 // block_size, + compiler_params=dict(triton=dict(num_warps=num_warps, + num_stages=num_stages))) + for block_size in [4, 8] + for num_warps, num_stages in zip([4, 4], [3, 2]) + ], + debug=True) + def add_one(x_ref, o_ref, *, block_size): + del block_size + o_ref[...] = x_ref[...] + 1. + + x = jnp.arange(8.) + np.testing.assert_allclose(add_one(x), x + 1.) + + @parameterized.parameters( + (256, 256, 256), + (1024, 1024, 1024), + (1024, 512, 1024), + (1024, 512, 512), + ) + def test_matmul_autotuned(self, m, n, k): + key = random.PRNGKey(0) + k1, k2 = random.split(key) + x = random.normal(k1, (m, k), dtype=jnp.float16) + y = random.normal(k2, (k, n), dtype=jnp.float16) + np.testing.assert_allclose(matmul_ops.matmul(x, y, debug=True), + matmul_ops.matmul_reference(x, y), + atol=0.05, rtol=0.05) + + + def test_vmap_of_autotuned_kernel(self): + + @functools.partial( + self.pallas_call, out_shape=jax.ShapeDtypeStruct((8,), jnp.float32), + autotuning_configs=[ + pl.KernelConfig(name=f"block_size={block_size}_num_warps={num_warps}_num_stages={num_stages}", + meta=dict(block_size=block_size), + in_specs=[ + pl.BlockSpec(lambda i: i, (block_size,)), + pl.BlockSpec(lambda i: i, (block_size,)) + ], + out_specs=[ + pl.BlockSpec(lambda i: i, (block_size,)) + ], + grid=8 // block_size, + compiler_params=dict(triton=dict(num_warps=num_warps, + num_stages=num_stages))) + for block_size in [4, 8] + for num_warps, num_stages in zip([4, 4], [3, 2]) + ], + debug=True) + def add(x_ref, y_ref, o_ref, *, block_size): + del block_size + o_ref[...] = x_ref[...] + y_ref[...] + + x = jnp.arange(16.).reshape((2, 8)) + np.testing.assert_allclose(jax.vmap(add)(x, x), x + x) + class FusedAttentionTest(parameterized.TestCase): @parameterized.parameters(*[