Skip to content

Commit 5f9d523

Browse files
authored
feat: add reduce to Ops.jl (EnzymeAD#840)
* adding reduce to Ops.jl * Update src/Ops.jl * change Ops.reduce test case to reflect stablehlo semantics * Run through formatter * add docstring as comments suggest * Update src/Ops.jl * Update src/Ops.jl
1 parent 2d64ecd commit 5f9d523

File tree

2 files changed

+127
-0
lines changed

2 files changed

+127
-0
lines changed

src/Ops.jl

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2313,4 +2313,101 @@ Produces a [`Reactant.MLIR.Dialects.sdy.sharding_constraint`](@ref) operation wi
23132313
end
23142314
end
23152315

2316+
"""
2317+
reduce(
2318+
x::TracedRArray{T},
2319+
init_values::TracedRNumber{T},
2320+
dimensions::Vector{Int},
2321+
fn::Function,
2322+
location=mlir_stacktrace("rand", @__FILE__, @__LINE__),
2323+
)
2324+
2325+
Applies a reduction function `fn` along the specified `dimensions` of input `x`, starting from `init_values`.
2326+
2327+
# Arguments
2328+
2329+
- `x`: The input array.
2330+
- `init_values`: The initial value.
2331+
- `dimensions`: The dimensions to reduce along.
2332+
- `fn`: A binary operator.
2333+
2334+
!!! warning
2335+
This reduction operation follows StableHLO semantics. The key difference between this operation and Julia's built-in `reduce` is explained below:
2336+
2337+
- The function `fn` and the initial value `init_values` must form a **monoid**, meaning:
2338+
- `fn` must be an **associative** binary operation.
2339+
- `init_values` must be the **identity element** associated with `fn`.
2340+
- This constraint ensures consistent results across all implementations.
2341+
2342+
If `init_values` is not the identity element of `fn`, the results may vary between CPU and GPU executions. For example:
2343+
2344+
```julia
2345+
A = [1 3; 2 4;;; 5 7; 6 8;;; 9 11; 10 12]
2346+
init_values = 2
2347+
dimensions = [1, 3]
2348+
```
2349+
2350+
- **CPU version & Julia's `reduce`**:
2351+
- Reduce along dimension 1 → `[(15) (21); (18) (24)]`
2352+
- Reduce along dimension 3 → `[(33 + 2) (45 + 2)]` → `[35 47]`
2353+
2354+
- **GPU version**:
2355+
- Reduce along dimension 1 → `[(15 + 2) (21 + 2); (18 + 2) (24 + 2)]`
2356+
- Reduce along dimension 3 → `[37 49]`
2357+
"""
2358+
@noinline function reduce(
2359+
x::TracedRArray{T},
2360+
init_values::TracedRNumber{T},
2361+
dimensions::Vector{Int},
2362+
fn::Function,
2363+
location=mlir_stacktrace("reduce", @__FILE__, @__LINE__),
2364+
) where {T}
2365+
reduced_shape = Tuple(deleteat!(collect(size(x)), dimensions))
2366+
2367+
result_type = mlir_type(TracedRArray{T,length(reduced_shape)}, reduced_shape)
2368+
2369+
sample_inputs = [
2370+
Reactant.TracedUtils.promote_to(TracedRNumber{T}, 0),
2371+
Reactant.TracedUtils.promote_to(TracedRNumber{T}, 0)
2372+
]
2373+
2374+
func =
2375+
Reactant.TracedUtils.make_mlir_fn(
2376+
fn,
2377+
(sample_inputs),
2378+
(),
2379+
"reduce_fn",
2380+
false;
2381+
args_in_result=:none,
2382+
return_dialect=:stablehlo,
2383+
).f
2384+
@assert MLIR.IR.nregions(func) == 1
2385+
fn_name = String(
2386+
MLIR.IR.attr(func, String(MLIR.API.mlirSymbolTableGetSymbolAttributeName()))
2387+
)
2388+
ftype_attr = MLIR.IR.attr(func, "function_type")
2389+
ftype = MLIR.IR.Type(ftype_attr)
2390+
@assert MLIR.IR.result(ftype) == MLIR.IR.TensorType((), MLIR.IR.Type(T)) error (
2391+
"$fn return type is not tensor<i1>"
2392+
)
2393+
fn = MLIR.IR.Region()
2394+
MLIR.API.mlirRegionTakeBody(fn, MLIR.IR.region(func, 1))
2395+
MLIR.IR.rmfromparent!(func)
2396+
2397+
dimensions = MLIR.IR.Attribute(dimensions .- 1)
2398+
2399+
res = MLIR.IR.result(
2400+
stablehlo.reduce(
2401+
[x.mlir_data],
2402+
[init_values.mlir_data];
2403+
result_0=[result_type],
2404+
dimensions=dimensions,
2405+
body=fn,
2406+
location=location,
2407+
),
2408+
)
2409+
2410+
return TracedRArray{T,length(reduced_shape)}((), res, reduced_shape)
2411+
end
2412+
23162413
end # module Ops

test/ops.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1071,3 +1071,33 @@ end
10711071

10721072
@test Reactant.@jit(f_multiple_hlo_calls(x_reactant, y_reactant))[1] (x .+ y) .* y
10731073
end
1074+
1075+
@testset "reduce" begin
1076+
# stablehlo reduce collapse the dimension so that (1,3) beomces (3, )
1077+
# while Julia reduce retains (1, 3). The test will fail despite elements being equal
1078+
function squeeze_dims(r)
1079+
return dropdims(r; dims=tuple(findall(size(r) .== 1)...))
1080+
end
1081+
1082+
# Floating point operation is not associative
1083+
A = rand(Int64, 3, 4, 5)
1084+
A_ra = Reactant.to_rarray(A)
1085+
init = 1
1086+
init_ra = @jit Reactant.Ops.constant(init)
1087+
dims = [2]
1088+
r_hlo = @jit Reactant.Ops.reduce(A_ra, init_ra, dims, *)
1089+
r = reduce(*, A; dims=dims, init=init)
1090+
@test r_hlo squeeze_dims(r)
1091+
1092+
dims = [1, 3]
1093+
init = 0
1094+
init_ra = @jit Reactant.Ops.constant(init)
1095+
r_hlo = @jit Reactant.Ops.reduce(A_ra, init_ra, dims, +)
1096+
r = reduce(+, A; dims=dims, init=init)
1097+
@test r_hlo squeeze_dims(r)
1098+
1099+
dims = [1, 2, 3]
1100+
r_hlo = @jit Reactant.Ops.reduce(A_ra, init_ra, dims, +)
1101+
r = reduce(+, A; dims=dims, init=init)
1102+
@test r_hlo squeeze_dims(r)
1103+
end

0 commit comments

Comments
 (0)