Skip to content

Commit

Permalink
Update src/Ops.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal authored Mar 8, 2025
1 parent 25a1124 commit b14a1e1
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2366,14 +2366,18 @@ Applies a reduction function `fn` along the specified `dimensions` of input `x`,

result_type = mlir_type(TracedRArray{T,length(reduced_shape)}, reduced_shape)

sample_inputs = [Reactant.ConcretePJRTNumber(T(0)), Reactant.ConcretePJRTNumber(T(0))]
sample_inputs = [
Reactant.TracedUtils.promote_to(TracedRNumber{T}, 0),
Reactant.TracedUtils.promote_to(TracedRNumber{T}, 0)
]

func =
Reactant.TracedUtils.make_mlir_fn(
fn,
(sample_inputs),
(),
"reduce_fn";
"reduce_fn",
false;
args_in_result=:none,
return_dialect=:stablehlo,
).f
Expand Down

0 comments on commit b14a1e1

Please sign in to comment.