Skip to content

Commit

Permalink
feat: add reduce to Ops.jl (#840)
Browse files Browse the repository at this point in the history
* adding reduce to Ops.jl

* Update src/Ops.jl

* change Ops.reduce test case to reflect stablehlo semantics

* Run through formatter

* add docstring as comments suggest

* Update src/Ops.jl

* Update src/Ops.jl
  • Loading branch information
tharittk authored Mar 8, 2025
1 parent 2d64ecd commit 5f9d523
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 0 deletions.
97 changes: 97 additions & 0 deletions src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2313,4 +2313,101 @@ Produces a [`Reactant.MLIR.Dialects.sdy.sharding_constraint`](@ref) operation wi
end
end

"""
reduce(
x::TracedRArray{T},
init_values::TracedRNumber{T},
dimensions::Vector{Int},
fn::Function,
location=mlir_stacktrace("rand", @__FILE__, @__LINE__),
)
Applies a reduction function `fn` along the specified `dimensions` of input `x`, starting from `init_values`.
# Arguments
- `x`: The input array.
- `init_values`: The initial value.
- `dimensions`: The dimensions to reduce along.
- `fn`: A binary operator.
!!! warning
This reduction operation follows StableHLO semantics. The key difference between this operation and Julia's built-in `reduce` is explained below:
- The function `fn` and the initial value `init_values` must form a **monoid**, meaning:
- `fn` must be an **associative** binary operation.
- `init_values` must be the **identity element** associated with `fn`.
- This constraint ensures consistent results across all implementations.
If `init_values` is not the identity element of `fn`, the results may vary between CPU and GPU executions. For example:
```julia
A = [1 3; 2 4;;; 5 7; 6 8;;; 9 11; 10 12]
init_values = 2
dimensions = [1, 3]
```
- **CPU version & Julia's `reduce`**:
- Reduce along dimension 1 → `[(15) (21); (18) (24)]`
- Reduce along dimension 3 → `[(33 + 2) (45 + 2)]` → `[35 47]`
- **GPU version**:
- Reduce along dimension 1 → `[(15 + 2) (21 + 2); (18 + 2) (24 + 2)]`
- Reduce along dimension 3 → `[37 49]`
"""
@noinline function reduce(
x::TracedRArray{T},
init_values::TracedRNumber{T},
dimensions::Vector{Int},
fn::Function,
location=mlir_stacktrace("reduce", @__FILE__, @__LINE__),
) where {T}
reduced_shape = Tuple(deleteat!(collect(size(x)), dimensions))

result_type = mlir_type(TracedRArray{T,length(reduced_shape)}, reduced_shape)

sample_inputs = [
Reactant.TracedUtils.promote_to(TracedRNumber{T}, 0),
Reactant.TracedUtils.promote_to(TracedRNumber{T}, 0)
]

func =
Reactant.TracedUtils.make_mlir_fn(
fn,
(sample_inputs),
(),
"reduce_fn",
false;
args_in_result=:none,
return_dialect=:stablehlo,
).f
@assert MLIR.IR.nregions(func) == 1
fn_name = String(
MLIR.IR.attr(func, String(MLIR.API.mlirSymbolTableGetSymbolAttributeName()))
)
ftype_attr = MLIR.IR.attr(func, "function_type")
ftype = MLIR.IR.Type(ftype_attr)
@assert MLIR.IR.result(ftype) == MLIR.IR.TensorType((), MLIR.IR.Type(T)) error (
"$fn return type is not tensor<i1>"
)
fn = MLIR.IR.Region()
MLIR.API.mlirRegionTakeBody(fn, MLIR.IR.region(func, 1))
MLIR.IR.rmfromparent!(func)

dimensions = MLIR.IR.Attribute(dimensions .- 1)

res = MLIR.IR.result(
stablehlo.reduce(
[x.mlir_data],
[init_values.mlir_data];
result_0=[result_type],
dimensions=dimensions,
body=fn,
location=location,
),
)

return TracedRArray{T,length(reduced_shape)}((), res, reduced_shape)
end

end # module Ops
30 changes: 30 additions & 0 deletions test/ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1071,3 +1071,33 @@ end

@test Reactant.@jit(f_multiple_hlo_calls(x_reactant, y_reactant))[1] (x .+ y) .* y
end

@testset "reduce" begin
# stablehlo reduce collapse the dimension so that (1,3) beomces (3, )
# while Julia reduce retains (1, 3). The test will fail despite elements being equal
function squeeze_dims(r)
return dropdims(r; dims=tuple(findall(size(r) .== 1)...))
end

# Floating point operation is not associative
A = rand(Int64, 3, 4, 5)
A_ra = Reactant.to_rarray(A)
init = 1
init_ra = @jit Reactant.Ops.constant(init)
dims = [2]
r_hlo = @jit Reactant.Ops.reduce(A_ra, init_ra, dims, *)
r = reduce(*, A; dims=dims, init=init)
@test r_hlo squeeze_dims(r)

dims = [1, 3]
init = 0
init_ra = @jit Reactant.Ops.constant(init)
r_hlo = @jit Reactant.Ops.reduce(A_ra, init_ra, dims, +)
r = reduce(+, A; dims=dims, init=init)
@test r_hlo squeeze_dims(r)

dims = [1, 2, 3]
r_hlo = @jit Reactant.Ops.reduce(A_ra, init_ra, dims, +)
r = reduce(+, A; dims=dims, init=init)
@test r_hlo squeeze_dims(r)
end

0 comments on commit 5f9d523

Please sign in to comment.