From 8a7710a5624f1ef9f63b66713edc6c738c3448c7 Mon Sep 17 00:00:00 2001 From: Goran Flegar Date: Wed, 13 Dec 2023 05:26:13 -0800 Subject: [PATCH] Internal fixes for Triton integration PiperOrigin-RevId: 590562619 --- jax_triton/triton_lib.py | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/jax_triton/triton_lib.py b/jax_triton/triton_lib.py index 5b05a943..565a7c2f 100644 --- a/jax_triton/triton_lib.py +++ b/jax_triton/triton_lib.py @@ -190,12 +190,6 @@ def compile_ttir_to_ptx_inplace( device_type: str = "cuda", ) -> Tuple[str, str, int, int]: compute_capability = triton_kernel_call_lib.get_compute_capability(device) - if cuda_options.num_warps is None: - cuda_options.num_warps = get_arch_default_num_warps(device_type) - if cuda_options.num_stages is None: - cuda_options.num_stages = get_arch_default_num_stages( - device_type, capability=compute_capability - ) if cuda_options.debug: print(ttir) try: @@ -216,7 +210,7 @@ def compile_ttir_to_ptx_inplace( llir = cuda_backend.make_llir( ttgir, metadata, - cuda_backend.parse_linker_options(dict()), + cuda_options, compute_capability, ) except RuntimeError as e: @@ -257,6 +251,13 @@ def get_or_create_triton_kernel( device_type = "cuda" if num_warps is None: num_warps = get_arch_default_num_warps(device_type) + # TODO(sharadmv): handle multiple devices, right now we assume device 0 + # which is fine when we have multiple of the same GPU but this won't work in + # general. + device = 0 + arch = triton_kernel_call_lib.get_compute_capability(device) + if num_stages is None: + num_stages = get_arch_default_num_stages(device_type, arch) signature = dict(enumerate(arg_dtypes)) # TODO(sharadmv,zhangqiaorjc): handle differently aligned pointers @@ -290,16 +291,11 @@ def get_or_create_triton_kernel( kernel = _COMPILED_KERNEL_CACHE.get(cache_key) if kernel is None: - # TODO(sharadmv): handle multiple devices, right now we assume device 0 - # which is fine when we have multiple of the same GPU but this won't work in - # general. - device = 0 - arch = triton_kernel_call_lib.get_compute_capability(device) target = ("cuda", arch) cuda_backend = cb.CUDABackend(target) - cuda_options = cuda_backend.parse_compiler_options( + cuda_options = cuda_backend.parse_options( dict( num_warps=num_warps, num_stages=num_stages,