diff --git a/jax_triton/triton_lib.py b/jax_triton/triton_lib.py index d4f2cee0..8e678d8f 100644 --- a/jax_triton/triton_lib.py +++ b/jax_triton/triton_lib.py @@ -419,13 +419,13 @@ def prune_configs(configs, named_args): call_proto = kernel_call.to_proto(serialized_metadata) return jaxlib.hlo_helpers.custom_call( call_target_name="triton_kernel_call", - out_types=out_types, + result_types=out_types, operands=array_args, backend_config=zlib.compress(call_proto), operand_layouts=avals_to_layouts(ctx.avals_in), result_layouts=avals_to_layouts(ctx.avals_out), operand_output_aliases=dict(input_output_aliases), - ) + ).results mlir.register_lowering(triton_kernel_call_p, triton_kernel_call_lowering)