From 34146e8b39f197950f3bad42cd78b516e2ce3a9f Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Thu, 31 Aug 2023 04:50:32 -0700 Subject: [PATCH] Fix type annotation for `triton_call` `args`. PiperOrigin-RevId: 561619415 --- jax_triton/triton_lib.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax_triton/triton_lib.py b/jax_triton/triton_lib.py index ceeb7e97..d4f2cee0 100644 --- a/jax_triton/triton_lib.py +++ b/jax_triton/triton_lib.py @@ -443,7 +443,7 @@ def dtype(self) -> np.dtype: def triton_call( - *args: Union[jax.Array, bool, int, float], + *args: Union[jax.Array, bool, int, float, np.float32], kernel: triton.JITFunction, out_shape: Union[ShapeDtype, Sequence[ShapeDtype]], grid: GridOrLambda,