diff --git a/jax_triton/triton_lib.py b/jax_triton/triton_lib.py index 4ef52c42..7cfd986b 100644 --- a/jax_triton/triton_lib.py +++ b/jax_triton/triton_lib.py @@ -135,9 +135,6 @@ def get_triton_type(obj: Any) -> str: ) -Grid = Union[int, Tuple[int], Tuple[int, int], Tuple[int, int, int]] -GridOrLambda = Union[Grid, Callable[[Dict[str, Any]], Grid]] - triton_kernel_call_p = jax.core.Primitive("triton_kernel_call") triton_kernel_call_p.multiple_results = True triton_kernel_call_p.def_impl( @@ -240,7 +237,8 @@ def compile_ttir_to_ptx_inplace( if cuda_options.debug: print(ptx) name = ptx_get_kernel_name(ptx) - return ptx, name, shared_mem_bytes, compute_capability + cluster_dims = metadata["cluster_dims"] + return ptx, name, shared_mem_bytes, compute_capability, cluster_dims _COMPILED_KERNEL_CACHE = {} # TODO(cjfj): Convert to LRU cache? @@ -312,7 +310,6 @@ def get_or_create_triton_kernel( num_warps=num_warps, num_stages=num_stages, num_ctas=num_ctas, - cluster_dims=(1, 1, 1), enable_warp_specialization=enable_warp_specialization, enable_persistent=enable_persistent, optimize_epilogue=False, @@ -338,7 +335,7 @@ def get_or_create_triton_kernel( ) ttir = str(module) # `module`` is compiled in-place, so copy TTIR here. - ptx, kernel_name, shared_mem_bytes, compute_capability = ( + ptx, kernel_name, shared_mem_bytes, compute_capability, cluster_dims = ( compile_ttir_to_ptx_inplace( module, context, @@ -350,7 +347,13 @@ def get_or_create_triton_kernel( ) kernel = triton_kernel_call_lib.TritonKernel( - kernel_name, num_warps, shared_mem_bytes, ptx, ttir, compute_capability + kernel_name, + num_warps, + shared_mem_bytes, + ptx, + ttir, + compute_capability, + *cluster_dims, ) _COMPILED_KERNEL_CACHE[cache_key] = kernel @@ -580,7 +583,7 @@ def triton_call( custom_call_target_name: str = "triton_kernel_call", num_warps: Optional[int] = None, num_stages: Optional[int] = None, - num_ctas: int = 1, + num_ctas: int = 1, # TODO(giorgioa): Add support for dimensions tuple. enable_fp_fusion: bool = True, enable_warp_specialization: bool = False, enable_persistent: bool = False, @@ -663,6 +666,8 @@ def add(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: indices, for outputs that should be zeroed before the kernel is launched. num_warps: The number of warps used to execute the Triton kernel. num_stages: The number of stages emitted by the Triton compiler. + num_ctas: The size of thread blocks per cluster to be used on GPUs with + compute capabilities >= 9.0. It must be less or equal to 8. debug: Prints out intermediate IRs if True for debugging purposes. serialized_metadata: Arbitrary metadata that will be added into the serialized kernel call.