Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pallas autotuning #108

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions jax_triton/pallas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

"""Module for pallas, a jaxpr "dialect" for Triton."""
from jax_triton.pallas.core import BlockSpec
from jax_triton.pallas.core import KernelConfig
from jax_triton.pallas.pallas_call import pallas_call
from jax_triton.pallas.pallas_call import pallas_call_p
from jax_triton.pallas.primitives import atomic_add
Expand Down
73 changes: 72 additions & 1 deletion jax_triton/pallas/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,29 @@
"""Module for pallas-core functionality."""
import contextlib
import dataclasses
import functools
from functools import partial

from typing import Any, Callable, Iterator, List, Optional, Tuple, Union
from typing import Any, Callable, Iterator, List, Optional, Sequence, Tuple, Union

import jax.numpy as jnp
from jax._src import api_util
from jax._src import core as jax_core
from jax._src import linear_util as lu
from jax._src import state
from jax._src import tree_util
from jax._src.lax.control_flow import for_loop
from jax.interpreters import partial_eval as pe
from jax._src.util import weakref_lru_cache, safe_map, safe_zip
from jax._src.state.types import AbstractRef

from jax_triton.pallas import tracing_utils

map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip

Grid = tuple[int, ...]
GridOrLambda = Union[Callable[..., Grid], Grid]

@dataclasses.dataclass
class GridEnv:
Expand Down Expand Up @@ -73,3 +92,55 @@ class GridSpec:
mapped_dims: Tuple[int, ...]

replace = dataclasses.replace

Platform = str

@dataclasses.dataclass
class KernelConfig:
name: Optional[str] = None
in_specs: Optional[Sequence[Optional[BlockSpec]]] = None
out_specs: Optional[Sequence[Optional[BlockSpec]]] = None
grid: Optional[Union[Grid, int]] = None
meta: dict[str, Any] = dataclasses.field(default_factory=dict)
compiler_params: dict[Platform, dict[str, Any]] = dataclasses.field(default_factory=dict)

def replace(self, *args, **kwargs):
return dataclasses.replace(self, *args, **kwargs)

def preprocess_grid(grid: Optional[Union[Grid, int]]) -> Grid:
if grid is None:
return ()
if isinstance(grid, int):
return (grid,)
return grid

def extract_function_name(f: Callable, name: Optional[str]) -> str:
if name is None:
name = f.__name__ if hasattr(f, "__name__") and f.__name__ else "func"
return name

def convert_block_spec_to_block_mapping(
grid: Grid, block_spec: Optional[BlockSpec]) -> Optional[BlockMapping]:
if block_spec is None:
return None
in_avals = [jax_core.ShapedArray((), jnp.int32) for _ in grid]
block_shape = tuple(
mapped if s is None else s for s in block_spec.block_shape)
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(block_spec.compute_index), in_avals)
return BlockMapping(block_shape, jax_core.ClosedJaxpr(jaxpr, consts))

def compute_shape_from_block_spec(block_spec: Optional[BlockSpec],
arg_shape: tuple[int, ...]
) -> tuple[int, ...]:
if block_spec is None:
return arg_shape
return tuple(s for s in block_spec.block_shape if s is not None)

@dataclasses.dataclass
class SpecializedKernel:
name: str
jaxpr: jax_core.Jaxpr
num_consts: int
grid_spec: GridSpec
compiler_params: dict[str, Any]
80 changes: 80 additions & 0 deletions jax_triton/pallas/ops/matmul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright 2023 The jax_triton Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Module containing fused matmuls."""
import functools

from typing import Any, Optional

import jax
import jax.numpy as jnp
from jax._src.lax.control_flow.for_loop import for_loop

import jax_triton as jt
from jax_triton import pallas as pl

def _compute_bound_configs():
yield from [
dict(bm=128, bn=256, bk=32, compiler_params=dict(triton=dict(num_stages=3, num_warps=8))),
dict(bm=256, bn=128, bk=32, compiler_params=dict(triton=dict(num_stages=3, num_warps=8))),
dict(bm=256, bn=64 , bk=32, compiler_params=dict(triton=dict(num_stages=4, num_warps=4))),
dict(bm=64, bn=256, bk=32, compiler_params=dict(triton=dict(num_stages=4, num_warps=4))),
dict(bm=128, bn=128, bk=32, compiler_params=dict(triton=dict(num_stages=4, num_warps=4))),
dict(bm=128, bn=64, bk=32, compiler_params=dict(triton=dict(num_stages=4, num_warps=4))),
dict(bm=64, bn=128, bk=32, compiler_params=dict(triton=dict(num_stages=4, num_warps=4))),
dict(bm=128, bn=32, bk=32, compiler_params=dict(triton=dict(num_stages=4, num_warps=4))),
dict(bm=64, bn=32, bk=32, compiler_params=dict(triton=dict(num_stages=5, num_warps=2))),
]

@functools.partial(jax.jit, static_argnames=["interpret", "debug"])
def matmul(x, y, *, interpret=False, debug=False):
# TODO(sharadmv): make this implementation better
# 1. reordered programs for better L2
# 2. split K
# 3. masking
m, n, k = x.shape[0], y.shape[1], x.shape[1]

@functools.partial(
pl.pallas_call, out_shape=jax.ShapeDtypeStruct((m, n), jnp.float32),
interpret=interpret,
debug=debug,
autotuning_configs=[
pl.KernelConfig(
name=f"bm={config['bm']}_bn={config['bn']}_bk={config['bk']}",
meta=dict(bk=config["bk"]),
in_specs=[
pl.BlockSpec(lambda i, _: (i, 0), (config["bm"], x.shape[1])),
pl.BlockSpec(lambda _, j: (0, j), (y.shape[0], config["bn"]))
],
out_specs=pl.BlockSpec(lambda i, j: (i, j), (config["bm"],
config["bn"])),
grid=(jt.cdiv(m, config["bm"]), jt.cdiv(n, config["bn"])),
compiler_params=config["compiler_params"],
)
for config in _compute_bound_configs()
]
)
def matmul_kernel(x_ref, y_ref, o_ref, *, bk: int):
acc = jnp.zeros(o_ref.shape, dtype=jnp.float32)
def body(i, acc_ref):
x_block = pl.load(x_ref, (slice(None), pl.ds(i * bk, bk)))
y_block = pl.load(y_ref, (pl.ds(i * bk, bk), slice(None)))
acc_ref[:, :] += jnp.dot(x_block, y_block)
acc = for_loop(k // bk, body, acc).astype(o_ref.dtype)
o_ref[:, :] = acc
return matmul_kernel(x, y)

@jax.jit
def matmul_reference(x, y):
return jnp.dot(x, y)
Loading