Skip to content

Commit

Permalink
Part of internal Triton integrate
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 595147751
  • Loading branch information
The jax_triton Authors committed Jan 3, 2024
1 parent 7778c47 commit fd4f7c0
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions jax_triton/triton_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from triton.compiler import compiler as tc
import triton.language as tl
from triton.runtime import autotuner
import triton._C.libtriton.triton as _triton
import triton._C.libtriton as _triton
from triton.common.backend import get_backend
import triton.compiler.backends.cuda as cb

Expand Down Expand Up @@ -156,7 +156,13 @@ def aval_size_bytes(aval):


def ptx_get_kernel_name(module) -> str:
return cb.get_kernel_name(module, pattern="// .globl")
# return cb.get_kernel_name(module, pattern="// .globl")
pattern = "// .globl"
assert module
for line in module.split("\n"):
line = line.strip()
if line.startswith(pattern):
return line.split()[-1]


def get_arch_default_num_warps(device_type):
Expand Down Expand Up @@ -216,7 +222,7 @@ def compile_ttir_to_ptx_inplace(
except RuntimeError as e:
ttgir.dump()
raise ValueError("TTGIR->LLIR pass failed!") from e
shared_mem_bytes = _triton.get_shared_memory_size(ttgir)
shared_mem_bytes = _triton.translation.get_shared_memory_size(ttgir)
if cuda_options.debug:
print(llir)
ptx = cuda_backend.make_ptx(
Expand Down

0 comments on commit fd4f7c0

Please sign in to comment.