Skip to content

Commit

Permalink
Align the custom_call implementation in mlir and hlo_helpers.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 562137330
  • Loading branch information
gnecula authored and The jax_triton Authors committed Sep 2, 2023
1 parent c70d4ed commit d01db1f
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions jax_triton/triton_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,13 +419,13 @@ def prune_configs(configs, named_args):
call_proto = kernel_call.to_proto(serialized_metadata)
return jaxlib.hlo_helpers.custom_call(
call_target_name="triton_kernel_call",
out_types=out_types,
result_types=out_types,
operands=array_args,
backend_config=zlib.compress(call_proto),
operand_layouts=avals_to_layouts(ctx.avals_in),
result_layouts=avals_to_layouts(ctx.avals_out),
operand_output_aliases=dict(input_output_aliases),
)
).results


mlir.register_lowering(triton_kernel_call_p, triton_kernel_call_lowering)
Expand Down

0 comments on commit d01db1f

Please sign in to comment.