Skip to content

Commit

Permalink
[jax_triton] Always use "triton_kernel_call" as custom call target name.
Browse files Browse the repository at this point in the history
Now there is a `name` field in the serialized proto, there is an alternative way to differentiate kernel calls.

PiperOrigin-RevId: 556806408
  • Loading branch information
chr1sj0nes authored and The jax_triton Authors committed Aug 14, 2023
1 parent a25b1ba commit 20d5f22
Showing 1 changed file with 13 additions and 10 deletions.
23 changes: 13 additions & 10 deletions jax_triton/triton_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def triton_kernel_call_lowering(
*array_args,
fn,
scalar_args,
call_name,
name,
out_shapes,
grid,
num_warps,
Expand All @@ -273,6 +273,7 @@ def triton_kernel_call_lowering(
raise NotImplementedError(
"`input_output_aliases` only supported on `jaxlib>=0.3.22")

kernel_call_name = name
args = list(ctx.avals_in)
arg_dtypes = list(map(get_triton_type, ctx.avals_in))
for idx, dtype, v in scalar_args:
Expand Down Expand Up @@ -402,7 +403,7 @@ def prune_configs(configs, named_args):
for input_idx, output_idx in input_output_aliases
)
kernel_call = triton_kernel_call_lib.TritonAutotunedKernelCall(
f"{fn.fn.__name__} ({call_name=}) {named_scalar_args}",
f"{kernel_call_name} ({fn.fn.__name__}) {named_scalar_args}",
[(call, str(config)) for call, config in zip(kernel_calls, configs)],
input_output_aliases_with_sizes,
)
Expand All @@ -413,19 +414,24 @@ def prune_configs(configs, named_args):
ir.RankedTensorType.get(shape.shape, mlir.dtype_to_ir_type(shape.dtype))
for shape in out_shapes
]

call_proto = kernel_call.to_proto(kernel_call_name, serialized_metadata)
return jaxlib.hlo_helpers.custom_call(
call_target_name=call_name,
call_target_name="triton_kernel_call",
out_types=out_types,
operands=array_args,
backend_config=zlib.compress(kernel_call.to_proto(serialized_metadata)),
backend_config=zlib.compress(call_proto),
operand_layouts=avals_to_layouts(ctx.avals_in),
result_layouts=avals_to_layouts(ctx.avals_out),
operand_output_aliases=dict(input_output_aliases),
)


mlir.register_lowering(triton_kernel_call_p, triton_kernel_call_lowering)
xc.register_custom_call_target(
"triton_kernel_call",
triton_kernel_call_lib.get_custom_call(),
platform="CUDA",
)


class ShapeDtype(Protocol):
Expand All @@ -444,7 +450,7 @@ def triton_call(
kernel: triton.JITFunction,
out_shape: Union[ShapeDtype, Sequence[ShapeDtype]],
grid: GridOrLambda,
call_name: str = "triton_kernel_call",
name: str = "",
num_warps: int = 4,
num_stages: int = 2,
input_output_aliases: Optional[Dict[int, int]] = None,
Expand Down Expand Up @@ -539,9 +545,6 @@ def add(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
raise ValueError(
"`triton_call` is only available when `triton` is installed."
)
xc.register_custom_call_target(
call_name, triton_kernel_call_lib.get_custom_call(), platform="CUDA"
)
out_shape = tree_util.tree_map(
lambda a: jax.ShapeDtypeStruct(a.shape, a.dtype), out_shape)
flat_args, _ = tree_util.tree_flatten(args)
Expand All @@ -565,7 +568,7 @@ def add(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
*array_args,
fn=kernel,
scalar_args=tuple(scalar_args),
call_name=call_name,
name=name,
out_shapes=tuple(flat_out_shapes),
grid=grid,
num_warps=num_warps,
Expand Down

0 comments on commit 20d5f22

Please sign in to comment.