diff --git a/src/dreduce.jl b/src/dreduce.jl index 5327254b86..7bee309565 100644 --- a/src/dreduce.jl +++ b/src/dreduce.jl @@ -71,12 +71,7 @@ function dtransduce( end # TODO: Cancel remote computation when there is a Reduced. results = map(fetch, futures) - i = findfirst(isreduced, results) - i === nothing || return results[i] - c = foldl(results) do a, b - combine(rf, a, b) - end - return complete(rf, c) + return complete(rf, combine_all(rf, results)) end function load_me_everywhere() diff --git a/src/reduce.jl b/src/reduce.jl index 1044ba5c49..a1af436603 100644 --- a/src/reduce.jl +++ b/src/reduce.jl @@ -157,8 +157,9 @@ function _reduce(ctx, rf, init, reducible::Reducible) a0 = _reduce(fg, rf, init, left) b0 = fetch(task) a = @return_if_reduced a0 - b = @return_if_reduced b0 should_abort(ctx) && return a # slight optimization + b = unreduced(b0) + b0 isa Reduced && return reduced(combine(rf, a, b)) return combine(rf, a, b) end end @@ -187,13 +188,22 @@ function _reduce_threads_for(rf, init, reducible::SizedReducible{<:AbstractArray # `combine` is compute-intensive enough so that launching # threads is worth enough. Let's merge the `results` # sequentially for now. - step = combine_step(rf) - return transduce(ensurerf(Completing(step)), Init(step), results) + return combine_all(rf, results) end end +function combine_all(rf, results) + step = combine_step(rf) + return transduce(ensurerf(Completing(step)), Init(step), results) +end + combine_step(rf) = - asmonoid((a, b) -> combine(rf, (@return_if_reduced a), (@return_if_reduced b))) + asmonoid() do a0, b0 + a = @return_if_reduced a0 + b = unreduced(b0) + b0 isa Reduced && return reduced(combine(rf, a, b)) + return combine(rf, a, b) + end # AbstractArray for disambiguation Base.mapreduce(xform::Transducer, step, itr::AbstractArray; diff --git a/test/test_distributed_reduce.jl b/test/test_distributed_reduce.jl index 5930b522ad..4865c966e9 100644 --- a/test/test_distributed_reduce.jl +++ b/test/test_distributed_reduce.jl @@ -46,6 +46,17 @@ end @test dcopy(Map(makerow), StructVector, 1:3, basesize = 2) == StructVector(a = 1:3) end +@testset "TakeWhile" begin + fname = gensym(:lessthan5) + @everywhere $fname(x) = x < 5 + lessthan5 = getproperty(Main, fname) + + coll = 1:10 + @testset for basesize in 1:(length(coll)+1) + @test dcollect(TakeWhile(lessthan5), coll; basesize = basesize) == 1:4 + end +end + @testset "basesize > 0" begin @test dcollect(Map(identity), [1]) == [1] end diff --git a/test/test_parallel_reduce.jl b/test/test_parallel_reduce.jl index be98c86f78..310ea57dec 100644 --- a/test/test_parallel_reduce.jl +++ b/test/test_parallel_reduce.jl @@ -120,6 +120,13 @@ end end end +@testset "TakeWhile" begin + coll = 1:10 + @testset for basesize in 1:(length(coll)+1) + @test tcollect(TakeWhile(x -> x < 5), coll; basesize = basesize) == 1:4 + end +end + @testset "withprogress" begin xf = Map() do x x