From e942d4ccce37b95e6f245dc001656531a408f318 Mon Sep 17 00:00:00 2001 From: Mohammed Anany Date: Tue, 9 Jan 2024 08:03:44 -0800 Subject: [PATCH] Fixes on top of Triton import PiperOrigin-RevId: 596935355 --- jax_triton/triton_lib.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/jax_triton/triton_lib.py b/jax_triton/triton_lib.py index 3617eca1..dfbbeeba 100644 --- a/jax_triton/triton_lib.py +++ b/jax_triton/triton_lib.py @@ -216,7 +216,7 @@ def compile_ttir_to_ptx_inplace( except RuntimeError as e: ttgir.dump() raise ValueError("TTGIR->LLIR pass failed!") from e - shared_mem_bytes = _triton.translation.get_shared_memory_size(ttgir) + shared_mem_bytes = metadata["shared"] if cuda_options.debug: print(llir) ptx = cuda_backend.make_ptx( @@ -309,6 +309,10 @@ def get_or_create_triton_kernel( ) ) + context = _triton.ir.context() + _triton.ir.load_dialects(context) + cuda_backend.load_dialects(context) + module = code_gen.ast_to_ttir( fn, specialization=tc.ASTSource( @@ -318,6 +322,7 @@ def get_or_create_triton_kernel( attrs=specialization_attr, ), options=cuda_options, + context=context, ) ttir = str(module) # `module`` is compiled in-place, so copy TTIR here.