diff --git a/src/Ops.jl b/src/Ops.jl index b49ca654c..13d2ecbc7 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -2313,4 +2313,102 @@ Produces a [`Reactant.MLIR.Dialects.sdy.sharding_constraint`](@ref) operation wi end end +""" + reduce( + x::TracedRArray{T}, + init_values::TracedRNumber{T}, + dimensions::Vector{Int}, + fn::Function, + location=mlir_stacktrace("rand", @__FILE__, @__LINE__), + ) + +Applies a reduction function `fn` along the specified `dimensions` of input `x`, starting from `init_values`. + +# Arguments + +- `x`: The input array. +- `init_values`: The initial value. +- `dimensions`: The dimensions to reduce along. +- `fn`: A binary operator. + +!!! warning + This reduction operation follows StableHLO semantics. The key difference between this operation and Julia's built-in `reduce` is explained below: + + - The function `fn` and the initial value `init_values` must form a **monoid**, meaning: + - `fn` must be an **associative** binary operation. + - `init_values` must be the **identity element** associated with `fn`. + - This constraint ensures consistent results across all implementations. + + If `init_values` is not the identity element of `fn`, the results may vary between CPU and GPU executions. For example: + + ```julia + A = [1 3; 2 4;;; 5 7; 6 8;;; 9 11; 10 12] + init_values = 2 + dimensions = [1, 3] + ``` + + - **CPU version & Julia's `reduce`**: + - Reduce along dimension 1 → `[(15) (21); (18) (24)]` + - Reduce along dimension 3 → `[(33 + 2) (45 + 2)]` → `[35 47]` + + - **GPU version**: + - Reduce along dimension 1 → `[(15 + 2) (21 + 2); (18 + 2) (24 + 2)]` + - Reduce along dimension 3 → `[37 49]` +""" +@noinline function reduce( + x::TracedRArray{T}, + init_values::TracedRNumber{T}, + dimensions::Vector{Int}, + fn::Function, + location=mlir_stacktrace("reduce", @__FILE__, @__LINE__), +) where {T} + reduced_shape = Tuple(deleteat!(collect(size(x)), dimensions)) + + result_type = mlir_type(TracedRArray{T,length(reduced_shape)}, reduced_shape) + + sample_inputs = [ + Reactant.TracedUtils.promote_to(TracedRNumber{T}, 0), + Reactant.TracedUtils.promote_to(TracedRNumber{T}, 0), + ] + + func = + Reactant.TracedUtils.make_mlir_fn( + fn, + (sample_inputs), + (), + "reduce_fn", + false; + args_in_result=:none, + return_dialect=:stablehlo, + ).f + @assert MLIR.IR.nregions(func) == 1 + fn_name = String( + MLIR.IR.attr(func, String(MLIR.API.mlirSymbolTableGetSymbolAttributeName())) + ) + @assert fn_name == "reduce_fn" + ftype_attr = MLIR.IR.attr(func, "function_type") + ftype = MLIR.IR.Type(ftype_attr) + @assert MLIR.IR.result(ftype) == MLIR.IR.TensorType((), MLIR.IR.Type(T)) error ( + "$fn return type is not tensor" + ) + fn = MLIR.IR.Region() + MLIR.API.mlirRegionTakeBody(fn, MLIR.IR.region(func, 1)) + MLIR.IR.rmfromparent!(func) + + dimensions = MLIR.IR.Attribute(dimensions .- 1) + + res = MLIR.IR.result( + stablehlo.reduce( + [x.mlir_data], + [init_values.mlir_data]; + result_0=[result_type], + dimensions=dimensions, + body=fn, + location=location, + ), + ) + + return TracedRArray{T,length(reduced_shape)}((), res, reduced_shape) +end + end # module Ops diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 43d042141..1c2db8bf1 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -462,103 +462,36 @@ for (jlop, hloop, hlocomp, merge) in end function Base.mapreduce( - @nospecialize(f), + f::F, # type parameter so that it recompiles for a different (anonymous) function @nospecialize(op), @nospecialize(A::AnyTracedRArray{T,N}); dims=:, init=nothing, -) where {T,N} +) where {T,N,F} A = materialize_traced_array(A) + inp = broadcast(f, A) + if dims isa Int dims = [dims] + elseif dims == (:) + dims = collect(1:N) + else + dims = collect(dims) end - op_in_T = Core.Compiler.return_type(f, Tuple{T}) - if init === nothing if op === min - init = typemax(op_in_T) + init = typemax(T) elseif op === max - init = typemin(op_in_T) - else - init = Base.reduce_empty(Base.BottomRF(op), op_in_T) - end - - if typeof(init) != op_in_T - op_in_T = typeof(init) - A = typeof(init).(A) - end - end - - init = [TracedUtils.broadcast_to_size(init, ()).mlir_data] - - inp = [broadcast(f, A).mlir_data] - - rdims = Int64[] - - if dims == (:) - for i in 0:(N - 1) - push!(rdims, i) - end - else - for i in dims - push!(rdims, i - 1) - end - end - - in_tys = [ - MLIR.IR.TensorType(Int64[], eltype(MLIR.IR.type(inp[1]))), - MLIR.IR.TensorType(Int64[], eltype(MLIR.IR.type(init[1]))), - ] - - fnbody = MLIR.IR.Block(in_tys, [MLIR.IR.Location(), MLIR.IR.Location()]) - - args = ( - TracedRNumber{Reactant.unwrapped_eltype(op_in_T)}((), MLIR.IR.argument(fnbody, 1)), - TracedRNumber{Reactant.unwrapped_eltype(op_in_T)}((), MLIR.IR.argument(fnbody, 2)), - ) - - resty = MLIR.IR.block!(fnbody) do - tmp = TracedUtils.broadcast_to_size(op(args...), ()) - Ops.return_(tmp) - return eltype(MLIR.IR.type(tmp.mlir_data)) - end - - toonedims = Int[] - outdims = Int[] - for i in 1:N - tmp = if in(i - 1, rdims) - 1 + init = typemin(T) else - sz = size(A, i) - push!(outdims, sz) - sz + init = Base.reduce_empty(Base.BottomRF(op), T) end - push!(toonedims, tmp) end + init = Ops.constant(init) - TT = MLIR.IR.Type[MLIR.IR.TensorType(outdims, resty)] - - body = MLIR.IR.Region() - push!(body, fnbody) - red = MLIR.Dialects.stablehlo.reduce( - inp, init; result_0=TT, dimensions=MLIR.IR.DenseArrayAttribute(rdims), body - ) - - red = MLIR.IR.result(red, 1) - redT = eltype(MLIR.IR.julia_type(MLIR.IR.type(red))) - - if dims != (:) - red = Ops.reshape(TracedRArray(red), toonedims...) - else - if length(outdims) == 0 - red = TracedRNumber{redT}((), red) - else - red = TracedRArray{redT,length(outdims)}((), red, (outdims...,)) - end - end - return red + return Ops.reduce(inp, init, dims, op) end function Base.mapreducedim!( diff --git a/test/ops.jl b/test/ops.jl index 928c4d3d4..0a5ef82bb 100644 --- a/test/ops.jl +++ b/test/ops.jl @@ -1071,3 +1071,60 @@ end @test Reactant.@jit(f_multiple_hlo_calls(x_reactant, y_reactant))[1] ≈ (x .+ y) .* y end + +# stablehlo reduce collapse the dimension so that (1,3) beomces (3, ) +# while Julia reduce retains (1, 3). The test will fail despite elements being equal +function squeeze_dims(r) + return dropdims(r; dims=tuple(findall(size(r) .== 1)...)) +end + +@testset "reduce" begin + # Floating point operation is not associative + A = rand(Int64, 3, 4, 5) + A_ra = Reactant.to_rarray(A) + init = 1 + init_ra = @jit Reactant.Ops.constant(init) + dims = [2] + r_hlo = @jit Reactant.Ops.reduce(A_ra, init_ra, dims, *) + r = reduce(*, A; dims=dims, init=init) + @test r_hlo ≈ squeeze_dims(r) + + dims = [1, 3] + init = 0 + init_ra = @jit Reactant.Ops.constant(init) + r_hlo = @jit Reactant.Ops.reduce(A_ra, init_ra, dims, +) + r = reduce(+, A; dims=dims, init=init) + @test r_hlo ≈ squeeze_dims(r) + + dims = [1, 2, 3] + r_hlo = @jit Reactant.Ops.reduce(A_ra, init_ra, dims, +) + r = reduce(+, A; dims=dims, init=init) + @test r_hlo ≈ squeeze_dims(r) +end + +@testset "mapreduce" begin + A = rand(Float64, 3, 4, 5) + B = rand(Float64, 3, 4, 5) + C = rand(Float64, 3, 4, 5) + D = rand(Float64, 3, 4, 5) + + A_ra = Reactant.to_rarray(A) + B_ra = Reactant.to_rarray(B) + C_ra = Reactant.to_rarray(C) + D_ra = Reactant.to_rarray(D) + + mr_ra = mapreduce(x -> 3 * x + 1.2, *, A_ra; dims=2) + @test mr_ra ≈ squeeze_dims(mapreduce(x -> 3 * x + 1.2, *, A; dims=2)) + + mr_ra = mapreduce(x -> x^3, +, A_ra; dims=1:2) + @test mr_ra ≈ squeeze_dims(mapreduce(x -> x^3, +, A; dims=1:2)) + + mr_ra = mapreduce(x -> 3 * x + 1.2, +, A_ra; dims=:) + @test mr_ra ≈ mapreduce(x -> 3 * x + 1.2, +, A; dims=:) + + mr_ra = mapreduce(x -> 3 * x + 1.2, max, A_ra; dims=:) + @test mr_ra ≈ mapreduce(x -> 3 * x + 1.2, max, A; dims=:) + + mr_ra = mapreduce(x -> 3 * x + 1.2, min, A_ra; dims=2:3) + @test mr_ra ≈ mapreduce(x -> 3 * x + 1.2, min, A; dims=2:3) +end