Skip to content

Commit 0581131

Browse files
Merge pull request #323 from stephen-huan:fix-annotation
PiperOrigin-RevId: 774739180
2 parents d0cfcec + 5ead0b0 commit 0581131

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

jax_triton/triton_lib.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -729,7 +729,7 @@ def dtype(self) -> np.dtype:
729729

730730
def triton_call(
731731
*args: jax.Array | bool | int | float | np.float32,
732-
kernel: triton.JITFunction,
732+
kernel: triton.JITFunction | triton.runtime.Heuristics | triton.runtime.Autotuner,
733733
out_shape: ShapeDtype | Sequence[ShapeDtype],
734734
grid: GridOrLambda,
735735
name: str = "",

0 commit comments

Comments
 (0)