Skip to content

Commit

Permalink
Integrate Triton up to [632bfc34](https://github.com/openai/triton/co…
Browse files Browse the repository at this point in the history
  • Loading branch information
vwbaker authored and Google-ML-Automation committed Jan 14, 2025
1 parent 859cc39 commit 544469c
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 92 deletions.
69 changes: 15 additions & 54 deletions jax_triton/triton_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import os
import pprint
import tempfile
import types
from typing import Any, Protocol, Union
import zlib

Expand Down Expand Up @@ -339,7 +338,7 @@ def get_or_create_triton_kernel(
enable_fp_fusion,
metaparams,
dump: bool,
) -> tuple[triton_kernel_call_lib.TritonKernel, Any]:
) -> triton_kernel_call_lib.TritonKernel:
if num_warps is None:
num_warps = 4
if num_stages is None:
Expand All @@ -354,29 +353,16 @@ def get_or_create_triton_kernel(
raise ValueError("num_ctas > 1 unsupported before Hopper.")

signature = {fn.arg_names[i]: v for i, v in enumerate(arg_dtypes)}
# TODO(sharadmv,zhangqiaorjc): handle differently aligned pointers
# We assume that all arrays are aligned to 16 bytes, and Triton may use this
# assumption, unless array args are include in the `do_not_specialize` list.
# We replace array arguments with mock Torch tensors, to allow us to use
# `JITFunction._get_config` to get the specialization_attr.
mock_torch_tensor = types.SimpleNamespace(data_ptr=lambda: 16)
args_for_specialization_attr = [mock_torch_tensor] * len(arg_dtypes)
backend = backend_init_func(device, compute_capability)
for i, _, v in scalar_args:
args_for_specialization_attr[i] = v

specialization_attr = backend.get_attrs_descriptor(fn.params[:len(args_for_specialization_attr)], args_for_specialization_attr) # pylint: disable=protected-access

constants = dict(metaparams)
constants.update({k: None for _, k, v in scalar_args if v is None})
constants.update({fn.arg_names[i]: 1 for (i,) in specialization_attr.equal_to_1})
for constant in constants:
signature[constant] = "constexpr"

# Cache key should contain any parameter that can affect the compiler output.
cache_key = (
fn,
tuple(signature.items()),
tuple(specialization_attr.get_fn_attrs()),
tuple(constants.items()),
num_warps,
num_stages,
Expand All @@ -396,6 +382,7 @@ def get_or_create_triton_kernel(
"enable_fp_fusion": enable_fp_fusion,
}

backend = backend_init_func(device, compute_capability)
options = backend.parse_options(opts)

kernel_hash = abs(hash(cache_key))
Expand All @@ -410,44 +397,18 @@ def get_or_create_triton_kernel(
backend.load_dialects(context)
codegen_fns = backend.get_codegen_implementation()

module = (
code_gen.ast_to_ttir(
fn,
specialization=tc.ASTSource(
fn,
constexprs=constants,
signature=signature,
attrs=specialization_attr,
),
options=options,
codegen_fns=codegen_fns,
context=context,
module_map=backend.get_module_map(),
)
if "module_map" in inspect.getfullargspec(code_gen.ast_to_ttir).args
# Triton changes ASTSource.ast_to_ttir to include module_map. Handle
# backward compatibility here.
else code_gen.ast_to_ttir(
fn,
specialization=tc.ASTSource(
fn,
constexprs=constants,
signature=signature,
attrs=specialization_attr,
),
options=options,
codegen_fns=codegen_fns,
context=context,
)
module = code_gen.ast_to_ttir(
fn,
tc.ASTSource(fn, constexprs=constants, signature=signature),
options=options,
codegen_fns=codegen_fns,
context=context,
module_map=backend.get_module_map(),
)
ttir = str(module)

compilation_result = compile_ttir_inplace(
module,
backend,
options,
compute_capability,
platform
module, backend, options, compute_capability, platform
)

kernel_name = compilation_result.name
Expand Down Expand Up @@ -490,7 +451,7 @@ def get_or_create_triton_kernel(

_COMPILED_KERNEL_CACHE[cache_key] = kernel

return kernel, specialization_attr
return kernel


def triton_kernel_call_lowering(
Expand Down Expand Up @@ -611,7 +572,7 @@ def prune_configs(configs, named_args, **kwargs):

kernel_calls = []
for params in config_params:
kernel, specialization_attr = get_or_create_triton_kernel(
kernel = get_or_create_triton_kernel(
backend_init_func,
ctx.module_context.platforms[0],
fn,
Expand All @@ -633,10 +594,10 @@ def prune_configs(configs, named_args, **kwargs):
kernel_params.append(
triton_kernel_call_lib.create_array_parameter(
zeroed_params_with_sizes.get(i, 0),
16 if (i in specialization_attr.divisibility_16) else 0,
0,
)
)
elif (i,) not in specialization_attr.equal_to_1:
else:
kernel_params.append(
triton_kernel_call_lib.create_scalar_parameter(arg, dtype)
)
Expand Down
38 changes: 0 additions & 38 deletions tests/triton_call_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,44 +531,6 @@ def test_autotune_with_input_output_aliasing(self):
out = add(x, y, kernel=kernel, input_output_aliases={0: 0})
np.testing.assert_allclose(out, expected)

def test_specialization(self):
do_not_specialize = (
0, # a_ptr
2, # M
6, # stride_ak
7, # stride_bk
11, # c_ptr
)
kernel = triton.jit(do_not_specialize=do_not_specialize)(matmul_kernel.fn)

m, n, k = 128, 128, 99
x, y = create_random_inputs([m, k], [k, n])

with mock.patch.object(code_gen, "ast_to_ttir") as mock_compile:
try:
_ = matmul(
x,
y,
kernel=kernel,
BLOCK_SIZE_M=32,
BLOCK_SIZE_N=32,
BLOCK_SIZE_K=32,
# K_EXACTLY_DIVISIBLE_BY_BLOCK=False,
)
except TypeError:
pass # Error thrown as the mocked method's return value is invalid.

mock_compile.assert_called_once()
specialization = mock_compile.call_args[1]['specialization']

# Pointers are assumed to divide by 16, as do `M`, `N`, `stride_{bk,cm}`.
# However, we've marked `a_ptr`, `M`, `stride_bk`, and `c_ptr` as "do not
# specialize", leaving `b_ptr`, `N`, and `stride_cm`.
self.assertEqual(specialization.attrs.divisibility_16, [(1,), (3,), (9,)])
# `stride_{ak,bn,cn}` equal 1, but we've marked `stride_ak` as "do not
# specialize" leaving `stride_{bn,cn}`.
self.assertEqual(specialization.attrs.equal_to_1, [(8,), (10,)])


if __name__ == "__main__":
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.5"
Expand Down

0 comments on commit 544469c

Please sign in to comment.