diff --git a/jax_triton/triton_lib.py b/jax_triton/triton_lib.py index c6f5bdf0..f057da1c 100644 --- a/jax_triton/triton_lib.py +++ b/jax_triton/triton_lib.py @@ -319,6 +319,7 @@ def get_or_create_triton_kernel( _triton.ir.load_dialects(context) cuda_backend.load_dialects(context) codegen_fns = cuda_backend.get_codegen_implementation() + module_map = cuda_backend.get_module_map() module = code_gen.ast_to_ttir( fn, @@ -331,6 +332,7 @@ def get_or_create_triton_kernel( options=cuda_options, codegen_fns=codegen_fns, context=context, + module_map=module_map ) ttir = str(module)