Skip to content

Commit

Permalink
Internal fixes for Triton integration
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 590562619
  • Loading branch information
gflegar authored and The jax_triton Authors committed Dec 13, 2023
1 parent f0beb4c commit 8a7710a
Showing 1 changed file with 9 additions and 13 deletions.
22 changes: 9 additions & 13 deletions jax_triton/triton_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,12 +190,6 @@ def compile_ttir_to_ptx_inplace(
device_type: str = "cuda",
) -> Tuple[str, str, int, int]:
compute_capability = triton_kernel_call_lib.get_compute_capability(device)
if cuda_options.num_warps is None:
cuda_options.num_warps = get_arch_default_num_warps(device_type)
if cuda_options.num_stages is None:
cuda_options.num_stages = get_arch_default_num_stages(
device_type, capability=compute_capability
)
if cuda_options.debug:
print(ttir)
try:
Expand All @@ -216,7 +210,7 @@ def compile_ttir_to_ptx_inplace(
llir = cuda_backend.make_llir(
ttgir,
metadata,
cuda_backend.parse_linker_options(dict()),
cuda_options,
compute_capability,
)
except RuntimeError as e:
Expand Down Expand Up @@ -257,6 +251,13 @@ def get_or_create_triton_kernel(
device_type = "cuda"
if num_warps is None:
num_warps = get_arch_default_num_warps(device_type)
# TODO(sharadmv): handle multiple devices, right now we assume device 0
# which is fine when we have multiple of the same GPU but this won't work in
# general.
device = 0
arch = triton_kernel_call_lib.get_compute_capability(device)
if num_stages is None:
num_stages = get_arch_default_num_stages(device_type, arch)

signature = dict(enumerate(arg_dtypes))
# TODO(sharadmv,zhangqiaorjc): handle differently aligned pointers
Expand Down Expand Up @@ -290,16 +291,11 @@ def get_or_create_triton_kernel(
kernel = _COMPILED_KERNEL_CACHE.get(cache_key)

if kernel is None:
# TODO(sharadmv): handle multiple devices, right now we assume device 0
# which is fine when we have multiple of the same GPU but this won't work in
# general.
device = 0
arch = triton_kernel_call_lib.get_compute_capability(device)

target = ("cuda", arch)
cuda_backend = cb.CUDABackend(target)

cuda_options = cuda_backend.parse_compiler_options(
cuda_options = cuda_backend.parse_options(
dict(
num_warps=num_warps,
num_stages=num_stages,
Expand Down

0 comments on commit 8a7710a

Please sign in to comment.