From 391cbc17f43ae37a97377ae95b052ef75a01b4e7 Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Thu, 30 Mar 2023 20:41:36 +0000 Subject: [PATCH] Remove lambdas --- jax_triton/pallas/__init__.py | 2 +- jax_triton/pallas/core.py | 14 ++- jax_triton/pallas/pallas_call.py | 113 +++++++++++++++--------- jax_triton/pallas/triton_ir_lowering.py | 8 +- tests/pallas_test.py | 26 +++--- 5 files changed, 102 insertions(+), 61 deletions(-) diff --git a/jax_triton/pallas/__init__.py b/jax_triton/pallas/__init__.py index 9a7d1986..76a7e323 100644 --- a/jax_triton/pallas/__init__.py +++ b/jax_triton/pallas/__init__.py @@ -14,7 +14,7 @@ """Module for pallas, a jaxpr "dialect" for Triton.""" from jax_triton.pallas.core import BlockSpec -from jax_triton.pallas.core import Config +from jax_triton.pallas.core import KernelConfig from jax_triton.pallas.pallas_call import pallas_call from jax_triton.pallas.pallas_call import pallas_call_p from jax_triton.pallas.primitives import atomic_add diff --git a/jax_triton/pallas/core.py b/jax_triton/pallas/core.py index a1bc7e4f..0ff92d46 100644 --- a/jax_triton/pallas/core.py +++ b/jax_triton/pallas/core.py @@ -18,7 +18,7 @@ import functools from functools import partial -from typing import Any, Callable, Iterator, List, Optional, Tuple, Union +from typing import Any, Callable, Iterator, List, Optional, Sequence, Tuple, Union import jax.numpy as jnp from jax._src import api_util @@ -95,6 +95,18 @@ class GridSpec: Platform = str + +@dataclasses.dataclass +class KernelConfig: + in_specs: Optional[Sequence[Optional[BlockSpec]]] = None + out_specs: Optional[Sequence[Optional[BlockSpec]]] = None + grid: Optional[Union[Grid, int]] = None + meta: dict[str, Any] = dataclasses.field(default_factory=dict) + compiler_params: dict[Platform, dict[str, Any]] = dataclasses.field(default_factory=dict) + + def replace(self, *args, **kwargs): + return dataclasses.replace(self, *args, **kwargs) + @dataclasses.dataclass class Config: meta: dict[str, Any] diff --git a/jax_triton/pallas/pallas_call.py b/jax_triton/pallas/pallas_call.py index f8e2bc3c..837049a7 100644 --- a/jax_triton/pallas/pallas_call.py +++ b/jax_triton/pallas/pallas_call.py @@ -306,24 +306,16 @@ def _compute_spec(config: Config, spec: MaybeSpec, spec = spec(**config.meta) return spec -def specialize_kernel(config: Config, +def specialize_kernel(config: pallas_core.KernelConfig, func: Callable, - grid: Optional[pallas_core.Grid], name: Optional[str], - in_specs: Optional[list[Optional[BlockSpec]]], - out_specs: Optional[list[Optional[BlockSpec]]], in_avals: tuple[jax_core.ShapedArray, ...], out_avals: tuple[jax_core.ShapedArray, ...], in_tree: tree_util.PyTreeDef, compiler_params: dict[str, Any] ) -> tuple[SpecializedKernel, ...]: - specialized_grid = grid - if callable(specialized_grid): - specialized_grid = specialized_grid(**config.meta) - specialized_grid = pallas_core.preprocess_grid(specialized_grid) - specialized_in_specs = map(partial(_compute_spec, config), in_specs) - specialized_out_specs = map(partial(_compute_spec, config), out_specs) - if specialized_grid == (): + grid = config.grid + if grid == (): in_ref_avals = [state.shaped_array_ref(arg.shape, arg.dtype) for arg in in_avals] out_ref_avals = [state.shaped_array_ref(arg.shape, arg.dtype) @@ -333,42 +325,76 @@ def specialize_kernel(config: Config, state.shaped_array_ref( pallas_core.compute_shape_from_block_spec(block_spec, aval.shape), aval.dtype) - for block_spec, aval in zip(specialized_in_specs, in_avals)] + for block_spec, aval in zip(config.in_specs, in_avals)] out_ref_avals = [ state.shaped_array_ref( pallas_core.compute_shape_from_block_spec(block_spec, aval.shape), aval.dtype) - for block_spec, aval in zip(specialized_out_specs, out_avals)] - in_block_mappings = map(partial(pallas_core.convert_block_spec_to_block_mapping, specialized_grid), - specialized_in_specs) - out_block_mappings = map(partial(pallas_core.convert_block_spec_to_block_mapping, specialized_grid), - specialized_out_specs) - grid_spec = pallas_core.GridSpec(specialized_grid, (*in_block_mappings, *out_block_mappings), ()) + for block_spec, aval in zip(config.out_specs, out_avals)] + in_block_mappings = map( + partial(pallas_core.convert_block_spec_to_block_mapping, grid), + config.in_specs) + out_block_mappings = map( + partial(pallas_core.convert_block_spec_to_block_mapping, grid), + config.out_specs) + grid_spec = pallas_core.GridSpec(grid, (*in_block_mappings, *out_block_mappings), ()) jaxpr, consts, out_tree = tracing_utils.initial_style_open_jaxpr( func, in_tree, tuple((*in_ref_avals, *out_ref_avals)), "pallas_call", **config.meta) return SpecializedKernel("foo", jaxpr, len(consts), grid_spec, dict(compiler_params, **config.compiler_params)), consts, out_tree -def pallas_call(f: Callable, out_shape: Any, *, debug: bool = False, +def _canonicalize_kernel_config( + maybe_kernel_config: Optional[pallas_core.KernelConfig], + in_avals: Sequence[jax_core.AbstractValue], + out_avals: Sequence[jax_core.AbstractValue], + in_specs: Optional[Sequence[Optional[BlockSpec]]], + out_specs: Optional[Sequence[Optional[BlockSpec]]], + grid: Optional[Union[Grid, int]], + ) -> pallas_core.KernelConfig: + if not maybe_kernel_config: + config = pallas_core.KernelConfig(in_specs=in_specs, out_specs=out_specs, grid=grid) + else: + config = maybe_kernel_config + grid = maybe_kernel_config.grid + grid, in_specs, out_specs = config.grid, config.in_specs, config.out_specs + grid = pallas_core.preprocess_grid(grid) + if in_specs is not None and not isinstance(in_specs, (tuple, list)): + in_specs = (in_specs,) + if out_specs is not None and not isinstance(out_specs, (tuple, list)): + out_specs = (out_specs,) + if in_specs is None: + in_specs = [None] * len(in_avals) + if out_specs is None: + out_specs = [None] * len(out_avals) + return config.replace(grid=grid, in_specs=in_specs, out_specs=out_specs) + +def pallas_call(f: Callable, out_shape: Any, *, grid: Optional[Grid] = None, + config: Optional[pallas_core.KernelConfig] = None, in_specs: Optional[Sequence[Optional[BlockSpec]]] = None, out_specs: Optional[Sequence[Optional[BlockSpec]]] = None, input_output_aliases: Dict[int, int] = {}, interpret: bool = False, name: Optional[str] = None, - autotuning_configs: Optional[list[Config]] = None, + autotuning_configs: Optional[Sequence[pallas_core.KernelConfig]] = None, + debug: bool = False, **compiler_params: Any): + if config is not None: + if grid is not None or in_specs is not None or out_specs is not None: + raise ValueError("Cannot specify both `config` and any of `grid`, " + "`in_specs`, or `out_specs`.") + if autotuning_configs is not None: + raise ValueError("Cannot specify both `config` and `autotuning_configs`") + if autotuning_configs is not None: + if grid is not None or in_specs is not None or out_specs is not None: + raise ValueError("Cannot specify both `autotuning_configs` and any of `grid`, " + "`in_specs`, or `out_specs`.") singleton = False if not isinstance(out_shape, (tuple, list)): out_shape = (out_shape,) singleton = True if not isinstance(out_shape, tuple): out_shape = tuple(out_shape) - if in_specs is not None and not isinstance(in_specs, (tuple, list)): - in_specs = (in_specs,) - if out_specs is not None and not isinstance(out_specs, (tuple, list)): - out_specs = (out_specs,) - if not name: name = f.__name__ if hasattr(f, "__name__") else "unnamed" @@ -382,29 +408,32 @@ def wrapped(*args): for a in flat_args) flat_out_avals = tuple(jax_core.ShapedArray(a.shape, a.dtype) for a in flat_out_shapes) + canonicalized_configs = [] + if autotuning_configs is None: + canonicalized_configs.append(_canonicalize_kernel_config(config, + flat_in_avals, + flat_out_avals, + in_specs, + out_specs, + grid)) + else: + canonicalized_configs.extend(map(partial(_canonicalize_kernel_config, + in_avals=flat_in_avals, + out_avals=flat_out_avals, + in_specs=in_specs, + out_specs=out_specs, + grid=grid), + autotuning_configs)) kernels = [] - flat_in_specs = in_specs - flat_out_specs = out_specs - if flat_in_specs is None: - flat_in_specs = [None] * len(flat_in_avals) - if flat_out_specs is None: - flat_out_specs = [None] * len(flat_out_avals) all_consts = [] - if autotuning_configs is None: + if len(canonicalized_configs) == 0: + raise ValueError("Cannot pass in empty autotuning configs") + for canonicalized_config in canonicalized_configs: specialized_kernel, consts, jaxpr_out_tree = specialize_kernel( - Config({}, {}), f, grid, name, flat_in_specs, flat_out_specs, flat_in_avals, + canonicalized_config, f, name, flat_in_avals, flat_out_avals, jaxpr_in_tree, compiler_params) kernels.append(specialized_kernel) all_consts.extend(consts) - else: - if len(autotuning_configs) == 0: - raise ValueError("Cannot pass in empty autotuning configs") - for config in autotuning_configs: - specialized_kernel, consts, jaxpr_out_tree = specialize_kernel( - config, f, grid, name, flat_in_specs, flat_out_specs, flat_in_avals, flat_out_avals, - jaxpr_in_tree, compiler_params) - kernels.append(specialized_kernel) - all_consts.extend(consts) if all_consts: raise NotImplementedError("Cannot handle consts.") del jaxpr_out_tree diff --git a/jax_triton/pallas/triton_ir_lowering.py b/jax_triton/pallas/triton_ir_lowering.py index 636f7b5a..d01feab5 100644 --- a/jax_triton/pallas/triton_ir_lowering.py +++ b/jax_triton/pallas/triton_ir_lowering.py @@ -837,9 +837,11 @@ def pallas_call_lowering(ctx: mlir.LoweringRuleContext, *in_nodes, if debug: print(kernel.jaxpr) print(kernel.grid_spec) - compiler_params = kernel.compiler_params - num_warps = compiler_params.get("num_warps", 4) - num_stages = compiler_params.get("num_stages", 3) + compiler_params = dict(kernel.compiler_params) + num_warps = compiler_params.pop("num_warps", 4) + num_stages = compiler_params.pop("num_stages", 3) + if compiler_params: + raise ValueError(f"Invalid compiler params: {compiler_params}") compilation_result = compile_jaxpr(kernel.jaxpr, kernel.num_consts, tuple((*in_shapes, *out_shapes)), kernel.grid_spec, kernel.name, num_warps, num_stages) diff --git a/tests/pallas_test.py b/tests/pallas_test.py index 3fab4145..47bc4da7 100644 --- a/tests/pallas_test.py +++ b/tests/pallas_test.py @@ -857,10 +857,10 @@ def test_basic_autotuning(self): @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((8,), jnp.float32), - grid=lambda block_size: 8 // block_size, autotuning_configs=[ - pl.Config(dict(block_size=2), {}), - pl.Config(dict(block_size=4), {}), + pl.KernelConfig(meta=dict(block_size=block_size), + grid=8 // block_size) + for block_size in [1, 2, 4, 8] ]) def add_one(x_ref, o_ref, *, block_size): idx = pl.program_id(0) * block_size + jnp.arange(block_size) @@ -873,18 +873,16 @@ def test_basic_autotuning_with_block_spec(self): @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((8,), jnp.float32), - grid=lambda block_size: 8 // block_size, - in_specs=[ - lambda block_size: pl.BlockSpec(lambda i: i, (block_size,)), - ], - out_specs=[ - lambda block_size: pl.BlockSpec(lambda i: i, (block_size,)), - ], autotuning_configs=[ - pl.Config(dict(block_size=1), {}), - pl.Config(dict(block_size=2), {}), - pl.Config(dict(block_size=4), {}), - pl.Config(dict(block_size=8), {}), + pl.KernelConfig(meta=dict(block_size=block_size), + in_specs=[ + pl.BlockSpec(lambda i: i, (block_size,)) + ], + out_specs=[ + pl.BlockSpec(lambda i: i, (block_size,)) + ], + grid=8 // block_size) + for block_size in [1, 2, 4, 8] ], debug=True) def add_one(x_ref, o_ref, *, block_size):