From 5ead0b081d5b2a77f885799b69d242b6c6bdaa1d Mon Sep 17 00:00:00 2001 From: Stephen Huan Date: Sun, 22 Dec 2024 02:22:29 -0800 Subject: [PATCH] Add Heuristics and Autotuner to type annotation --- jax_triton/triton_lib.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax_triton/triton_lib.py b/jax_triton/triton_lib.py index a98351f..3aaded6 100644 --- a/jax_triton/triton_lib.py +++ b/jax_triton/triton_lib.py @@ -703,7 +703,7 @@ def dtype(self) -> np.dtype: def triton_call( *args: jax.Array | bool | int | float | np.float32, - kernel: triton.JITFunction, + kernel: triton.JITFunction | triton.runtime.Heuristics | triton.runtime.Autotuner, out_shape: ShapeDtype | Sequence[ShapeDtype], grid: GridOrLambda, name: str = "",