Skip to content

Commit

Permalink
adding reduce to Ops.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
tharittk committed Mar 3, 2025
1 parent 100c9f7 commit 05fb91b
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 1 deletion.
48 changes: 47 additions & 1 deletion src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1483,7 +1483,7 @@ julia> Reactant.@jit(
MLIR.IR.Attribute("private"),
)

# Change function name
# Change function nane
MLIR.IR.attr!(op, symbol_attr_name, MLIR.IR.Attribute(new_name))
end
end
Expand Down Expand Up @@ -2313,4 +2313,50 @@ Produces a [`Reactant.MLIR.Dialects.sdy.sharding_constraint`](@ref) operation wi
end
end

@noinline function reduce(
x::TracedRArray{T},
init_values::TracedRNumber{T},
dimensions::Vector{Int},
fn::Function,
location=mlir_stacktrace("reduce", @__FILE__, @__LINE__)
) where {T}
reduced_shape = Tuple(deleteat!(collect(size(x)), dimensions))

result_type = mlir_type(TracedRArray{T, length(reduced_shape)}, reduced_shape)

sample_inputs = [Reactant.ConcretePJRTNumber(T(0)), Reactant.ConcretePJRTNumber(T(0))]

func =
Reactant.TracedUtils.make_mlir_fn(
fn,
(sample_inputs),
(),
"reduce_fn";
args_in_result=:none,
return_dialect=:stablehlo,
).f
@assert MLIR.IR.nregions(func) == 1
fn_name = String(
MLIR.IR.attr(func, String(MLIR.API.mlirSymbolTableGetSymbolAttributeName()))
)
@assert fn_name == "reduce_fn"
ftype_attr = MLIR.IR.attr(func, "function_type")
ftype = MLIR.IR.Type(ftype_attr)
@assert MLIR.IR.result(ftype) == MLIR.IR.TensorType((), MLIR.IR.Type(T)) error (
"$fn return type is not tensor<i1>"
)
fn = MLIR.IR.Region()
MLIR.API.mlirRegionTakeBody(fn, MLIR.IR.region(func, 1))
MLIR.IR.rmfromparent!(func)

dimensions = MLIR.IR.Attribute(dimensions .- 1)

res = MLIR.IR.result(stablehlo.reduce(
[x.mlir_data], [init_values.mlir_data];
result_0=[result_type], dimensions=dimensions, body=fn, location=location
))

return TracedRArray{T, length(reduced_shape)}((), res, reduced_shape)
end

end # module Ops
32 changes: 32 additions & 0 deletions test/ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1071,3 +1071,35 @@ end

@test Reactant.@jit(f_multiple_hlo_calls(x_reactant, y_reactant))[1] (x .+ y) .* y
end

@testset "reduce" begin
# stablehlo reduce collapse the dimension so that (1,3) beomces (3, )
# while Julia reduce retains (1, 3). The test will fail despite elements being equal
function squeeze_dims(r)
return dropdims(r,dims=tuple(findall(size(r).==1)...))
end

A = rand(3, 4, 5)
A_ra = Reactant.to_rarray(A)
init = 2.1
init_ra = @jit Reactant.Ops.constant(init)

dims = [2]
r_hlo = @jit Reactant.Ops.reduce(A_ra, init_ra, dims, *)
r = reduce(*, A; dims=dims, init=init)
@test r_hlo squeeze_dims(r)

dims = [1,3]
r_hlo = @jit Reactant.Ops.reduce(A_ra, init_ra, dims, +)
r = reduce(+, A; dims=dims, init=init)
@test r_hlo squeeze_dims(r)

dims = [1,2,3]
r_hlo = @jit Reactant.Ops.reduce(A_ra, init_ra, dims, +)
r = reduce(+, A; dims=dims, init=init)
@test r_hlo squeeze_dims(r)

end



0 comments on commit 05fb91b

Please sign in to comment.