Skip to content

Commit

Permalink
Check that a kernel can be launched before adding it to candidate ker…
Browse files Browse the repository at this point in the history
…nels

PiperOrigin-RevId: 591324734
  • Loading branch information
The jax_triton Authors committed Jan 9, 2024
1 parent 4a5791d commit 0a292cc
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions jax_triton/triton_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,15 +496,16 @@ def prune_configs(configs, named_args):
if num_ctas != 1:
raise ValueError("`num_ctas != 1` is not yet supported.")

kernel_calls.append(
triton_kernel_call_lib.TritonKernelCall(
kernel,
params["grid"][0],
params["grid"][1],
params["grid"][2],
kernel_params,
)
)
if kernel.can_launch():
kernel_calls.append(
triton_kernel_call_lib.TritonKernelCall(
kernel,
params["grid"][0],
params["grid"][1],
params["grid"][2],
kernel_params,
)
)

if len(kernel_calls) > 1:
named_scalar_args = {fn.arg_names[i]: v for i, _, v in scalar_args}
Expand Down

0 comments on commit 0a292cc

Please sign in to comment.