@@ -103,10 +103,23 @@ end
103
103
end
104
104
end
105
105
106
- @inline _mapreduce (f, op, D:: Int , nt:: NamedTuple , sz:: Size{S} , a:: StaticArray ) where {S} =
107
- _mapreduce (f, op, Val (D), nt, sz, a)
106
+ @inline function _mapreduce (f, op, D:: Int , nt:: NamedTuple , sz:: Size{S} , a:: StaticArray ) where {S}
107
+ # Body of this function is split because constant propagation (at least
108
+ # as of Julia 1.2) can't always correctly propagate here and
109
+ # as a result the function is not type stable and very slow.
110
+ # This makes it at least fast for three dimensions but people should use
111
+ # for example any(a; dims=Val(1)) instead of any(a; dims=1) anyway.
112
+ if D == 1
113
+ return _mapreduce (f, op, Val (1 ), nt, sz, a)
114
+ elseif D == 2
115
+ return _mapreduce (f, op, Val (2 ), nt, sz, a)
116
+ elseif D == 3
117
+ return _mapreduce (f, op, Val (3 ), nt, sz, a)
118
+ else
119
+ return _mapreduce (f, op, Val (D), nt, sz, a)
120
+ end
121
+ end
108
122
109
-
110
123
@generated function _mapreduce (f, op, dims:: Val{D} , nt:: NamedTuple{()} ,
111
124
:: Size{S} , a:: StaticArray ) where {S,D}
112
125
N = length (S)
161
174
# # reduce ##
162
175
# ###########
163
176
164
- @inline reduce (op, a:: StaticArray ; kw... ) = mapreduce (identity, op, a; kw... )
177
+ @inline reduce (op, a:: StaticArray ; dims= :, kw... ) = _reduce (op, a, dims, kw. data)
178
+
179
+ @inline _reduce (op, a:: StaticArray , dims, kw:: NamedTuple = NamedTuple ()) = _mapreduce (identity, op, dims, kw, Size (a), a)
165
180
166
181
# ######################
167
182
# # related functions ##
@@ -186,38 +201,38 @@ end
186
201
# TODO : change to use Base.reduce_empty/Base.reduce_first
187
202
@inline iszero (a:: StaticArray{<:Tuple,T} ) where {T} = reduce ((x,y) -> x && iszero (y), a, init= true )
188
203
189
- @inline sum (a:: StaticArray{<:Tuple,T} ; dims= :) where {T} = reduce (+ , a; dims = dims)
190
- @inline sum (f, a:: StaticArray{<:Tuple,T} ; dims= :) where {T} = mapreduce (f, + , a; dims = dims )
191
- @inline sum (f:: Union{Function, Type} , a:: StaticArray{<:Tuple,T} ; dims= :) where {T} = mapreduce (f, + , a; dims = dims ) # avoid ambiguity
204
+ @inline sum (a:: StaticArray{<:Tuple,T} ; dims= :) where {T} = _reduce (+ , a, dims)
205
+ @inline sum (f, a:: StaticArray{<:Tuple,T} ; dims= :) where {T} = _mapreduce (f, + , dims, NamedTuple (), Size (a), a )
206
+ @inline sum (f:: Union{Function, Type} , a:: StaticArray{<:Tuple,T} ; dims= :) where {T} = _mapreduce (f, + , dims, NamedTuple (), Size (a), a ) # avoid ambiguity
192
207
193
- @inline prod (a:: StaticArray{<:Tuple,T} ; dims= :) where {T} = reduce (* , a; dims = dims)
194
- @inline prod (f, a:: StaticArray{<:Tuple,T} ; dims= :) where {T} = mapreduce (f, * , a; dims = dims )
195
- @inline prod (f:: Union{Function, Type} , a:: StaticArray{<:Tuple,T} ; dims= :) where {T} = mapreduce (f, * , a; dims = dims )
208
+ @inline prod (a:: StaticArray{<:Tuple,T} ; dims= :) where {T} = _reduce (* , a, dims)
209
+ @inline prod (f, a:: StaticArray{<:Tuple,T} ; dims= :) where {T} = _mapreduce (f, * , dims, NamedTuple (), Size (a), a )
210
+ @inline prod (f:: Union{Function, Type} , a:: StaticArray{<:Tuple,T} ; dims= :) where {T} = _mapreduce (f, * , dims, NamedTuple (), Size (a), a )
196
211
197
- @inline count (a:: StaticArray{<:Tuple,Bool} ; dims= :) = reduce (+ , a; dims = dims)
198
- @inline count (f, a:: StaticArray ; dims= :) = mapreduce (x-> f (x):: Bool , + , a; dims = dims )
212
+ @inline count (a:: StaticArray{<:Tuple,Bool} ; dims= :) = _reduce (+ , a, dims)
213
+ @inline count (f, a:: StaticArray ; dims= :) = _mapreduce (x-> f (x):: Bool , + , dims, NamedTuple (), Size (a), a )
199
214
200
- @inline all (a:: StaticArray{<:Tuple,Bool} ; dims= :) = reduce (& , a; dims= dims, init= true ) # non-branching versions
201
- @inline all (f:: Function , a:: StaticArray ; dims= :) = mapreduce (x-> f (x):: Bool , & , a; dims= dims, init= true )
215
+ @inline all (a:: StaticArray{<:Tuple,Bool} ; dims= :) = _reduce (& , a, dims, ( init= true ,) ) # non-branching versions
216
+ @inline all (f:: Function , a:: StaticArray ; dims= :) = _mapreduce (x-> f (x):: Bool , & , dims, ( init= true ,), Size (a), a )
202
217
203
- @inline any (a:: StaticArray{<:Tuple,Bool} ; dims= :) = reduce (| , a; dims= dims, init= false ) # (benchmarking needed)
204
- @inline any (f:: Function , a:: StaticArray ; dims= :) = mapreduce (x-> f (x):: Bool , | , a; dims= dims, init= false ) # (benchmarking needed)
218
+ @inline any (a:: StaticArray{<:Tuple,Bool} ; dims= :) = _reduce (| , a, dims, ( init= false ,) ) # (benchmarking needed)
219
+ @inline any (f:: Function , a:: StaticArray ; dims= :) = _mapreduce (x-> f (x):: Bool , | , dims, ( init= false ,), Size (a), a ) # (benchmarking needed)
205
220
206
- @inline Base. in (x, a:: StaticArray ) = mapreduce (== (x), | , a, init= false )
221
+ @inline Base. in (x, a:: StaticArray ) = _mapreduce (== (x), | , :, ( init= false ,), Size (a), a )
207
222
208
223
_mean_denom (a, dims:: Colon ) = length (a)
209
224
_mean_denom (a, dims:: Int ) = size (a, dims)
210
225
_mean_denom (a, :: Val{D} ) where {D} = size (a, D)
211
226
_mean_denom (a, :: Type{Val{D}} ) where {D} = size (a, D)
212
227
213
- @inline mean (a:: StaticArray ; dims= :) = sum (a; dims= dims ) / _mean_denom (a,dims)
214
- @inline mean (f:: Function , a:: StaticArray ;dims= :) = sum (f, a; dims= dims) / _mean_denom (a,dims)
228
+ @inline mean (a:: StaticArray ; dims= :) = _reduce ( + , a, dims) / _mean_denom (a, dims)
229
+ @inline mean (f:: Function , a:: StaticArray ; dims= :) = _mapreduce (f, + , dims, NamedTuple (), Size (a), a) / _mean_denom (a, dims)
215
230
216
- @inline minimum (a:: StaticArray ; dims= :) = reduce (min, a; dims = dims) # base has mapreduce(idenity, scalarmin, a)
217
- @inline minimum (f:: Function , a:: StaticArray ; dims= :) = mapreduce (f, min, a; dims = dims )
231
+ @inline minimum (a:: StaticArray ; dims= :) = _reduce (min, a, dims) # base has mapreduce(idenity, scalarmin, a)
232
+ @inline minimum (f:: Function , a:: StaticArray ; dims= :) = _mapreduce (f, min, dims, NamedTuple (), Size (a), a )
218
233
219
- @inline maximum (a:: StaticArray ; dims= :) = reduce (max, a; dims = dims) # base has mapreduce(idenity, scalarmax, a)
220
- @inline maximum (f:: Function , a:: StaticArray ; dims= :) = mapreduce (f, max, a; dims = dims )
234
+ @inline maximum (a:: StaticArray ; dims= :) = _reduce (max, a, dims) # base has mapreduce(idenity, scalarmax, a)
235
+ @inline maximum (f:: Function , a:: StaticArray ; dims= :) = _mapreduce (f, max, dims, NamedTuple (), Size (a), a )
221
236
222
237
# Diff is slightly different
223
238
@inline diff (a:: StaticArray ; dims) = _diff (Size (a), a, dims)
0 commit comments