diff --git a/jax_triton/triton_lib.py b/jax_triton/triton_lib.py index 565a7c2..3617eca 100644 --- a/jax_triton/triton_lib.py +++ b/jax_triton/triton_lib.py @@ -45,7 +45,7 @@ from triton.compiler import compiler as tc import triton.language as tl from triton.runtime import autotuner - import triton._C.libtriton.triton as _triton + import triton._C.libtriton as _triton from triton.common.backend import get_backend import triton.compiler.backends.cuda as cb @@ -216,7 +216,7 @@ def compile_ttir_to_ptx_inplace( except RuntimeError as e: ttgir.dump() raise ValueError("TTGIR->LLIR pass failed!") from e - shared_mem_bytes = _triton.get_shared_memory_size(ttgir) + shared_mem_bytes = _triton.translation.get_shared_memory_size(ttgir) if cuda_options.debug: print(llir) ptx = cuda_backend.make_ptx(