Skip to content

Commit

Permalink
[jax_triton] Add parameter allowing user to compile for specific comp…
Browse files Browse the repository at this point in the history
…ute capability.

PiperOrigin-RevId: 612431401
  • Loading branch information
chr1sj0nes authored and The jax_triton Authors committed Mar 4, 2024
1 parent 4a50e1d commit 9f04807
Showing 1 changed file with 20 additions and 19 deletions.
39 changes: 20 additions & 19 deletions jax_triton/triton_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,10 +157,8 @@ def compile_ttir_to_ptx_inplace(
ttir,
cuda_backend: cb.CUDABackend,
cuda_options: cb.CUDAOptions,
device: int = 0,
device_type: str = "cuda",
) -> Tuple[str, str, int, int, Any]:
compute_capability = triton_kernel_call_lib.get_compute_capability(device)
compute_capability,
) -> Tuple[str, str, int, Any]:
if cuda_options.debug:
print(ttir)
if isinstance(ttir, ir.Module):
Expand Down Expand Up @@ -213,7 +211,7 @@ def compile_ttir_to_ptx_inplace(
print(ptx)
name = metadata["name"]
cluster_dims = metadata["cluster_dims"]
return ptx, name, shared_mem_bytes, compute_capability, cluster_dims
return ptx, name, shared_mem_bytes, cluster_dims


_COMPILED_KERNEL_CACHE = {} # TODO(cjfj): Convert to LRU cache?
Expand All @@ -227,22 +225,23 @@ def get_or_create_triton_kernel(
num_warps,
num_stages,
num_ctas,
compute_capability,
enable_fp_fusion,
metaparams,
dump: bool,
) -> Tuple[triton_kernel_call_lib.TritonKernel, Any]:
device_type = "cuda"
if num_warps is None:
num_warps = 4
# TODO(sharadmv): handle multiple devices, right now we assume device 0
# which is fine when we have multiple of the same GPU but this won't work in
# general.
device = 0
arch = triton_kernel_call_lib.get_compute_capability(device)
if num_ctas > 1 and arch < 90:
raise ValueError("num_ctas > 1 unsupported before Hopper.")
if num_stages is None:
num_stages = 3
if compute_capability is None:
# TODO(sharadmv): handle multiple devices, right now we assume device 0
# which is fine when we have multiple of the same GPU but this won't work in
# general.
device = 0
compute_capability = triton_kernel_call_lib.get_compute_capability(device)
if num_ctas > 1 and compute_capability < 90:
raise ValueError("num_ctas > 1 unsupported before Hopper.")

signature = dict(enumerate(arg_dtypes))
# TODO(sharadmv,zhangqiaorjc): handle differently aligned pointers
Expand All @@ -269,15 +268,14 @@ def get_or_create_triton_kernel(
num_warps,
num_stages,
num_ctas,
compute_capability,
enable_fp_fusion,
)
kernel = _COMPILED_KERNEL_CACHE.get(cache_key)

if kernel is None:

target = ("cuda", arch)
target = ("cuda", compute_capability)
cuda_backend = cb.CUDABackend(target)

cuda_options = cuda_backend.parse_options(
dict(
num_warps=num_warps,
Expand Down Expand Up @@ -306,13 +304,12 @@ def get_or_create_triton_kernel(
)

ttir = str(module) # `module`` is compiled in-place, so copy TTIR here.
ptx, kernel_name, shared_mem_bytes, compute_capability, cluster_dims = (
ptx, kernel_name, shared_mem_bytes, cluster_dims = (
compile_ttir_to_ptx_inplace(
module,
cuda_backend,
cuda_options,
device=device,
device_type=device_type,
compute_capability,
)
)

Expand Down Expand Up @@ -343,6 +340,7 @@ def triton_kernel_call_lowering(
num_warps,
num_stages,
num_ctas,
compute_capability,
enable_fp_fusion,
input_output_aliases,
zeroed_outputs,
Expand Down Expand Up @@ -455,6 +453,7 @@ def prune_configs(configs, named_args):
num_warps=params["num_warps"],
num_stages=params["num_stages"],
num_ctas=params["num_ctas"],
compute_capability=compute_capability,
enable_fp_fusion=enable_fp_fusion,
metaparams=dict(params["metaparams"]),
dump=debug,
Expand Down Expand Up @@ -542,6 +541,7 @@ def triton_call(
num_warps: Optional[int] = None,
num_stages: Optional[int] = None,
num_ctas: int = 1, # TODO(giorgioa): Add support for dimensions tuple.
compute_capability: Optional[int] = None,
enable_fp_fusion: bool = True,
input_output_aliases: Optional[Dict[int, int]] = None,
zeroed_outputs: Union[
Expand Down Expand Up @@ -668,6 +668,7 @@ def add(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
num_warps=num_warps,
num_stages=num_stages,
num_ctas=num_ctas,
compute_capability=compute_capability,
enable_fp_fusion=enable_fp_fusion,
input_output_aliases=tuple(input_output_aliases.items()),
zeroed_outputs=zeroed_outputs,
Expand Down

0 comments on commit 9f04807

Please sign in to comment.