diff --git a/jax_triton/triton_lib.py b/jax_triton/triton_lib.py index 2b14fb4..f736a4c 100644 --- a/jax_triton/triton_lib.py +++ b/jax_triton/triton_lib.py @@ -258,6 +258,7 @@ def triton_kernel_call_lowering( *array_args, fn, scalar_args, + name, call_name, out_shapes, grid, @@ -273,6 +274,7 @@ def triton_kernel_call_lowering( raise NotImplementedError( "`input_output_aliases` only supported on `jaxlib>=0.3.22") + kernel_call_name = name args = list(ctx.avals_in) arg_dtypes = list(map(get_triton_type, ctx.avals_in)) for idx, dtype, v in scalar_args: @@ -413,12 +415,15 @@ def prune_configs(configs, named_args): ir.RankedTensorType.get(shape.shape, mlir.dtype_to_ir_type(shape.dtype)) for shape in out_shapes ] - + if jaxlib.version.__version_info__ >= (0, 4, 15): + call_proto = kernel_call.to_proto(kernel_call_name, serialized_metadata) + else: + call_proto = kernel_call.to_proto(serialized_metadata) return jaxlib.hlo_helpers.custom_call( call_target_name=call_name, out_types=out_types, operands=array_args, - backend_config=zlib.compress(kernel_call.to_proto(serialized_metadata)), + backend_config=zlib.compress(call_proto), operand_layouts=avals_to_layouts(ctx.avals_in), result_layouts=avals_to_layouts(ctx.avals_out), operand_output_aliases=dict(input_output_aliases), @@ -444,7 +449,8 @@ def triton_call( kernel: triton.JITFunction, out_shape: Union[ShapeDtype, Sequence[ShapeDtype]], grid: GridOrLambda, - call_name: str = "triton_kernel_call", + name: str = "", + call_name: str = "triton_kernel_call", # TODO(cjfj): Remove this. num_warps: int = 4, num_stages: int = 2, input_output_aliases: Optional[Dict[int, int]] = None, @@ -565,6 +571,7 @@ def add(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: *array_args, fn=kernel, scalar_args=tuple(scalar_args), + name=name, call_name=call_name, out_shapes=tuple(flat_out_shapes), grid=grid,