Skip to content

Commit

Permalink
Generic Reduced handling in parallel reduce (#172)
Browse files Browse the repository at this point in the history
Previously, `Reduced` did not have a meaningful behavior in parallel
`reduce` when the reducing function is not `right`:

    julia> tcollect(TakeWhile(x -> x < 5), 1:10; basesize=1)
    Empty{Array{T,1} where T}()

    julia> tcollect(TakeWhile(x -> x < 5), 1:10; basesize=2)
    1-element Array{Int64,1}:
     4

    julia> tcollect(TakeWhile(x -> x < 5), 1:10; basesize=4)
    2-element Array{Int64,1}:
     3
     4

    julia> tcollect(TakeWhile(x -> x < 5), 1:10; basesize=5)
    4-element Array{Int64,1}:
     1
     2
     3
     4

This PR fixes it by properly formulating how to execute the reducing
function when combined with `Reduced`.  This is done by "augmenting"
the reducing function `*`:

Given a semigroup `*(::T, ::T) :: T` where `!(Reduced <: T)`, fold
functions in Transducers.jl act on an "augmented" semigroup
`*′(::T′, ::T′) :: T′` where `T′ = Union{T, Reduced{T}}` defined by

    *′(a::Reduced, _) = a
    *′(a::T, b::Reduced) = reduced(a * unreduced(b))
    *′(a::T, b::T) = a * b

If `*` is a monoid with the identity element `e`, the "augmented"
semigroup `*′` is also a monoid with the identity element `e′`.
  • Loading branch information
tkf authored and mergify[bot] committed Jan 20, 2020
1 parent 149fe18 commit 9a84bb3
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 10 deletions.
7 changes: 1 addition & 6 deletions src/dreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,7 @@ function dtransduce(
end
# TODO: Cancel remote computation when there is a Reduced.
results = map(fetch, futures)
i = findfirst(isreduced, results)
i === nothing || return results[i]
c = foldl(results) do a, b
combine(rf, a, b)
end
return complete(rf, c)
return complete(rf, combine_all(rf, results))
end

function load_me_everywhere()
Expand Down
18 changes: 14 additions & 4 deletions src/reduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,9 @@ function _reduce(ctx, rf, init, reducible::Reducible)
a0 = _reduce(fg, rf, init, left)
b0 = fetch(task)
a = @return_if_reduced a0
b = @return_if_reduced b0
should_abort(ctx) && return a # slight optimization
b = unreduced(b0)
b0 isa Reduced && return reduced(combine(rf, a, b))
return combine(rf, a, b)
end
end
Expand Down Expand Up @@ -187,13 +188,22 @@ function _reduce_threads_for(rf, init, reducible::SizedReducible{<:AbstractArray
# `combine` is compute-intensive enough so that launching
# threads is worth enough. Let's merge the `results`
# sequentially for now.
step = combine_step(rf)
return transduce(ensurerf(Completing(step)), Init(step), results)
return combine_all(rf, results)
end
end

function combine_all(rf, results)
step = combine_step(rf)
return transduce(ensurerf(Completing(step)), Init(step), results)
end

combine_step(rf) =
asmonoid((a, b) -> combine(rf, (@return_if_reduced a), (@return_if_reduced b)))
asmonoid() do a0, b0
a = @return_if_reduced a0
b = unreduced(b0)
b0 isa Reduced && return reduced(combine(rf, a, b))
return combine(rf, a, b)
end

# AbstractArray for disambiguation
Base.mapreduce(xform::Transducer, step, itr::AbstractArray;
Expand Down
11 changes: 11 additions & 0 deletions test/test_distributed_reduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,17 @@ end
@test dcopy(Map(makerow), StructVector, 1:3, basesize = 2) == StructVector(a = 1:3)
end

@testset "TakeWhile" begin
fname = gensym(:lessthan5)
@everywhere $fname(x) = x < 5
lessthan5 = getproperty(Main, fname)

coll = 1:10
@testset for basesize in 1:(length(coll)+1)
@test dcollect(TakeWhile(lessthan5), coll; basesize = basesize) == 1:4
end
end

@testset "basesize > 0" begin
@test dcollect(Map(identity), [1]) == [1]
end
Expand Down
7 changes: 7 additions & 0 deletions test/test_parallel_reduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,13 @@ end
end
end

@testset "TakeWhile" begin
coll = 1:10
@testset for basesize in 1:(length(coll)+1)
@test tcollect(TakeWhile(x -> x < 5), coll; basesize = basesize) == 1:4
end
end

@testset "withprogress" begin
xf = Map() do x
x
Expand Down

0 comments on commit 9a84bb3

Please sign in to comment.