From 0a292ccdd83ea81efe4757175f6e6983437b5f07 Mon Sep 17 00:00:00 2001 From: The jax_triton Authors Date: Fri, 15 Dec 2023 12:12:35 -0800 Subject: [PATCH] Check that a kernel can be launched before adding it to candidate kernels PiperOrigin-RevId: 591324734 --- jax_triton/triton_lib.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/jax_triton/triton_lib.py b/jax_triton/triton_lib.py index 3617eca1..24432ec2 100644 --- a/jax_triton/triton_lib.py +++ b/jax_triton/triton_lib.py @@ -496,15 +496,16 @@ def prune_configs(configs, named_args): if num_ctas != 1: raise ValueError("`num_ctas != 1` is not yet supported.") - kernel_calls.append( - triton_kernel_call_lib.TritonKernelCall( - kernel, - params["grid"][0], - params["grid"][1], - params["grid"][2], - kernel_params, - ) - ) + if kernel.can_launch(): + kernel_calls.append( + triton_kernel_call_lib.TritonKernelCall( + kernel, + params["grid"][0], + params["grid"][1], + params["grid"][2], + kernel_params, + ) + ) if len(kernel_calls) > 1: named_scalar_args = {fn.arg_names[i]: v for i, _, v in scalar_args}