Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shorten mapreduce by using Ops.reduce #858

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 98 additions & 0 deletions src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2313,4 +2313,102 @@ Produces a [`Reactant.MLIR.Dialects.sdy.sharding_constraint`](@ref) operation wi
end
end

"""
reduce(
x::TracedRArray{T},
init_values::TracedRNumber{T},
dimensions::Vector{Int},
fn::Function,
location=mlir_stacktrace("rand", @__FILE__, @__LINE__),
)

Applies a reduction function `fn` along the specified `dimensions` of input `x`, starting from `init_values`.

# Arguments

- `x`: The input array.
- `init_values`: The initial value.
- `dimensions`: The dimensions to reduce along.
- `fn`: A binary operator.

!!! warning
This reduction operation follows StableHLO semantics. The key difference between this operation and Julia's built-in `reduce` is explained below:

- The function `fn` and the initial value `init_values` must form a **monoid**, meaning:
- `fn` must be an **associative** binary operation.
- `init_values` must be the **identity element** associated with `fn`.
- This constraint ensures consistent results across all implementations.

If `init_values` is not the identity element of `fn`, the results may vary between CPU and GPU executions. For example:

```julia
A = [1 3; 2 4;;; 5 7; 6 8;;; 9 11; 10 12]
init_values = 2
dimensions = [1, 3]
```

- **CPU version & Julia's `reduce`**:
- Reduce along dimension 1 → `[(15) (21); (18) (24)]`
- Reduce along dimension 3 → `[(33 + 2) (45 + 2)]` → `[35 47]`

- **GPU version**:
- Reduce along dimension 1 → `[(15 + 2) (21 + 2); (18 + 2) (24 + 2)]`
- Reduce along dimension 3 → `[37 49]`
"""
@noinline function reduce(
x::TracedRArray{T},
init_values::TracedRNumber{T},
dimensions::Vector{Int},
fn::Function,
location=mlir_stacktrace("reduce", @__FILE__, @__LINE__),
) where {T}
reduced_shape = Tuple(deleteat!(collect(size(x)), dimensions))

result_type = mlir_type(TracedRArray{T,length(reduced_shape)}, reduced_shape)

sample_inputs = [
Reactant.TracedUtils.promote_to(TracedRNumber{T}, 0),
Reactant.TracedUtils.promote_to(TracedRNumber{T}, 0),
]

func =
Reactant.TracedUtils.make_mlir_fn(
fn,
(sample_inputs),
(),
"reduce_fn",
false;
args_in_result=:none,
return_dialect=:stablehlo,
).f
@assert MLIR.IR.nregions(func) == 1
fn_name = String(
MLIR.IR.attr(func, String(MLIR.API.mlirSymbolTableGetSymbolAttributeName()))
)
@assert fn_name == "reduce_fn"
ftype_attr = MLIR.IR.attr(func, "function_type")
ftype = MLIR.IR.Type(ftype_attr)
@assert MLIR.IR.result(ftype) == MLIR.IR.TensorType((), MLIR.IR.Type(T)) error (
"$fn return type is not tensor<i1>"
)
fn = MLIR.IR.Region()
MLIR.API.mlirRegionTakeBody(fn, MLIR.IR.region(func, 1))
MLIR.IR.rmfromparent!(func)

dimensions = MLIR.IR.Attribute(dimensions .- 1)

res = MLIR.IR.result(
stablehlo.reduce(
[x.mlir_data],
[init_values.mlir_data];
result_0=[result_type],
dimensions=dimensions,
body=fn,
location=location,
),
)

return TracedRArray{T,length(reduced_shape)}((), res, reduced_shape)
end

end # module Ops
93 changes: 13 additions & 80 deletions src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -462,103 +462,36 @@ for (jlop, hloop, hlocomp, merge) in
end

function Base.mapreduce(
@nospecialize(f),
f::F, # type parameter so that it recompiles for a different (anonymous) function
@nospecialize(op),
@nospecialize(A::AnyTracedRArray{T,N});
dims=:,
init=nothing,
) where {T,N}
) where {T,N,F}
A = materialize_traced_array(A)

inp = broadcast(f, A)

if dims isa Int
dims = [dims]
elseif dims == (:)
dims = collect(1:N)
else
dims = collect(dims)
end

op_in_T = Core.Compiler.return_type(f, Tuple{T})

if init === nothing
if op === min
init = typemax(op_in_T)
init = typemax(T)
elseif op === max
init = typemin(op_in_T)
else
init = Base.reduce_empty(Base.BottomRF(op), op_in_T)
end

if typeof(init) != op_in_T
op_in_T = typeof(init)
A = typeof(init).(A)
end
end

init = [TracedUtils.broadcast_to_size(init, ()).mlir_data]

inp = [broadcast(f, A).mlir_data]

rdims = Int64[]

if dims == (:)
for i in 0:(N - 1)
push!(rdims, i)
end
else
for i in dims
push!(rdims, i - 1)
end
end

in_tys = [
MLIR.IR.TensorType(Int64[], eltype(MLIR.IR.type(inp[1]))),
MLIR.IR.TensorType(Int64[], eltype(MLIR.IR.type(init[1]))),
]

fnbody = MLIR.IR.Block(in_tys, [MLIR.IR.Location(), MLIR.IR.Location()])

args = (
TracedRNumber{Reactant.unwrapped_eltype(op_in_T)}((), MLIR.IR.argument(fnbody, 1)),
TracedRNumber{Reactant.unwrapped_eltype(op_in_T)}((), MLIR.IR.argument(fnbody, 2)),
)

resty = MLIR.IR.block!(fnbody) do
tmp = TracedUtils.broadcast_to_size(op(args...), ())
Ops.return_(tmp)
return eltype(MLIR.IR.type(tmp.mlir_data))
end

toonedims = Int[]
outdims = Int[]
for i in 1:N
tmp = if in(i - 1, rdims)
1
init = typemin(T)
else
sz = size(A, i)
push!(outdims, sz)
sz
init = Base.reduce_empty(Base.BottomRF(op), T)
end
push!(toonedims, tmp)
end
init = Ops.constant(init)

TT = MLIR.IR.Type[MLIR.IR.TensorType(outdims, resty)]

body = MLIR.IR.Region()
push!(body, fnbody)
red = MLIR.Dialects.stablehlo.reduce(
inp, init; result_0=TT, dimensions=MLIR.IR.DenseArrayAttribute(rdims), body
)

red = MLIR.IR.result(red, 1)
redT = eltype(MLIR.IR.julia_type(MLIR.IR.type(red)))

if dims != (:)
red = Ops.reshape(TracedRArray(red), toonedims...)
else
if length(outdims) == 0
red = TracedRNumber{redT}((), red)
else
red = TracedRArray{redT,length(outdims)}((), red, (outdims...,))
end
end
return red
Comment on lines -552 to -561
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these checks need to be preserved to ensure that the final result matches julia semantics

return Ops.reduce(inp, init, dims, op)
end

function Base.mapreducedim!(
Expand Down
57 changes: 57 additions & 0 deletions test/ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1071,3 +1071,60 @@ end

@test Reactant.@jit(f_multiple_hlo_calls(x_reactant, y_reactant))[1] ≈ (x .+ y) .* y
end

# stablehlo reduce collapse the dimension so that (1,3) beomces (3, )
# while Julia reduce retains (1, 3). The test will fail despite elements being equal
function squeeze_dims(r)
return dropdims(r; dims=tuple(findall(size(r) .== 1)...))
end

@testset "reduce" begin
# Floating point operation is not associative
A = rand(Int64, 3, 4, 5)
A_ra = Reactant.to_rarray(A)
init = 1
init_ra = @jit Reactant.Ops.constant(init)
dims = [2]
r_hlo = @jit Reactant.Ops.reduce(A_ra, init_ra, dims, *)
r = reduce(*, A; dims=dims, init=init)
@test r_hlo ≈ squeeze_dims(r)

dims = [1, 3]
init = 0
init_ra = @jit Reactant.Ops.constant(init)
r_hlo = @jit Reactant.Ops.reduce(A_ra, init_ra, dims, +)
r = reduce(+, A; dims=dims, init=init)
@test r_hlo ≈ squeeze_dims(r)

dims = [1, 2, 3]
r_hlo = @jit Reactant.Ops.reduce(A_ra, init_ra, dims, +)
r = reduce(+, A; dims=dims, init=init)
@test r_hlo ≈ squeeze_dims(r)
end

@testset "mapreduce" begin
A = rand(Float64, 3, 4, 5)
B = rand(Float64, 3, 4, 5)
C = rand(Float64, 3, 4, 5)
D = rand(Float64, 3, 4, 5)

A_ra = Reactant.to_rarray(A)
B_ra = Reactant.to_rarray(B)
C_ra = Reactant.to_rarray(C)
D_ra = Reactant.to_rarray(D)

mr_ra = mapreduce(x -> 3 * x + 1.2, *, A_ra; dims=2)
@test mr_ra ≈ squeeze_dims(mapreduce(x -> 3 * x + 1.2, *, A; dims=2))

mr_ra = mapreduce(x -> x^3, +, A_ra; dims=1:2)
@test mr_ra ≈ squeeze_dims(mapreduce(x -> x^3, +, A; dims=1:2))

mr_ra = mapreduce(x -> 3 * x + 1.2, +, A_ra; dims=:)
@test mr_ra ≈ mapreduce(x -> 3 * x + 1.2, +, A; dims=:)

mr_ra = mapreduce(x -> 3 * x + 1.2, max, A_ra; dims=:)
@test mr_ra ≈ mapreduce(x -> 3 * x + 1.2, max, A; dims=:)

mr_ra = mapreduce(x -> 3 * x + 1.2, min, A_ra; dims=2:3)
@test mr_ra ≈ mapreduce(x -> 3 * x + 1.2, min, A; dims=2:3)
end
Loading