From 6c799bfe88eed6cef068801e8a8149018e0ec6fc Mon Sep 17 00:00:00 2001 From: tharittk Date: Mon, 3 Mar 2025 21:42:24 +0700 Subject: [PATCH 1/8] adding reduce to Ops.jl --- src/Ops.jl | 48 +++++++++++++++++++++++++++++++++++++++++++++++- test/ops.jl | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 79 insertions(+), 1 deletion(-) diff --git a/src/Ops.jl b/src/Ops.jl index b49ca654c..564c27918 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -1483,7 +1483,7 @@ julia> Reactant.@jit( MLIR.IR.Attribute("private"), ) - # Change function name + # Change function nane MLIR.IR.attr!(op, symbol_attr_name, MLIR.IR.Attribute(new_name)) end end @@ -2313,4 +2313,50 @@ Produces a [`Reactant.MLIR.Dialects.sdy.sharding_constraint`](@ref) operation wi end end +@noinline function reduce( + x::TracedRArray{T}, + init_values::TracedRNumber{T}, + dimensions::Vector{Int}, + fn::Function, + location=mlir_stacktrace("reduce", @__FILE__, @__LINE__) +) where {T} + reduced_shape = Tuple(deleteat!(collect(size(x)), dimensions)) + + result_type = mlir_type(TracedRArray{T, length(reduced_shape)}, reduced_shape) + + sample_inputs = [Reactant.ConcretePJRTNumber(T(0)), Reactant.ConcretePJRTNumber(T(0))] + + func = + Reactant.TracedUtils.make_mlir_fn( + fn, + (sample_inputs), + (), + "reduce_fn"; + args_in_result=:none, + return_dialect=:stablehlo, + ).f + @assert MLIR.IR.nregions(func) == 1 + fn_name = String( + MLIR.IR.attr(func, String(MLIR.API.mlirSymbolTableGetSymbolAttributeName())) + ) + @assert fn_name == "reduce_fn" + ftype_attr = MLIR.IR.attr(func, "function_type") + ftype = MLIR.IR.Type(ftype_attr) + @assert MLIR.IR.result(ftype) == MLIR.IR.TensorType((), MLIR.IR.Type(T)) error ( + "$fn return type is not tensor" + ) + fn = MLIR.IR.Region() + MLIR.API.mlirRegionTakeBody(fn, MLIR.IR.region(func, 1)) + MLIR.IR.rmfromparent!(func) + + dimensions = MLIR.IR.Attribute(dimensions .- 1) + + res = MLIR.IR.result(stablehlo.reduce( + [x.mlir_data], [init_values.mlir_data]; + result_0=[result_type], dimensions=dimensions, body=fn, location=location + )) + + return TracedRArray{T, length(reduced_shape)}((), res, reduced_shape) +end + end # module Ops diff --git a/test/ops.jl b/test/ops.jl index 928c4d3d4..96eff6ce9 100644 --- a/test/ops.jl +++ b/test/ops.jl @@ -1071,3 +1071,35 @@ end @test Reactant.@jit(f_multiple_hlo_calls(x_reactant, y_reactant))[1] ≈ (x .+ y) .* y end + +@testset "reduce" begin + # stablehlo reduce collapse the dimension so that (1,3) beomces (3, ) + # while Julia reduce retains (1, 3). The test will fail despite elements being equal + function squeeze_dims(r) + return dropdims(r,dims=tuple(findall(size(r).==1)...)) + end + + A = rand(3, 4, 5) + A_ra = Reactant.to_rarray(A) + init = 2.1 + init_ra = @jit Reactant.Ops.constant(init) + + dims = [2] + r_hlo = @jit Reactant.Ops.reduce(A_ra, init_ra, dims, *) + r = reduce(*, A; dims=dims, init=init) + @test r_hlo ≈ squeeze_dims(r) + + dims = [1,3] + r_hlo = @jit Reactant.Ops.reduce(A_ra, init_ra, dims, +) + r = reduce(+, A; dims=dims, init=init) + @test r_hlo ≈ squeeze_dims(r) + + dims = [1,2,3] + r_hlo = @jit Reactant.Ops.reduce(A_ra, init_ra, dims, +) + r = reduce(+, A; dims=dims, init=init) + @test r_hlo ≈ squeeze_dims(r) + +end + + + From b2b1e32697961cccaba9b60672a89fde509eb5e6 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 3 Mar 2025 17:29:54 -0500 Subject: [PATCH 2/8] Update src/Ops.jl --- src/Ops.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Ops.jl b/src/Ops.jl index 564c27918..eb07ca95a 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -1483,7 +1483,7 @@ julia> Reactant.@jit( MLIR.IR.Attribute("private"), ) - # Change function nane + # Change function name MLIR.IR.attr!(op, symbol_attr_name, MLIR.IR.Attribute(new_name)) end end From 7592effaf8ccce48a437e55b51d2865ab947e7fd Mon Sep 17 00:00:00 2001 From: Tharit Tangkijwanichakul Date: Wed, 5 Mar 2025 08:46:37 +0700 Subject: [PATCH 3/8] change Ops.reduce test case to reflect stablehlo semantics --- test/ops.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/ops.jl b/test/ops.jl index 96eff6ce9..03ea1b31e 100644 --- a/test/ops.jl +++ b/test/ops.jl @@ -1079,17 +1079,19 @@ end return dropdims(r,dims=tuple(findall(size(r).==1)...)) end - A = rand(3, 4, 5) + # Floating point operation is not associative + A = rand(Int64, 3, 4, 5) A_ra = Reactant.to_rarray(A) - init = 2.1 + init = 1 init_ra = @jit Reactant.Ops.constant(init) - dims = [2] r_hlo = @jit Reactant.Ops.reduce(A_ra, init_ra, dims, *) r = reduce(*, A; dims=dims, init=init) @test r_hlo ≈ squeeze_dims(r) dims = [1,3] + init = 0 + init_ra = @jit Reactant.Ops.constant(init) r_hlo = @jit Reactant.Ops.reduce(A_ra, init_ra, dims, +) r = reduce(+, A; dims=dims, init=init) @test r_hlo ≈ squeeze_dims(r) @@ -1101,5 +1103,3 @@ end end - - From a999a96e7d8dad3b4a3c26fa8a7bd06290648ca6 Mon Sep 17 00:00:00 2001 From: tharittk Date: Thu, 6 Mar 2025 09:19:27 +0700 Subject: [PATCH 4/8] Run through formatter --- src/Ops.jl | 26 ++++++++++++++++---------- test/ops.jl | 14 ++++++-------- 2 files changed, 22 insertions(+), 18 deletions(-) diff --git a/src/Ops.jl b/src/Ops.jl index eb07ca95a..9d0822924 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -2318,12 +2318,12 @@ end init_values::TracedRNumber{T}, dimensions::Vector{Int}, fn::Function, - location=mlir_stacktrace("reduce", @__FILE__, @__LINE__) + location=mlir_stacktrace("reduce", @__FILE__, @__LINE__), ) where {T} reduced_shape = Tuple(deleteat!(collect(size(x)), dimensions)) - result_type = mlir_type(TracedRArray{T, length(reduced_shape)}, reduced_shape) - + result_type = mlir_type(TracedRArray{T,length(reduced_shape)}, reduced_shape) + sample_inputs = [Reactant.ConcretePJRTNumber(T(0)), Reactant.ConcretePJRTNumber(T(0))] func = @@ -2348,15 +2348,21 @@ end fn = MLIR.IR.Region() MLIR.API.mlirRegionTakeBody(fn, MLIR.IR.region(func, 1)) MLIR.IR.rmfromparent!(func) - - dimensions = MLIR.IR.Attribute(dimensions .- 1) - res = MLIR.IR.result(stablehlo.reduce( - [x.mlir_data], [init_values.mlir_data]; - result_0=[result_type], dimensions=dimensions, body=fn, location=location - )) + dimensions = MLIR.IR.Attribute(dimensions .- 1) + + res = MLIR.IR.result( + stablehlo.reduce( + [x.mlir_data], + [init_values.mlir_data]; + result_0=[result_type], + dimensions=dimensions, + body=fn, + location=location, + ), + ) - return TracedRArray{T, length(reduced_shape)}((), res, reduced_shape) + return TracedRArray{T,length(reduced_shape)}((), res, reduced_shape) end end # module Ops diff --git a/test/ops.jl b/test/ops.jl index 03ea1b31e..2a0b1bb66 100644 --- a/test/ops.jl +++ b/test/ops.jl @@ -1076,30 +1076,28 @@ end # stablehlo reduce collapse the dimension so that (1,3) beomces (3, ) # while Julia reduce retains (1, 3). The test will fail despite elements being equal function squeeze_dims(r) - return dropdims(r,dims=tuple(findall(size(r).==1)...)) + return dropdims(r; dims=tuple(findall(size(r) .== 1)...)) end # Floating point operation is not associative A = rand(Int64, 3, 4, 5) A_ra = Reactant.to_rarray(A) - init = 1 + init = 1 init_ra = @jit Reactant.Ops.constant(init) - dims = [2] + dims = [2] r_hlo = @jit Reactant.Ops.reduce(A_ra, init_ra, dims, *) r = reduce(*, A; dims=dims, init=init) @test r_hlo ≈ squeeze_dims(r) - dims = [1,3] - init = 0 + dims = [1, 3] + init = 0 init_ra = @jit Reactant.Ops.constant(init) r_hlo = @jit Reactant.Ops.reduce(A_ra, init_ra, dims, +) r = reduce(+, A; dims=dims, init=init) @test r_hlo ≈ squeeze_dims(r) - dims = [1,2,3] + dims = [1, 2, 3] r_hlo = @jit Reactant.Ops.reduce(A_ra, init_ra, dims, +) r = reduce(+, A; dims=dims, init=init) @test r_hlo ≈ squeeze_dims(r) - end - From 85388b05ff578e1722297ba5def89a5bac5b6d13 Mon Sep 17 00:00:00 2001 From: tharittk Date: Thu, 6 Mar 2025 17:02:42 +0700 Subject: [PATCH 5/8] add docstring as comments suggest --- src/Ops.jl | 42 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/src/Ops.jl b/src/Ops.jl index 9d0822924..02558d973 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -2313,6 +2313,48 @@ Produces a [`Reactant.MLIR.Dialects.sdy.sharding_constraint`](@ref) operation wi end end +""" + reduce( + x::TracedRArray{T}, + init_values::TracedRNumber{T}, + dimensions::Vector{Int}, + fn::Function, + location=mlir_stacktrace("rand", @__FILE__, @__LINE__), + ) + +Applies a reduction function `fn` along the specified `dimensions` of input `x`, starting from `init_values`. + +# Arguments + +- `x`: The input array. +- `init_values`: The initial value. +- `dimensions`: The dimensions to reduce along. +- `fn`: A binary operator. + +!!! warning + This reduction operation follows StableHLO semantics. The key difference between this operation and Julia's built-in `reduce` is explained below: + + - The function `fn` and the initial value `init_values` must form a **monoid**, meaning: + - `fn` must be an **associative** binary operation. + - `init_values` must be the **identity element** associated with `fn`. + - This constraint ensures consistent results across all implementations. + + If `init_values` is not the identity element of `fn`, the results may vary between CPU and GPU executions. For example: + + ```julia + A = [1 3; 2 4;;; 5 7; 6 8;;; 9 11; 10 12] + init_values = 2 + dimensions = [1, 3] + ``` + + - **CPU version & Julia's `reduce`**: + - Reduce along dimension 1 → `[(15) (21); (18) (24)]` + - Reduce along dimension 3 → `[(33 + 2) (45 + 2)]` → `[35 47]` + + - **GPU version**: + - Reduce along dimension 1 → `[(15 + 2) (21 + 2); (18 + 2) (24 + 2)]` + - Reduce along dimension 3 → `[37 49]` +""" @noinline function reduce( x::TracedRArray{T}, init_values::TracedRNumber{T}, From 45ebd5ef49dcf163064849b98038d2bc4dcf28e4 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 7 Mar 2025 21:58:20 -0500 Subject: [PATCH 6/8] Update src/Ops.jl --- src/Ops.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/Ops.jl b/src/Ops.jl index 02558d973..aaaee500c 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -2366,14 +2366,18 @@ Applies a reduction function `fn` along the specified `dimensions` of input `x`, result_type = mlir_type(TracedRArray{T,length(reduced_shape)}, reduced_shape) - sample_inputs = [Reactant.ConcretePJRTNumber(T(0)), Reactant.ConcretePJRTNumber(T(0))] + sample_inputs = [ + Reactant.TracedUtils.promote_to(TracedRNumber{T}, 0), + Reactant.TracedUtils.promote_to(TracedRNumber{T}, 0) + ] func = Reactant.TracedUtils.make_mlir_fn( fn, (sample_inputs), (), - "reduce_fn"; + "reduce_fn", + false; args_in_result=:none, return_dialect=:stablehlo, ).f From ab6d1782c0cf9b8bfb7c70f851149f1ad5078e7d Mon Sep 17 00:00:00 2001 From: tharittk Date: Sat, 8 Mar 2025 19:19:44 +0700 Subject: [PATCH 7/8] change mapreduce to use Ops.reduce --- src/TracedRArray.jl | 93 +++++++-------------------------------------- test/ops.jl | 39 ++++++++++++++++--- 2 files changed, 46 insertions(+), 86 deletions(-) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 43d042141..1c2db8bf1 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -462,103 +462,36 @@ for (jlop, hloop, hlocomp, merge) in end function Base.mapreduce( - @nospecialize(f), + f::F, # type parameter so that it recompiles for a different (anonymous) function @nospecialize(op), @nospecialize(A::AnyTracedRArray{T,N}); dims=:, init=nothing, -) where {T,N} +) where {T,N,F} A = materialize_traced_array(A) + inp = broadcast(f, A) + if dims isa Int dims = [dims] + elseif dims == (:) + dims = collect(1:N) + else + dims = collect(dims) end - op_in_T = Core.Compiler.return_type(f, Tuple{T}) - if init === nothing if op === min - init = typemax(op_in_T) + init = typemax(T) elseif op === max - init = typemin(op_in_T) - else - init = Base.reduce_empty(Base.BottomRF(op), op_in_T) - end - - if typeof(init) != op_in_T - op_in_T = typeof(init) - A = typeof(init).(A) - end - end - - init = [TracedUtils.broadcast_to_size(init, ()).mlir_data] - - inp = [broadcast(f, A).mlir_data] - - rdims = Int64[] - - if dims == (:) - for i in 0:(N - 1) - push!(rdims, i) - end - else - for i in dims - push!(rdims, i - 1) - end - end - - in_tys = [ - MLIR.IR.TensorType(Int64[], eltype(MLIR.IR.type(inp[1]))), - MLIR.IR.TensorType(Int64[], eltype(MLIR.IR.type(init[1]))), - ] - - fnbody = MLIR.IR.Block(in_tys, [MLIR.IR.Location(), MLIR.IR.Location()]) - - args = ( - TracedRNumber{Reactant.unwrapped_eltype(op_in_T)}((), MLIR.IR.argument(fnbody, 1)), - TracedRNumber{Reactant.unwrapped_eltype(op_in_T)}((), MLIR.IR.argument(fnbody, 2)), - ) - - resty = MLIR.IR.block!(fnbody) do - tmp = TracedUtils.broadcast_to_size(op(args...), ()) - Ops.return_(tmp) - return eltype(MLIR.IR.type(tmp.mlir_data)) - end - - toonedims = Int[] - outdims = Int[] - for i in 1:N - tmp = if in(i - 1, rdims) - 1 + init = typemin(T) else - sz = size(A, i) - push!(outdims, sz) - sz + init = Base.reduce_empty(Base.BottomRF(op), T) end - push!(toonedims, tmp) end + init = Ops.constant(init) - TT = MLIR.IR.Type[MLIR.IR.TensorType(outdims, resty)] - - body = MLIR.IR.Region() - push!(body, fnbody) - red = MLIR.Dialects.stablehlo.reduce( - inp, init; result_0=TT, dimensions=MLIR.IR.DenseArrayAttribute(rdims), body - ) - - red = MLIR.IR.result(red, 1) - redT = eltype(MLIR.IR.julia_type(MLIR.IR.type(red))) - - if dims != (:) - red = Ops.reshape(TracedRArray(red), toonedims...) - else - if length(outdims) == 0 - red = TracedRNumber{redT}((), red) - else - red = TracedRArray{redT,length(outdims)}((), red, (outdims...,)) - end - end - return red + return Ops.reduce(inp, init, dims, op) end function Base.mapreducedim!( diff --git a/test/ops.jl b/test/ops.jl index 2a0b1bb66..a6a463501 100644 --- a/test/ops.jl +++ b/test/ops.jl @@ -1072,13 +1072,13 @@ end @test Reactant.@jit(f_multiple_hlo_calls(x_reactant, y_reactant))[1] ≈ (x .+ y) .* y end -@testset "reduce" begin - # stablehlo reduce collapse the dimension so that (1,3) beomces (3, ) - # while Julia reduce retains (1, 3). The test will fail despite elements being equal - function squeeze_dims(r) - return dropdims(r; dims=tuple(findall(size(r) .== 1)...)) - end +# stablehlo reduce collapse the dimension so that (1,3) beomces (3, ) +# while Julia reduce retains (1, 3). The test will fail despite elements being equal +function squeeze_dims(r) + return dropdims(r; dims=tuple(findall(size(r) .== 1)...)) +end +@testset "reduce" begin # Floating point operation is not associative A = rand(Int64, 3, 4, 5) A_ra = Reactant.to_rarray(A) @@ -1101,3 +1101,30 @@ end r = reduce(+, A; dims=dims, init=init) @test r_hlo ≈ squeeze_dims(r) end + +@testset "mapreduce" begin + A = rand(Float64, 3, 4, 5) + B = rand(Float64, 3, 4, 5) + C = rand(Float64, 3, 4, 5) + D = rand(Float64, 3, 4, 5) + + A_ra = Reactant.to_rarray(A) + B_ra = Reactant.to_rarray(B) + C_ra = Reactant.to_rarray(C) + D_ra = Reactant.to_rarray(D) + + mr_ra = mapreduce(x -> 3*x+1.2, *, A_ra; dims=2) + @test mr_ra ≈ squeeze_dims(mapreduce(x -> 3*x+1.2, *, A; dims=2)) + + mr_ra = mapreduce(x -> x^3, +, A_ra; dims=1:2) + @test mr_ra ≈ squeeze_dims(mapreduce(x -> x^3, +, A; dims=1:2)) + + mr_ra = mapreduce(x -> 3*x+1.2, +, A_ra; dims=:) + @test mr_ra ≈ mapreduce(x -> 3*x+1.2, +, A; dims=:) + + mr_ra = mapreduce(x -> 3*x+1.2, max, A_ra; dims=:) + @test mr_ra ≈ mapreduce(x -> 3*x+1.2, max, A; dims=:) + + mr_ra = mapreduce(x -> 3*x+1.2, min, A_ra; dims=2:3) + @test mr_ra ≈ mapreduce(x -> 3*x+1.2, min, A; dims=2:3) +end From 50c14aced69728b4a89f11ca901e9f7db2461bc9 Mon Sep 17 00:00:00 2001 From: tharittk Date: Sat, 8 Mar 2025 19:28:55 +0700 Subject: [PATCH 8/8] formatted with JuliaFormatter --- src/Ops.jl | 2 +- test/ops.jl | 16 ++++++++-------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/Ops.jl b/src/Ops.jl index aaaee500c..13d2ecbc7 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -2368,7 +2368,7 @@ Applies a reduction function `fn` along the specified `dimensions` of input `x`, sample_inputs = [ Reactant.TracedUtils.promote_to(TracedRNumber{T}, 0), - Reactant.TracedUtils.promote_to(TracedRNumber{T}, 0) + Reactant.TracedUtils.promote_to(TracedRNumber{T}, 0), ] func = diff --git a/test/ops.jl b/test/ops.jl index a6a463501..0a5ef82bb 100644 --- a/test/ops.jl +++ b/test/ops.jl @@ -1113,18 +1113,18 @@ end C_ra = Reactant.to_rarray(C) D_ra = Reactant.to_rarray(D) - mr_ra = mapreduce(x -> 3*x+1.2, *, A_ra; dims=2) - @test mr_ra ≈ squeeze_dims(mapreduce(x -> 3*x+1.2, *, A; dims=2)) + mr_ra = mapreduce(x -> 3 * x + 1.2, *, A_ra; dims=2) + @test mr_ra ≈ squeeze_dims(mapreduce(x -> 3 * x + 1.2, *, A; dims=2)) mr_ra = mapreduce(x -> x^3, +, A_ra; dims=1:2) @test mr_ra ≈ squeeze_dims(mapreduce(x -> x^3, +, A; dims=1:2)) - mr_ra = mapreduce(x -> 3*x+1.2, +, A_ra; dims=:) - @test mr_ra ≈ mapreduce(x -> 3*x+1.2, +, A; dims=:) + mr_ra = mapreduce(x -> 3 * x + 1.2, +, A_ra; dims=:) + @test mr_ra ≈ mapreduce(x -> 3 * x + 1.2, +, A; dims=:) - mr_ra = mapreduce(x -> 3*x+1.2, max, A_ra; dims=:) - @test mr_ra ≈ mapreduce(x -> 3*x+1.2, max, A; dims=:) + mr_ra = mapreduce(x -> 3 * x + 1.2, max, A_ra; dims=:) + @test mr_ra ≈ mapreduce(x -> 3 * x + 1.2, max, A; dims=:) - mr_ra = mapreduce(x -> 3*x+1.2, min, A_ra; dims=2:3) - @test mr_ra ≈ mapreduce(x -> 3*x+1.2, min, A; dims=2:3) + mr_ra = mapreduce(x -> 3 * x + 1.2, min, A_ra; dims=2:3) + @test mr_ra ≈ mapreduce(x -> 3 * x + 1.2, min, A; dims=2:3) end