From f0beb4cb899bc5c7435179cf516b73988df85587 Mon Sep 17 00:00:00 2001 From: The jax_triton Authors Date: Wed, 13 Dec 2023 01:23:42 -0800 Subject: [PATCH] Import new version of Triton PiperOrigin-RevId: 590496513 --- jax_triton/triton_lib.py | 173 ++++++++++++++++++++++++-------------- tests/triton_call_test.py | 6 +- 2 files changed, 114 insertions(+), 65 deletions(-) diff --git a/jax_triton/triton_lib.py b/jax_triton/triton_lib.py index fbf784d1..5b05a943 100644 --- a/jax_triton/triton_lib.py +++ b/jax_triton/triton_lib.py @@ -13,6 +13,7 @@ # limitations under the License. """Module for calling Triton kernels from JAX.""" + # b/301982023 from __future__ import annotations @@ -45,6 +46,9 @@ import triton.language as tl from triton.runtime import autotuner import triton._C.libtriton.triton as _triton + from triton.common.backend import get_backend + import triton.compiler.backends.cuda as cb + CAN_USE_TRITON = True except ModuleNotFoundError: pass @@ -52,8 +56,11 @@ from jax._src.lib import gpu_triton as triton_kernel_call_lib except ImportError: raise ValueError( - "Cannot import jaxlib triton library. You may need a newer version of jaxlib. Try installing a nightly wheel from: https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda_releases.html or https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda12_releases.html" - ) + "Cannot import jaxlib triton library. You may need a newer" + " version of jaxlib. Try installing a nightly wheel from:" + " https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda_releases.html" + " or https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda12_releases.html" + ) os.environ["TRITON_CACHE_DIR"] = "" map, unsafe_map = util.safe_map, map @@ -96,17 +103,18 @@ def normalize_grid(grid: GridOrLambda, metaparams) -> Tuple[int, int, int]: def avals_to_layouts(avals): return [list(reversed(range(aval.ndim))) for aval in avals] + def get_triton_type(obj: Any) -> str: if isinstance(obj, (jax.core.ShapedArray, state.AbstractRef)): return f"*{_JAX_TO_TRITON_TYPE_MAP[obj.dtype]}" if isinstance(obj, tl.constexpr): obj = obj.value if isinstance(obj, int): - if -2**31 <= obj < 2**31: + if -(2**31) <= obj < 2**31: return "i32" elif 2**31 <= obj < 2**32: return "u32" - elif -2**63 <= obj < 2**63: + elif -(2**63) <= obj < 2**63: return "i64" elif 2**63 <= obj < 2**64: return "u64" @@ -131,7 +139,8 @@ def get_triton_type(obj: Any) -> str: triton_kernel_call_p = jax.core.Primitive("triton_kernel_call") triton_kernel_call_p.multiple_results = True triton_kernel_call_p.def_impl( - functools.partial(xla.apply_primitive, triton_kernel_call_p)) + functools.partial(xla.apply_primitive, triton_kernel_call_p) +) @triton_kernel_call_p.def_abstract_eval @@ -147,65 +156,82 @@ def aval_size_bytes(aval): def ptx_get_kernel_name(module) -> str: - return tc.get_kernel_name(module, pattern='// .globl') + return cb.get_kernel_name(module, pattern="// .globl") + + +def get_arch_default_num_warps(device_type): + if device_type in ["cuda", "hip"]: + num_warps = 4 + else: + device_backend = get_backend(device_type) + assert device_backend + arch = device_backend.get_architecture_descriptor() + num_warps = arch["num_warps"] + return num_warps + + +def get_arch_default_num_stages(device_type, capability): + if device_type == "cuda": + num_stages = 3 if capability >= 75 else 2 + else: + device_backend = get_backend(device_type) + assert device_backend + arch = device_backend.get_architecture_descriptor() + num_stages = arch["num_stages"] + + return num_stages def compile_ttir_to_ptx_inplace( ttir, + cuda_backend: cb.CUDABackend, + cuda_options: cb.CUDAOptions, device: int = 0, device_type: str = "cuda", - num_warps: Optional[int] = None, - num_stages: Optional[int] = None, - num_ctas: int = 1, - enable_fp_fusion: bool = True, - enable_warp_specialization: bool = False, - enable_persistent: bool = False, - dump: bool = False, ) -> Tuple[str, str, int, int]: compute_capability = triton_kernel_call_lib.get_compute_capability(device) - if num_warps is None: - num_warps = tc.get_arch_default_num_warps(device_type) - if num_stages is None: - num_stages = tc.get_arch_default_num_stages( + if cuda_options.num_warps is None: + cuda_options.num_warps = get_arch_default_num_warps(device_type) + if cuda_options.num_stages is None: + cuda_options.num_stages = get_arch_default_num_stages( device_type, capability=compute_capability ) - if dump: + if cuda_options.debug: print(ttir) try: - target = tc.CudaTargetDescriptor( - capability=compute_capability, - num_warps=num_warps, - enable_fp_fusion=enable_fp_fusion, - ) - ttir = tc.optimize_ttir(ttir, target) - ttgir = tc.ttir_to_ttgir(ttir, num_warps, num_ctas=num_ctas, target=target) - ttgir = tc.optimize_ttgir( - ttgir, - num_stages, - num_warps, - num_ctas=num_ctas, - target=target, - cluster_info=_triton.ClusterInfo(), - enable_warp_specialization=enable_warp_specialization, - enable_persistent=enable_persistent, - optimize_epilogue=False, + metadata = dict() + opt_ttir = cuda_backend.make_ttir(ttir, metadata, cuda_options) + ttgir = cuda_backend.make_ttgir( + opt_ttir, + metadata, + cuda_options, + compute_capability, ) except RuntimeError as e: ttir.dump() raise ValueError("TTIR->TTGIR pass failed!") from e - if dump: + if cuda_options.debug: print(ttgir) - extern_libs = {} try: - llir = tc.ttgir_to_llir(ttgir, extern_libs, target, _triton.TMAInfos()) + llir = cuda_backend.make_llir( + ttgir, + metadata, + cuda_backend.parse_linker_options(dict()), + compute_capability, + ) except RuntimeError as e: ttgir.dump() raise ValueError("TTGIR->LLIR pass failed!") from e shared_mem_bytes = _triton.get_shared_memory_size(ttgir) - if dump: + if cuda_options.debug: print(llir) - ptx = tc.llir_to_ptx(llir, target) - if dump: + ptx = cuda_backend.make_ptx( + llir, + metadata, + cuda_options, + compute_capability, + ) + if cuda_options.debug: print(ptx) name = ptx_get_kernel_name(ptx) return ptx, name, shared_mem_bytes, compute_capability @@ -230,29 +256,29 @@ def get_or_create_triton_kernel( ) -> Tuple[triton_kernel_call_lib.TritonKernel, Any]: device_type = "cuda" if num_warps is None: - num_warps = tc.get_arch_default_num_warps(device_type) + num_warps = get_arch_default_num_warps(device_type) signature = dict(enumerate(arg_dtypes)) # TODO(sharadmv,zhangqiaorjc): handle differently aligned pointers # We assume that all arrays are aligned to 16 bytes, and Triton may use this # assumption, unless array args are include in the `do_not_specialize` list. # We replace array arguments with mock Torch tensors, to allow us to use - # `JITFunction._get_config` to get the specialization. + # `JITFunction._get_config` to get the specialization_attr. mock_torch_tensor = types.SimpleNamespace(data_ptr=lambda: 16) - args_for_specialization = [mock_torch_tensor] * len(arg_dtypes) + args_for_specialization_attr = [mock_torch_tensor] * len(arg_dtypes) for i, _, v in scalar_args: - args_for_specialization[i] = v - specialization = fn._get_config(*args_for_specialization) # pylint: disable=protected-access + args_for_specialization_attr[i] = v + specialization_attr = fn._get_config(*args_for_specialization_attr) # pylint: disable=protected-access constants = {fn.arg_names.index(k): v for k, v in metaparams.items()} constants.update({i: None for i, _, v in scalar_args if v is None}) - constants.update({i: 1 for i in specialization.equal_to_1}) + constants.update({i: 1 for i in specialization_attr.equal_to_1}) # Cache key should contain any parameter that can affect the compiler output. cache_key = ( fn, tuple(signature.items()), - specialization, + tuple(vars(specialization_attr).values()), tuple(constants.items()), num_warps, num_stages, @@ -269,22 +295,43 @@ def get_or_create_triton_kernel( # general. device = 0 arch = triton_kernel_call_lib.get_compute_capability(device) + + target = ("cuda", arch) + cuda_backend = cb.CUDABackend(target) + + cuda_options = cuda_backend.parse_compiler_options( + dict( + num_warps=num_warps, + num_stages=num_stages, + num_ctas=num_ctas, + cluster_dims=(1, 1, 1), + enable_warp_specialization=enable_warp_specialization, + enable_persistent=enable_persistent, + optimize_epilogue=False, + debug=dump, + enable_fp_fusion=enable_fp_fusion, + ) + ) + module = code_gen.ast_to_ttir( - fn, signature, specialization, constants, debug=dump, target=arch + fn, + specialization=tc.ASTSource( + fn, + constants=constants, + signature=signature, + attrs=specialization_attr, + ), + options=cuda_options, ) + ttir = str(module) # `module`` is compiled in-place, so copy TTIR here. ptx, kernel_name, shared_mem_bytes, compute_capability = ( compile_ttir_to_ptx_inplace( module, + cuda_backend, + cuda_options, device=device, device_type=device_type, - num_warps=num_warps, - num_stages=num_stages, - num_ctas=num_ctas, - enable_fp_fusion=enable_fp_fusion, - enable_warp_specialization=enable_warp_specialization, - enable_persistent=enable_persistent, - dump=dump, ) ) @@ -294,7 +341,7 @@ def get_or_create_triton_kernel( _COMPILED_KERNEL_CACHE[cache_key] = kernel - return kernel, specialization + return kernel, specialization_attr def triton_kernel_call_lowering( @@ -320,7 +367,8 @@ def triton_kernel_call_lowering( ): if jaxlib.version.__version_info__ < (0, 3, 22) and input_output_aliases: raise NotImplementedError( - "`input_output_aliases` only supported on `jaxlib>=0.3.22") + "`input_output_aliases` only supported on `jaxlib>=0.3.22" + ) kernel_call_name = name args = list(ctx.avals_in) @@ -419,7 +467,7 @@ def prune_configs(configs, named_args): kernel_calls = [] for params in config_params: - kernel, specialization = get_or_create_triton_kernel( + kernel, specialization_attr = get_or_create_triton_kernel( fn, arg_dtypes, scalar_args, @@ -440,10 +488,10 @@ def prune_configs(configs, named_args): kernel_params.append( triton_kernel_call_lib.create_array_parameter( zeroed_params_with_sizes.get(i, 0), - 16 if (i in specialization.divisible_by_16) else 0, + 16 if (i in specialization_attr.divisible_by_16) else 0, ) ) - elif i not in specialization.equal_to_1: + elif i not in specialization_attr.equal_to_1: kernel_params.append( triton_kernel_call_lib.create_scalar_parameter(arg, dtype) ) @@ -615,7 +663,8 @@ def add(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: "`triton_call` is only available when `triton` is installed." ) out_shape = tree_util.tree_map( - lambda a: jax.ShapeDtypeStruct(a.shape, a.dtype), out_shape) + lambda a: jax.ShapeDtypeStruct(a.shape, a.dtype), out_shape + ) flat_args, _ = tree_util.tree_flatten(args) # TODO(sharadmv): check in_tree is flat (no Pytrees allowed in triton_call) flat_out_shapes, out_tree = tree_util.tree_flatten(out_shape) diff --git a/tests/triton_call_test.py b/tests/triton_call_test.py index ebdf41e9..b9df7a81 100644 --- a/tests/triton_call_test.py +++ b/tests/triton_call_test.py @@ -539,15 +539,15 @@ def test_specialization(self): pass # Error thrown as the mocked method's return value is invalid. mock_compile.assert_called_once() - specialization = mock_compile.call_args.args[2] + specialization = mock_compile.call_args[1]['specialization'] # Pointers are assumed to divide by 16, as do `M`, `N`, `stride_{bk,cm}`. # However, we've marked `a_ptr`, `M`, `stride_bk`, and `c_ptr` as "do not # specialize", leaving `b_ptr`, `N`, and `stride_cm`. - self.assertEqual(specialization.divisible_by_16, (1, 3, 9)) + self.assertEqual(specialization.attrs.divisible_by_16, (1, 3, 9)) # `stride_{ak,bn,cn}` equal 1, but we've marked `stride_ak` as "do not # specialize" leaving `stride_{bn,cn}`. - self.assertEqual(specialization.equal_to_1, (8, 10)) + self.assertEqual(specialization.attrs.equal_to_1, (8, 10)) if __name__ == "__main__":