Skip to content

Commit

Permalink
Import openai/triton from GitHub.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 597848585
  • Loading branch information
Moerafaat authored and The jax_triton Authors committed Jan 12, 2024
1 parent 3130126 commit b2ae828
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions jax_triton/triton_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def get_arch_default_num_stages(device_type, capability):

def compile_ttir_to_ptx_inplace(
ttir,
tl_context: tl_ir.Context,
cuda_backend: cb.CUDABackend,
cuda_options: cb.CUDAOptions,
device: int = 0,
Expand All @@ -198,8 +199,6 @@ def compile_ttir_to_ptx_inplace(
# Triton compilation APIs only accept Triton-specific MLIR wrappers.
# So, here we serialize an ir.Module to a file and then deserialize
# it as a tl_ir.module.
tl_context = tl_ir.context()
tl_context.load_triton()
with tempfile.NamedTemporaryFile(mode="wb") as f:
ttir.operation.write_bytecode(f)
f.flush()
Expand Down Expand Up @@ -229,7 +228,7 @@ def compile_ttir_to_ptx_inplace(
except RuntimeError as e:
ttgir.dump()
raise ValueError("TTGIR->LLIR pass failed!") from e
shared_mem_bytes = _triton.translation.get_shared_memory_size(ttgir)
shared_mem_bytes = metadata["shared"]
if cuda_options.debug:
print(llir)
ptx = cuda_backend.make_ptx(
Expand Down Expand Up @@ -322,6 +321,10 @@ def get_or_create_triton_kernel(
)
)

context = _triton.ir.context()
_triton.ir.load_dialects(context)
cuda_backend.load_dialects(context)

module = code_gen.ast_to_ttir(
fn,
specialization=tc.ASTSource(
Expand All @@ -331,12 +334,14 @@ def get_or_create_triton_kernel(
attrs=specialization_attr,
),
options=cuda_options,
context=context,
)

ttir = str(module) # `module`` is compiled in-place, so copy TTIR here.
ptx, kernel_name, shared_mem_bytes, compute_capability = (
compile_ttir_to_ptx_inplace(
module,
context,
cuda_backend,
cuda_options,
device=device,
Expand Down

0 comments on commit b2ae828

Please sign in to comment.