diff --git a/jax_triton/triton_lib.py b/jax_triton/triton_lib.py index de4ec620..1cfe3789 100644 --- a/jax_triton/triton_lib.py +++ b/jax_triton/triton_lib.py @@ -159,7 +159,7 @@ def compile_ttir_to_ptx_inplace( cuda_options: cb.CUDAOptions, device: int = 0, device_type: str = "cuda", -) -> Tuple[str, str, int, int]: +) -> Tuple[str, str, int, int, Any]: compute_capability = triton_kernel_call_lib.get_compute_capability(device) if cuda_options.debug: print(ttir)