Skip to content

Commit

Permalink
Fixes on top of Triton import
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 596935355
  • Loading branch information
Moerafaat authored and The jax_triton Authors committed Jan 10, 2024
1 parent 4a5791d commit e942d4c
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion jax_triton/triton_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def compile_ttir_to_ptx_inplace(
except RuntimeError as e:
ttgir.dump()
raise ValueError("TTGIR->LLIR pass failed!") from e
shared_mem_bytes = _triton.translation.get_shared_memory_size(ttgir)
shared_mem_bytes = metadata["shared"]
if cuda_options.debug:
print(llir)
ptx = cuda_backend.make_ptx(
Expand Down Expand Up @@ -309,6 +309,10 @@ def get_or_create_triton_kernel(
)
)

context = _triton.ir.context()
_triton.ir.load_dialects(context)
cuda_backend.load_dialects(context)

module = code_gen.ast_to_ttir(
fn,
specialization=tc.ASTSource(
Expand All @@ -318,6 +322,7 @@ def get_or_create_triton_kernel(
attrs=specialization_attr,
),
options=cuda_options,
context=context,
)

ttir = str(module) # `module`` is compiled in-place, so copy TTIR here.
Expand Down

0 comments on commit e942d4c

Please sign in to comment.