Skip to content

Commit ae78e52

Browse files
authored
Merge pull request #659 from mateuszbaran/mbaran/faster-reduction
2 parents 8b77542 + d3623b9 commit ae78e52

File tree

1 file changed

+38
-23
lines changed

1 file changed

+38
-23
lines changed

src/mapreduce.jl

Lines changed: 38 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,23 @@ end
103103
end
104104
end
105105

106-
@inline _mapreduce(f, op, D::Int, nt::NamedTuple, sz::Size{S}, a::StaticArray) where {S} =
107-
_mapreduce(f, op, Val(D), nt, sz, a)
106+
@inline function _mapreduce(f, op, D::Int, nt::NamedTuple, sz::Size{S}, a::StaticArray) where {S}
107+
# Body of this function is split because constant propagation (at least
108+
# as of Julia 1.2) can't always correctly propagate here and
109+
# as a result the function is not type stable and very slow.
110+
# This makes it at least fast for three dimensions but people should use
111+
# for example any(a; dims=Val(1)) instead of any(a; dims=1) anyway.
112+
if D == 1
113+
return _mapreduce(f, op, Val(1), nt, sz, a)
114+
elseif D == 2
115+
return _mapreduce(f, op, Val(2), nt, sz, a)
116+
elseif D == 3
117+
return _mapreduce(f, op, Val(3), nt, sz, a)
118+
else
119+
return _mapreduce(f, op, Val(D), nt, sz, a)
120+
end
121+
end
108122

109-
110123
@generated function _mapreduce(f, op, dims::Val{D}, nt::NamedTuple{()},
111124
::Size{S}, a::StaticArray) where {S,D}
112125
N = length(S)
@@ -161,7 +174,9 @@ end
161174
## reduce ##
162175
############
163176

164-
@inline reduce(op, a::StaticArray; kw...) = mapreduce(identity, op, a; kw...)
177+
@inline reduce(op, a::StaticArray; dims=:, kw...) = _reduce(op, a, dims, kw.data)
178+
179+
@inline _reduce(op, a::StaticArray, dims, kw::NamedTuple=NamedTuple()) = _mapreduce(identity, op, dims, kw, Size(a), a)
165180

166181
#######################
167182
## related functions ##
@@ -186,38 +201,38 @@ end
186201
# TODO: change to use Base.reduce_empty/Base.reduce_first
187202
@inline iszero(a::StaticArray{<:Tuple,T}) where {T} = reduce((x,y) -> x && iszero(y), a, init=true)
188203

189-
@inline sum(a::StaticArray{<:Tuple,T}; dims=:) where {T} = reduce(+, a; dims=dims)
190-
@inline sum(f, a::StaticArray{<:Tuple,T}; dims=:) where {T} = mapreduce(f, +, a; dims=dims)
191-
@inline sum(f::Union{Function, Type}, a::StaticArray{<:Tuple,T}; dims=:) where {T} = mapreduce(f, +, a; dims=dims) # avoid ambiguity
204+
@inline sum(a::StaticArray{<:Tuple,T}; dims=:) where {T} = _reduce(+, a, dims)
205+
@inline sum(f, a::StaticArray{<:Tuple,T}; dims=:) where {T} = _mapreduce(f, +, dims, NamedTuple(), Size(a), a)
206+
@inline sum(f::Union{Function, Type}, a::StaticArray{<:Tuple,T}; dims=:) where {T} = _mapreduce(f, +, dims, NamedTuple(), Size(a), a) # avoid ambiguity
192207

193-
@inline prod(a::StaticArray{<:Tuple,T}; dims=:) where {T} = reduce(*, a; dims=dims)
194-
@inline prod(f, a::StaticArray{<:Tuple,T}; dims=:) where {T} = mapreduce(f, *, a; dims=dims)
195-
@inline prod(f::Union{Function, Type}, a::StaticArray{<:Tuple,T}; dims=:) where {T} = mapreduce(f, *, a; dims=dims)
208+
@inline prod(a::StaticArray{<:Tuple,T}; dims=:) where {T} = _reduce(*, a, dims)
209+
@inline prod(f, a::StaticArray{<:Tuple,T}; dims=:) where {T} = _mapreduce(f, *, dims, NamedTuple(), Size(a), a)
210+
@inline prod(f::Union{Function, Type}, a::StaticArray{<:Tuple,T}; dims=:) where {T} = _mapreduce(f, *, dims, NamedTuple(), Size(a), a)
196211

197-
@inline count(a::StaticArray{<:Tuple,Bool}; dims=:) = reduce(+, a; dims=dims)
198-
@inline count(f, a::StaticArray; dims=:) = mapreduce(x->f(x)::Bool, +, a; dims=dims)
212+
@inline count(a::StaticArray{<:Tuple,Bool}; dims=:) = _reduce(+, a, dims)
213+
@inline count(f, a::StaticArray; dims=:) = _mapreduce(x->f(x)::Bool, +, dims, NamedTuple(), Size(a), a)
199214

200-
@inline all(a::StaticArray{<:Tuple,Bool}; dims=:) = reduce(&, a; dims=dims, init=true) # non-branching versions
201-
@inline all(f::Function, a::StaticArray; dims=:) = mapreduce(x->f(x)::Bool, &, a; dims=dims, init=true)
215+
@inline all(a::StaticArray{<:Tuple,Bool}; dims=:) = _reduce(&, a, dims, (init=true,)) # non-branching versions
216+
@inline all(f::Function, a::StaticArray; dims=:) = _mapreduce(x->f(x)::Bool, &, dims, (init=true,), Size(a), a)
202217

203-
@inline any(a::StaticArray{<:Tuple,Bool}; dims=:) = reduce(|, a; dims=dims, init=false) # (benchmarking needed)
204-
@inline any(f::Function, a::StaticArray; dims=:) = mapreduce(x->f(x)::Bool, |, a; dims=dims, init=false) # (benchmarking needed)
218+
@inline any(a::StaticArray{<:Tuple,Bool}; dims=:) = _reduce(|, a, dims, (init=false,)) # (benchmarking needed)
219+
@inline any(f::Function, a::StaticArray; dims=:) = _mapreduce(x->f(x)::Bool, |, dims, (init=false,), Size(a), a) # (benchmarking needed)
205220

206-
@inline Base.in(x, a::StaticArray) = mapreduce(==(x), |, a, init=false)
221+
@inline Base.in(x, a::StaticArray) = _mapreduce(==(x), |, :, (init=false,), Size(a), a)
207222

208223
_mean_denom(a, dims::Colon) = length(a)
209224
_mean_denom(a, dims::Int) = size(a, dims)
210225
_mean_denom(a, ::Val{D}) where {D} = size(a, D)
211226
_mean_denom(a, ::Type{Val{D}}) where {D} = size(a, D)
212227

213-
@inline mean(a::StaticArray; dims=:) = sum(a; dims=dims) / _mean_denom(a,dims)
214-
@inline mean(f::Function, a::StaticArray;dims=:) = sum(f, a; dims=dims) / _mean_denom(a,dims)
228+
@inline mean(a::StaticArray; dims=:) = _reduce(+, a, dims) / _mean_denom(a, dims)
229+
@inline mean(f::Function, a::StaticArray; dims=:) = _mapreduce(f, +, dims, NamedTuple(), Size(a), a) / _mean_denom(a, dims)
215230

216-
@inline minimum(a::StaticArray; dims=:) = reduce(min, a; dims=dims) # base has mapreduce(idenity, scalarmin, a)
217-
@inline minimum(f::Function, a::StaticArray; dims=:) = mapreduce(f, min, a; dims=dims)
231+
@inline minimum(a::StaticArray; dims=:) = _reduce(min, a, dims) # base has mapreduce(idenity, scalarmin, a)
232+
@inline minimum(f::Function, a::StaticArray; dims=:) = _mapreduce(f, min, dims, NamedTuple(), Size(a), a)
218233

219-
@inline maximum(a::StaticArray; dims=:) = reduce(max, a; dims=dims) # base has mapreduce(idenity, scalarmax, a)
220-
@inline maximum(f::Function, a::StaticArray; dims=:) = mapreduce(f, max, a; dims=dims)
234+
@inline maximum(a::StaticArray; dims=:) = _reduce(max, a, dims) # base has mapreduce(idenity, scalarmax, a)
235+
@inline maximum(f::Function, a::StaticArray; dims=:) = _mapreduce(f, max, dims, NamedTuple(), Size(a), a)
221236

222237
# Diff is slightly different
223238
@inline diff(a::StaticArray; dims) = _diff(Size(a), a, dims)

0 commit comments

Comments
 (0)