diff --git a/jax_triton/triton_lib.py b/jax_triton/triton_lib.py index 2b14fb42..c25e3dc6 100644 --- a/jax_triton/triton_lib.py +++ b/jax_triton/triton_lib.py @@ -258,7 +258,7 @@ def triton_kernel_call_lowering( *array_args, fn, scalar_args, - call_name, + name, out_shapes, grid, num_warps, @@ -273,6 +273,7 @@ def triton_kernel_call_lowering( raise NotImplementedError( "`input_output_aliases` only supported on `jaxlib>=0.3.22") + kernel_call_name = name args = list(ctx.avals_in) arg_dtypes = list(map(get_triton_type, ctx.avals_in)) for idx, dtype, v in scalar_args: @@ -402,7 +403,7 @@ def prune_configs(configs, named_args): for input_idx, output_idx in input_output_aliases ) kernel_call = triton_kernel_call_lib.TritonAutotunedKernelCall( - f"{fn.fn.__name__} ({call_name=}) {named_scalar_args}", + f"{kernel_call_name} ({fn.fn.__name__}) {named_scalar_args}", [(call, str(config)) for call, config in zip(kernel_calls, configs)], input_output_aliases_with_sizes, ) @@ -413,12 +414,15 @@ def prune_configs(configs, named_args): ir.RankedTensorType.get(shape.shape, mlir.dtype_to_ir_type(shape.dtype)) for shape in out_shapes ] - + if jaxlib.version.__version_info__ >= (0, 4, 15): + call_proto = kernel_call.to_proto(kernel_call_name, serialized_metadata) + else: + call_proto = kernel_call.to_proto(serialized_metadata) return jaxlib.hlo_helpers.custom_call( - call_target_name=call_name, + call_target_name="triton_kernel_call", out_types=out_types, operands=array_args, - backend_config=zlib.compress(kernel_call.to_proto(serialized_metadata)), + backend_config=zlib.compress(call_proto), operand_layouts=avals_to_layouts(ctx.avals_in), result_layouts=avals_to_layouts(ctx.avals_out), operand_output_aliases=dict(input_output_aliases), @@ -444,7 +448,7 @@ def triton_call( kernel: triton.JITFunction, out_shape: Union[ShapeDtype, Sequence[ShapeDtype]], grid: GridOrLambda, - call_name: str = "triton_kernel_call", + name: str = "", num_warps: int = 4, num_stages: int = 2, input_output_aliases: Optional[Dict[int, int]] = None, @@ -539,9 +543,6 @@ def add(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: raise ValueError( "`triton_call` is only available when `triton` is installed." ) - xc.register_custom_call_target( - call_name, triton_kernel_call_lib.get_custom_call(), platform="CUDA" - ) out_shape = tree_util.tree_map( lambda a: jax.ShapeDtypeStruct(a.shape, a.dtype), out_shape) flat_args, _ = tree_util.tree_flatten(args) @@ -565,7 +566,7 @@ def add(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: *array_args, fn=kernel, scalar_args=tuple(scalar_args), - call_name=call_name, + name=name, out_shapes=tuple(flat_out_shapes), grid=grid, num_warps=num_warps,