diff --git a/src/Ops.jl b/src/Ops.jl index 02558d9737..aaaee500c8 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -2366,14 +2366,18 @@ Applies a reduction function `fn` along the specified `dimensions` of input `x`, result_type = mlir_type(TracedRArray{T,length(reduced_shape)}, reduced_shape) - sample_inputs = [Reactant.ConcretePJRTNumber(T(0)), Reactant.ConcretePJRTNumber(T(0))] + sample_inputs = [ + Reactant.TracedUtils.promote_to(TracedRNumber{T}, 0), + Reactant.TracedUtils.promote_to(TracedRNumber{T}, 0) + ] func = Reactant.TracedUtils.make_mlir_fn( fn, (sample_inputs), (), - "reduce_fn"; + "reduce_fn", + false; args_in_result=:none, return_dialect=:stablehlo, ).f