Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[jax_triton] Always use "triton_kernel_call" as custom call target name. #221

Merged
merged 1 commit into from
Aug 22, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 2 additions & 9 deletions jax_triton/triton_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import jax.dlpack
from jax.interpreters import mlir
from jax.interpreters import xla
from jax.lib import xla_client as xc
import jax.numpy as jnp
import numpy as np

Expand Down Expand Up @@ -259,7 +258,6 @@ def triton_kernel_call_lowering(
fn,
scalar_args,
name,
call_name,
out_shapes,
grid,
num_warps,
Expand Down Expand Up @@ -404,7 +402,7 @@ def prune_configs(configs, named_args):
for input_idx, output_idx in input_output_aliases
)
kernel_call = triton_kernel_call_lib.TritonAutotunedKernelCall(
f"{fn.fn.__name__} ({call_name=}) {named_scalar_args}",
f"{kernel_call_name} ({fn.fn.__name__}) {named_scalar_args}",
[(call, str(config)) for call, config in zip(kernel_calls, configs)],
input_output_aliases_with_sizes,
)
Expand All @@ -420,7 +418,7 @@ def prune_configs(configs, named_args):
else:
call_proto = kernel_call.to_proto(serialized_metadata)
return jaxlib.hlo_helpers.custom_call(
call_target_name=call_name,
call_target_name="triton_kernel_call",
out_types=out_types,
operands=array_args,
backend_config=zlib.compress(call_proto),
Expand Down Expand Up @@ -450,7 +448,6 @@ def triton_call(
out_shape: Union[ShapeDtype, Sequence[ShapeDtype]],
grid: GridOrLambda,
name: str = "",
call_name: str = "triton_kernel_call", # TODO(cjfj): Remove this.
num_warps: int = 4,
num_stages: int = 2,
input_output_aliases: Optional[Dict[int, int]] = None,
Expand Down Expand Up @@ -545,9 +542,6 @@ def add(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
raise ValueError(
"`triton_call` is only available when `triton` is installed."
)
xc.register_custom_call_target(
call_name, triton_kernel_call_lib.get_custom_call(), platform="CUDA"
)
out_shape = tree_util.tree_map(
lambda a: jax.ShapeDtypeStruct(a.shape, a.dtype), out_shape)
flat_args, _ = tree_util.tree_flatten(args)
Expand All @@ -572,7 +566,6 @@ def add(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
fn=kernel,
scalar_args=tuple(scalar_args),
name=name,
call_name=call_name,
out_shapes=tuple(flat_out_shapes),
grid=grid,
num_warps=num_warps,
Expand Down
Loading