From 9a84bb320ff5787a9e083aceef532c94e5b99c72 Mon Sep 17 00:00:00 2001 From: Takafumi Arakaki Date: Sun, 19 Jan 2020 17:40:14 -0800 Subject: [PATCH] Generic `Reduced` handling in parallel reduce (#172) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Previously, `Reduced` did not have a meaningful behavior in parallel `reduce` when the reducing function is not `right`: julia> tcollect(TakeWhile(x -> x < 5), 1:10; basesize=1) Empty{Array{T,1} where T}() julia> tcollect(TakeWhile(x -> x < 5), 1:10; basesize=2) 1-element Array{Int64,1}: 4 julia> tcollect(TakeWhile(x -> x < 5), 1:10; basesize=4) 2-element Array{Int64,1}: 3 4 julia> tcollect(TakeWhile(x -> x < 5), 1:10; basesize=5) 4-element Array{Int64,1}: 1 2 3 4 This PR fixes it by properly formulating how to execute the reducing function when combined with `Reduced`. This is done by "augmenting" the reducing function `*`: Given a semigroup `*(::T, ::T) :: T` where `!(Reduced <: T)`, fold functions in Transducers.jl act on an "augmented" semigroup `*′(::T′, ::T′) :: T′` where `T′ = Union{T, Reduced{T}}` defined by *′(a::Reduced, _) = a *′(a::T, b::Reduced) = reduced(a * unreduced(b)) *′(a::T, b::T) = a * b If `*` is a monoid with the identity element `e`, the "augmented" semigroup `*′` is also a monoid with the identity element `e′`. --- src/dreduce.jl | 7 +------ src/reduce.jl | 18 ++++++++++++++---- test/test_distributed_reduce.jl | 11 +++++++++++ test/test_parallel_reduce.jl | 7 +++++++ 4 files changed, 33 insertions(+), 10 deletions(-) diff --git a/src/dreduce.jl b/src/dreduce.jl index 5327254b86..7bee309565 100644 --- a/src/dreduce.jl +++ b/src/dreduce.jl @@ -71,12 +71,7 @@ function dtransduce( end # TODO: Cancel remote computation when there is a Reduced. results = map(fetch, futures) - i = findfirst(isreduced, results) - i === nothing || return results[i] - c = foldl(results) do a, b - combine(rf, a, b) - end - return complete(rf, c) + return complete(rf, combine_all(rf, results)) end function load_me_everywhere() diff --git a/src/reduce.jl b/src/reduce.jl index 1044ba5c49..a1af436603 100644 --- a/src/reduce.jl +++ b/src/reduce.jl @@ -157,8 +157,9 @@ function _reduce(ctx, rf, init, reducible::Reducible) a0 = _reduce(fg, rf, init, left) b0 = fetch(task) a = @return_if_reduced a0 - b = @return_if_reduced b0 should_abort(ctx) && return a # slight optimization + b = unreduced(b0) + b0 isa Reduced && return reduced(combine(rf, a, b)) return combine(rf, a, b) end end @@ -187,13 +188,22 @@ function _reduce_threads_for(rf, init, reducible::SizedReducible{<:AbstractArray # `combine` is compute-intensive enough so that launching # threads is worth enough. Let's merge the `results` # sequentially for now. - step = combine_step(rf) - return transduce(ensurerf(Completing(step)), Init(step), results) + return combine_all(rf, results) end end +function combine_all(rf, results) + step = combine_step(rf) + return transduce(ensurerf(Completing(step)), Init(step), results) +end + combine_step(rf) = - asmonoid((a, b) -> combine(rf, (@return_if_reduced a), (@return_if_reduced b))) + asmonoid() do a0, b0 + a = @return_if_reduced a0 + b = unreduced(b0) + b0 isa Reduced && return reduced(combine(rf, a, b)) + return combine(rf, a, b) + end # AbstractArray for disambiguation Base.mapreduce(xform::Transducer, step, itr::AbstractArray; diff --git a/test/test_distributed_reduce.jl b/test/test_distributed_reduce.jl index 5930b522ad..4865c966e9 100644 --- a/test/test_distributed_reduce.jl +++ b/test/test_distributed_reduce.jl @@ -46,6 +46,17 @@ end @test dcopy(Map(makerow), StructVector, 1:3, basesize = 2) == StructVector(a = 1:3) end +@testset "TakeWhile" begin + fname = gensym(:lessthan5) + @everywhere $fname(x) = x < 5 + lessthan5 = getproperty(Main, fname) + + coll = 1:10 + @testset for basesize in 1:(length(coll)+1) + @test dcollect(TakeWhile(lessthan5), coll; basesize = basesize) == 1:4 + end +end + @testset "basesize > 0" begin @test dcollect(Map(identity), [1]) == [1] end diff --git a/test/test_parallel_reduce.jl b/test/test_parallel_reduce.jl index be98c86f78..310ea57dec 100644 --- a/test/test_parallel_reduce.jl +++ b/test/test_parallel_reduce.jl @@ -120,6 +120,13 @@ end end end +@testset "TakeWhile" begin + coll = 1:10 + @testset for basesize in 1:(length(coll)+1) + @test tcollect(TakeWhile(x -> x < 5), coll; basesize = basesize) == 1:4 + end +end + @testset "withprogress" begin xf = Map() do x x