From 544469c7e6898fa73d0b6a6820dc6ece2172b0c3 Mon Sep 17 00:00:00 2001 From: Tori Baker Date: Fri, 3 Jan 2025 04:44:45 -0800 Subject: [PATCH] Integrate Triton up to [632bfc34](https://github.com/openai/triton/commits/632bfc342d3a7d63ce8b21209355139ee070d392) PiperOrigin-RevId: 711713649 --- jax_triton/triton_lib.py | 69 +++++++++------------------------------ tests/triton_call_test.py | 38 --------------------- 2 files changed, 15 insertions(+), 92 deletions(-) diff --git a/jax_triton/triton_lib.py b/jax_triton/triton_lib.py index ebf563d..23977aa 100644 --- a/jax_triton/triton_lib.py +++ b/jax_triton/triton_lib.py @@ -24,7 +24,6 @@ import os import pprint import tempfile -import types from typing import Any, Protocol, Union import zlib @@ -339,7 +338,7 @@ def get_or_create_triton_kernel( enable_fp_fusion, metaparams, dump: bool, -) -> tuple[triton_kernel_call_lib.TritonKernel, Any]: +) -> triton_kernel_call_lib.TritonKernel: if num_warps is None: num_warps = 4 if num_stages is None: @@ -354,21 +353,9 @@ def get_or_create_triton_kernel( raise ValueError("num_ctas > 1 unsupported before Hopper.") signature = {fn.arg_names[i]: v for i, v in enumerate(arg_dtypes)} - # TODO(sharadmv,zhangqiaorjc): handle differently aligned pointers - # We assume that all arrays are aligned to 16 bytes, and Triton may use this - # assumption, unless array args are include in the `do_not_specialize` list. - # We replace array arguments with mock Torch tensors, to allow us to use - # `JITFunction._get_config` to get the specialization_attr. - mock_torch_tensor = types.SimpleNamespace(data_ptr=lambda: 16) - args_for_specialization_attr = [mock_torch_tensor] * len(arg_dtypes) - backend = backend_init_func(device, compute_capability) - for i, _, v in scalar_args: - args_for_specialization_attr[i] = v - - specialization_attr = backend.get_attrs_descriptor(fn.params[:len(args_for_specialization_attr)], args_for_specialization_attr) # pylint: disable=protected-access + constants = dict(metaparams) constants.update({k: None for _, k, v in scalar_args if v is None}) - constants.update({fn.arg_names[i]: 1 for (i,) in specialization_attr.equal_to_1}) for constant in constants: signature[constant] = "constexpr" @@ -376,7 +363,6 @@ def get_or_create_triton_kernel( cache_key = ( fn, tuple(signature.items()), - tuple(specialization_attr.get_fn_attrs()), tuple(constants.items()), num_warps, num_stages, @@ -396,6 +382,7 @@ def get_or_create_triton_kernel( "enable_fp_fusion": enable_fp_fusion, } + backend = backend_init_func(device, compute_capability) options = backend.parse_options(opts) kernel_hash = abs(hash(cache_key)) @@ -410,44 +397,18 @@ def get_or_create_triton_kernel( backend.load_dialects(context) codegen_fns = backend.get_codegen_implementation() - module = ( - code_gen.ast_to_ttir( - fn, - specialization=tc.ASTSource( - fn, - constexprs=constants, - signature=signature, - attrs=specialization_attr, - ), - options=options, - codegen_fns=codegen_fns, - context=context, - module_map=backend.get_module_map(), - ) - if "module_map" in inspect.getfullargspec(code_gen.ast_to_ttir).args - # Triton changes ASTSource.ast_to_ttir to include module_map. Handle - # backward compatibility here. - else code_gen.ast_to_ttir( - fn, - specialization=tc.ASTSource( - fn, - constexprs=constants, - signature=signature, - attrs=specialization_attr, - ), - options=options, - codegen_fns=codegen_fns, - context=context, - ) + module = code_gen.ast_to_ttir( + fn, + tc.ASTSource(fn, constexprs=constants, signature=signature), + options=options, + codegen_fns=codegen_fns, + context=context, + module_map=backend.get_module_map(), ) ttir = str(module) compilation_result = compile_ttir_inplace( - module, - backend, - options, - compute_capability, - platform + module, backend, options, compute_capability, platform ) kernel_name = compilation_result.name @@ -490,7 +451,7 @@ def get_or_create_triton_kernel( _COMPILED_KERNEL_CACHE[cache_key] = kernel - return kernel, specialization_attr + return kernel def triton_kernel_call_lowering( @@ -611,7 +572,7 @@ def prune_configs(configs, named_args, **kwargs): kernel_calls = [] for params in config_params: - kernel, specialization_attr = get_or_create_triton_kernel( + kernel = get_or_create_triton_kernel( backend_init_func, ctx.module_context.platforms[0], fn, @@ -633,10 +594,10 @@ def prune_configs(configs, named_args, **kwargs): kernel_params.append( triton_kernel_call_lib.create_array_parameter( zeroed_params_with_sizes.get(i, 0), - 16 if (i in specialization_attr.divisibility_16) else 0, + 0, ) ) - elif (i,) not in specialization_attr.equal_to_1: + else: kernel_params.append( triton_kernel_call_lib.create_scalar_parameter(arg, dtype) ) diff --git a/tests/triton_call_test.py b/tests/triton_call_test.py index 1ff9a67..a1dfd78 100644 --- a/tests/triton_call_test.py +++ b/tests/triton_call_test.py @@ -531,44 +531,6 @@ def test_autotune_with_input_output_aliasing(self): out = add(x, y, kernel=kernel, input_output_aliases={0: 0}) np.testing.assert_allclose(out, expected) - def test_specialization(self): - do_not_specialize = ( - 0, # a_ptr - 2, # M - 6, # stride_ak - 7, # stride_bk - 11, # c_ptr - ) - kernel = triton.jit(do_not_specialize=do_not_specialize)(matmul_kernel.fn) - - m, n, k = 128, 128, 99 - x, y = create_random_inputs([m, k], [k, n]) - - with mock.patch.object(code_gen, "ast_to_ttir") as mock_compile: - try: - _ = matmul( - x, - y, - kernel=kernel, - BLOCK_SIZE_M=32, - BLOCK_SIZE_N=32, - BLOCK_SIZE_K=32, - # K_EXACTLY_DIVISIBLE_BY_BLOCK=False, - ) - except TypeError: - pass # Error thrown as the mocked method's return value is invalid. - - mock_compile.assert_called_once() - specialization = mock_compile.call_args[1]['specialization'] - - # Pointers are assumed to divide by 16, as do `M`, `N`, `stride_{bk,cm}`. - # However, we've marked `a_ptr`, `M`, `stride_bk`, and `c_ptr` as "do not - # specialize", leaving `b_ptr`, `N`, and `stride_cm`. - self.assertEqual(specialization.attrs.divisibility_16, [(1,), (3,), (9,)]) - # `stride_{ak,bn,cn}` equal 1, but we've marked `stride_ak` as "do not - # specialize" leaving `stride_{bn,cn}`. - self.assertEqual(specialization.attrs.equal_to_1, [(8,), (10,)]) - if __name__ == "__main__": os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.5"