60
60
# # mapreduce ##
61
61
# ##############
62
62
63
- @inline function mapreduce (f, op, a:: StaticArray , b:: StaticArray... )
64
- _mapreduce (f, op, same_size (a, b... ), a, b... )
63
+ @inline function mapreduce (f, op, a:: StaticArray , b:: StaticArray... ; dims = :,kw ... )
64
+ _mapreduce (f, op, dims, kw . data, same_size (a, b... ), a, b... )
65
65
end
66
66
67
- @inline function mapreduce (f, op, v0, a:: StaticArray , b:: StaticArray... )
68
- _mapreduce (f, op, v0, same_size (a, b... ), a, b... )
69
- end
70
-
71
- @generated function _mapreduce (f, op, :: Size{S} , a:: StaticArray... ) where {S}
67
+ @generated function _mapreduce (f, op, dims:: Colon , nt:: NamedTuple{()} ,
68
+ :: Size{S} , a:: StaticArray... ) where {S}
72
69
tmp = [:(a[$ j][1 ]) for j ∈ 1 : length (a)]
73
70
expr = :(f ($ (tmp... )))
74
71
for i ∈ 2 : prod (S)
80
77
@inbounds return $ expr
81
78
end
82
79
end
83
-
84
- @generated function _mapreduce (f, op, v0, :: Size{S} , a:: StaticArray... ) where {S}
85
- expr = :v0
80
+
81
+ @generated function _mapreduce (f, op, dims:: Colon , nt:: NamedTuple{(:init,)} ,
82
+ :: Size{S} , a:: StaticArray... ) where {S}
83
+ expr = :(nt. init)
86
84
for i ∈ 1 : prod (S)
87
85
tmp = [:(a[$ j][$ i]) for j ∈ 1 : length (a)]
88
86
expr = :(op ($ expr, f ($ (tmp... ))))
93
91
end
94
92
end
95
93
96
- # #################
97
- # # mapreducedim ##
98
- # #################
99
-
100
- # I'm not sure why the signature for this from Base precludes multiple arrays?
101
- # (also, why not mutating `mapreducedim!` and `reducedim!`?)
102
- # (similarly, `broadcastreduce` and `broadcastreducedim` sounds useful)
103
- @inline function mapreducedim (f, op, a:: StaticArray , :: Type{Val{D}} ) where {D}
104
- _mapreducedim (f, op, Size (a), a, Val{D})
105
- end
106
-
107
- @inline function mapreducedim (f, op, a:: StaticArray , :: Type{Val{D}} , v0) where {D}
108
- _mapreducedim (f, op, Size (a), a, Val{D}, v0)
94
+ @inline function _mapreduce (f, op, :: Type{Val{D}} , nt:: NamedTuple ,sz:: Size{S} , a:: StaticArray ) where {S,D}
95
+ Base. depwarn (" `Val{D}` as dims argument is deprecated, use `D` or `Val(D)` instead." )
96
+ _mapreduce (f, op, Val (D), nt, sz, a)
109
97
end
98
+ @inline _mapreduce (f, op, D:: Int , nt:: NamedTuple , sz:: Size{S} , a:: StaticArray ) where {S} =
99
+ _mapreduce (f, op, Val (D), nt, sz, a)
110
100
111
- @generated function _mapreducedim (f, op, :: Size{S} , a:: StaticArray , :: Type{Val{D}} ) where {S,D}
101
+
102
+ @generated function _mapreduce (f, op, dims:: Val{D} , nt:: NamedTuple{()} ,
103
+ :: Size{S} , a:: StaticArray ) where {S,D}
112
104
N = length (S)
113
105
Snew = ([n== D ? 1 : S[n] for n = 1 : N]. .. ,)
114
106
T0 = eltype (a)
@@ -133,14 +125,15 @@ end
133
125
end
134
126
end
135
127
136
- @generated function _mapreducedim (f, op, :: Size{S} , a:: StaticArray , :: Type{Val{D}} , v0:: T ) where {S,D,T}
128
+ @generated function _mapreduce (f, op, dims:: Val{D} , nt:: NamedTuple{(:init,),Tuple{T}} ,
129
+ :: Size{S} , a:: StaticArray ) where {S,D,T}
137
130
N = length (S)
138
131
Snew = ([n== D ? 1 : S[n] for n = 1 : N]. .. ,)
139
132
140
133
exprs = Array {Expr} (undef, Snew)
141
134
itr = [1 : n for n = Snew]
142
135
for i ∈ Base. product (itr... )
143
- expr = :v0
136
+ expr = :(nt . init)
144
137
for k = 1 : S[D]
145
138
ik = collect (i)
146
139
ik[D] = k
160
153
# # reduce ##
161
154
# ###########
162
155
163
- @inline reduce (op, a:: StaticArray ) = mapreduce (identity, op, a)
164
- @inline reduce (op, v0, a:: StaticArray ) = mapreduce (identity, op, v0, a)
165
-
166
- # ##############
167
- # # reducedim ##
168
- # ##############
169
-
170
- @inline reducedim (op, a:: StaticArray , :: Type{Val{D}} ) where {D} = mapreducedim (identity, op, a, Val{D})
171
- @inline reducedim (op, a:: StaticArray , :: Type{Val{D}} , v0) where {D} = mapreducedim (identity, op, a, Val{D}, v0)
156
+ @inline reduce (op, a:: StaticArray ; kw... ) = mapreduce (identity, op, a; kw... )
172
157
173
158
# ######################
174
159
# # related functions ##
@@ -178,68 +163,64 @@ end
178
163
#
179
164
# Implementation notes:
180
165
#
181
- # 1. When providing an initial value v0, note that its location is different in reduce and
182
- # reducedim: v0 comes earlier than collection in reduce, whereas it is the last argument in
183
- # reducedim. The same difference exists between mapreduce and mapreducedim.
184
- #
185
- # 2. mapreduce and mapreducedim usually do not take initial value v0, because we don't
166
+ # 1. mapreduce and mapreducedim usually do not take initial value, because we don't
186
167
# always know the return type of an arbitrary mapping function f. (We usually want to use
187
- # some initial value such as one(T) or zero(T) as v0 , where T is the return type of f, but
168
+ # some initial value such as one(T) or zero(T), where T is the return type of f, but
188
169
# if users provide type-unstable f, its return type cannot be known.) Therefore, mapped
189
170
# versions of the functions implemented below usually require the collection to have at
190
171
# least two entries.
191
172
#
192
- # 3 . Exceptions are the ones that require Boolean mapping functions. For example, f in
193
- # all and any must return Bool, so we know the appropriate v0 is true and false,
173
+ # 2 . Exceptions are the ones that require Boolean mapping functions. For example, f in
174
+ # all and any must return Bool, so we know the appropriate initial value is true and false,
194
175
# respectively. Therefore, all(f, ...) and any(f, ...) are implemented by mapreduce(f, ...)
195
176
# with an initial value v0 = true and false.
196
- @inline iszero (a:: StaticArray{<:Any,T} ) where {T} = reduce ((x,y) -> x && (y== zero (T)), true , a)
197
-
198
- @inline sum (a:: StaticArray{<:Any,T} ) where {T} = reduce (+ , zero (T), a)
199
- @inline sum (f:: Function , a:: StaticArray ) = mapreduce (f, + , a)
200
- @inline sum (a:: StaticArray{<:Any,T} , :: Type{Val{D}} ) where {T,D} = reducedim (+ , a, Val{D}, zero (T))
201
- @inline sum (f:: Function , a:: StaticArray , :: Type{Val{D}} ) where D = mapreducedim (f, + , a, Val{D})
202
-
203
- @inline prod (a:: StaticArray{<:Any,T} ) where {T} = reduce (* , one (T), a)
204
- @inline prod (f:: Function , a:: StaticArray{<:Any,T} ) where {T} = mapreduce (f, * , a)
205
- @inline prod (a:: StaticArray{<:Any,T} , :: Type{Val{D}} ) where {T,D} = reducedim (* , a, Val{D}, one (T))
206
- @inline prod (f:: Function , a:: StaticArray{<:Any,T} , :: Type{Val{D}} ) where {T,D} = mapreducedim (f, * , a, Val{D})
207
-
208
- @inline count (a:: StaticArray{<:Any,Bool} ) = reduce (+ , 0 , a)
209
- @inline count (f:: Function , a:: StaticArray ) = mapreduce (x-> f (x):: Bool , + , 0 , a)
210
- @inline count (a:: StaticArray{<:Any,Bool} , :: Type{Val{D}} ) where {D} = reducedim (+ , a, Val{D}, 0 )
211
- @inline count (f:: Function , a:: StaticArray , :: Type{Val{D}} ) where {D} = mapreducedim (x-> f (x):: Bool , + , a, Val{D}, 0 )
212
-
213
- @inline all (a:: StaticArray{<:Any,Bool} ) = reduce (& , true , a) # non-branching versions
214
- @inline all (f:: Function , a:: StaticArray ) = mapreduce (x-> f (x):: Bool , & , true , a)
215
- @inline all (a:: StaticArray{<:Any,Bool} , :: Type{Val{D}} ) where {D} = reducedim (& , a, Val{D}, true )
216
- @inline all (f:: Function , a:: StaticArray , :: Type{Val{D}} ) where {D} = mapreducedim (x-> f (x):: Bool , & , a, Val{D}, true )
217
-
218
- @inline any (a:: StaticArray{<:Any,Bool} ) = reduce (| , false , a) # (benchmarking needed)
219
- @inline any (f:: Function , a:: StaticArray ) = mapreduce (x-> f (x):: Bool , | , false , a) # (benchmarking needed)
220
- @inline any (a:: StaticArray{<:Any,Bool} , :: Type{Val{D}} ) where {D} = reducedim (| , a, Val{D}, false )
221
- @inline any (f:: Function , a:: StaticArray , :: Type{Val{D}} ) where {D} = mapreducedim (x-> f (x):: Bool , | , a, Val{D}, false )
222
-
223
- @inline mean (a:: StaticArray ) = sum (a) / length (a)
224
- @inline mean (f:: Function , a:: StaticArray ) = sum (f, a) / length (a)
225
- @inline mean (a:: StaticArray , :: Type{Val{D}} ) where {D} = sum (a, Val{D}) / size (a, D)
226
- @inline mean (f:: Function , a:: StaticArray , :: Type{Val{D}} ) where {D} = sum (f, a, Val{D}) / size (a, D)
227
-
228
- @inline minimum (a:: StaticArray ) = reduce (min, a) # base has mapreduce(idenity, scalarmin, a)
229
- @inline minimum (f:: Function , a:: StaticArray ) = mapreduce (f, min, a)
230
- @inline minimum (a:: StaticArray , :: Type{Val{D}} ) where {D} = reducedim (min, a, Val{D})
231
- @inline minimum (f:: Function , a:: StaticArray , :: Type{Val{D}} ) where {D} = mapreducedim (f, min, a, Val{D})
232
-
233
- @inline maximum (a:: StaticArray ) = reduce (max, a) # base has mapreduce(idenity, scalarmax, a)
234
- @inline maximum (f:: Function , a:: StaticArray ) = mapreduce (f, max, a)
235
- @inline maximum (a:: StaticArray , :: Type{Val{D}} ) where {D} = reducedim (max, a, Val{D})
236
- @inline maximum (f:: Function , a:: StaticArray , :: Type{Val{D}} ) where {D} = mapreducedim (f, max, a, Val{D})
177
+ #
178
+ # TODO : change to use Base.reduce_empty/Base.reduce_first
179
+ @inline iszero (a:: StaticArray{<:Any,T} ) where {T} = reduce ((x,y) -> x && iszero (y), a, init= true )
180
+
181
+ @inline sum (a:: StaticArray{<:Any,T} ; dims= :) where {T} = reduce (+ , a; dims= dims)
182
+ @inline sum (f, a:: StaticArray{<:Any,T} ; dims= :) where {T} = mapreduce (f, + , a; dims= dims)
183
+ @inline sum (f:: Union{Function, Type} , a:: StaticArray{<:Any,T} ; dims= :) where {T} = mapreduce (f, + , a; dims= dims) # avoid ambiguity
184
+
185
+ @inline prod (a:: StaticArray{<:Any,T} ; dims= :) where {T} = reduce (* , a; dims= dims)
186
+ @inline prod (f, a:: StaticArray{<:Any,T} ; dims= :) where {T} = mapreduce (f, * , a; dims= dims)
187
+ @inline prod (f:: Union{Function, Type} , a:: StaticArray{<:Any,T} ; dims= :) where {T} = mapreduce (f, * , a; dims= dims)
188
+
189
+ @inline count (a:: StaticArray{<:Any,Bool} ; dims= :) = reduce (+ , a; dims= dims)
190
+ @inline count (f, a:: StaticArray ; dims= :) = mapreduce (x-> f (x):: Bool , + , a; dims= dims)
191
+
192
+ @inline all (a:: StaticArray{<:Any,Bool} ; dims= :) = reduce (& , a; dims= dims, init= true ) # non-branching versions
193
+ @inline all (f:: Function , a:: StaticArray ; dims= :) = mapreduce (x-> f (x):: Bool , & , a; dims= dims, init= true )
194
+
195
+ @inline any (a:: StaticArray{<:Any,Bool} ; dims= :) = reduce (| , a; dims= dims, init= false ) # (benchmarking needed)
196
+ @inline any (f:: Function , a:: StaticArray ; dims= :) = mapreduce (x-> f (x):: Bool , | , a; dims= dims, init= false ) # (benchmarking needed)
197
+
198
+ _mean_denom (a, dims:: Colon ) = length (a)
199
+ _mean_denom (a, dims:: Int ) = size (a, dims)
200
+ _mean_denom (a, :: Val{D} ) where {D} = size (a, D)
201
+ _mean_denom (a, :: Type{Val{D}} ) where {D} = size (a, D)
202
+
203
+ @inline mean (a:: StaticArray ; dims= :) = sum (a; dims= dims) / _mean_denom (a,dims)
204
+ @inline mean (f:: Function , a:: StaticArray ;dims= :) = sum (f, a; dims= dims) / _mean_denom (a,dims)
205
+
206
+ @inline minimum (a:: StaticArray ; dims= :) = reduce (min, a; dims= dims) # base has mapreduce(idenity, scalarmin, a)
207
+ @inline minimum (f:: Function , a:: StaticArray ; dims= :) = mapreduce (f, min, a; dims= dims)
208
+
209
+ @inline maximum (a:: StaticArray ; dims= :) = reduce (max, a; dims= dims) # base has mapreduce(idenity, scalarmax, a)
210
+ @inline maximum (f:: Function , a:: StaticArray ; dims= :) = mapreduce (f, max, a; dims= dims)
237
211
238
212
# Diff is slightly different
239
- @inline diff (a:: StaticArray ) = diff (a, Val{ 1 } )
240
- @inline diff (a:: StaticArray , :: Type{Val{D}} ) where {D} = _diff ( Size (a), a, Val{D} )
213
+ @inline diff (a:: StaticArray ; dims ) = _diff ( Size (a), a, dims )
214
+ @inline diff (a:: StaticVector ) = diff (a;dims = Val ( 1 ) )
241
215
242
- @generated function _diff (:: Size{S} , a:: StaticArray , :: Type{Val{D}} ) where {S,D}
216
+ @inline function _diff (sz:: Size{S} , a:: StaticArray , D:: Int ) where {S}
217
+ _diff (sz,a,Val (D))
218
+ end
219
+ @inline function _diff (sz:: Size{S} , a:: StaticArray , :: Type{Val{D}} ) where {S,D}
220
+ Base. depwarn (" `Val{D}` as dims argument is deprecated, use `D` or `Val(D)` instead." )
221
+ _diff (sz,a,Val (D))
222
+ end
223
+ @generated function _diff (:: Size{S} , a:: StaticArray , :: Val{D} ) where {S,D}
243
224
N = length (S)
244
225
Snew = ([n== D ? S[n]- 1 : S[n] for n = 1 : N]. .. ,)
245
226
0 commit comments