@@ -142,89 +142,57 @@ end
142
142
# ####
143
143
144
144
for mimum in (:minimum , :maximum )
145
- mimum_pullback = Symbol (mimum, :_pullback_f )
145
+ pullback1 = Symbol (mimum, :_pullback_f )
146
+ pullback2 = Symbol (mimum, :_pullback_composed )
146
147
findm = Symbol (:find , string (mimum)[1 : 3 ])
147
148
148
149
@eval function rrule (
149
150
config:: RuleConfig{>:HasReverseMode} , :: typeof ($ mimum), f:: F , xs:: AbstractArray{<:Number} ; dims= :
150
151
) where {F}
151
- if dims isa Colon && VERSION >= v " 1.7-"
152
- # Best case, we can use findmax to get index:
153
- y, imax = $ findm (f, xs)
154
- elseif dims isa Colon
155
- # Explicitly figure out where it attains the max:
156
- y = $ mimum (f, xs; dims= dims)
157
- mask = y .== f .(xs)
158
- imax = findfirst (mask)
159
- else
160
- y = $ mimum (f, xs; dims= dims)
161
- mask = y .== f .(xs) # this is N^2 more calls to f, that's a lot!
162
- mask .= (mask .== cumsum (mask; dims= dims) .== true )
163
- imax = findall (mask)
164
- end
165
152
project = ProjectTo (xs)
166
153
167
- function $mimum_pullback (dys)
168
- if dims isa Colon
169
- # Notice that this does evaluate `f` one more time, but will this matter
170
- # unless `f` is sateful? In which case both this and `maximum(f.(xs))` give undefined results.
171
- _, back = rrule_via_ad (config, f, xs[imax])
172
- dfs, _dxmax = back (unthunk (dys))
173
- dxmax = unthunk (_dxmax)
174
- elseif Base. issingletontype (F)
175
- # Then we need not accumulate the gradient with respect to `f`.
176
- dfs = NoTangent ()
177
- # On a matrix we called `f` 2*N^2 times, now call it N more with `rrule_via_ad`:
178
- dxmax = map (view (xs, imax), unthunk (dys)) do x, dy
179
- _, bk = rrule_via_ad (config, f, x)
180
- df, dx = bk (dy)
181
- unthunk (dx)
154
+ # The easy case is when we can use `findmax` to get index, and write into it:
155
+ if dims isa Colon && VERSION >= v " 1.7-"
156
+ y, ind = $ findm (f, xs)
157
+ function $pullback1 (dy)
158
+ # Notice this evaluates `f` one more time, but this shouldn't matter
159
+ # unless `f` is sateful, in which case both this and `maximum(f.(xs))`
160
+ # give undefined results.
161
+ _, one_back = rrule_via_ad (config, f, xs[ind])
162
+ df, one_dx_raw = one_back (unthunk (dy))
163
+ one_dx = unthunk (one_dx_raw)
164
+ x_thunk = @thunk project (_writezero (xs, one_dx, ind, dims))
165
+ x_ithunk = InplaceableThunk (x_thunk) do dxs
166
+ view (dxs, ind) .+ = one_dx
167
+ dxs
182
168
end
183
- else
184
- # This could perhaps accumulate df more smartly...
185
- call (g, x) = g (x)
186
- backs = map (x -> last (rrule_via_ad (config, f, x)), view (xs, imax))
187
- dfs_and_dxs = map (unthunk∘ last∘ call, backs, unthunk (dys))
188
- dfs = sum (first, dfs_and_dxs)
189
- dxmax = map (unthunk∘ last, dfs_and_dxs)
169
+ return (NoTangent (), df, x_ithunk)
190
170
end
191
- x_thunk = @thunk begin
192
- dxs = fill! (similar (xs, eltype (dxmax)), false )
193
- view (dxs, imax) .= dxmax
194
- project (dxs)
195
- end
196
- x_ithunk = InplaceableThunk (x_thunk) do dxs
197
- view (dxs, imax) .= view (dxs, imax) .+ dxmax
198
- dxs
171
+ return y, $ pullback1
172
+
173
+ # Otherwise, the best path is to broadcast, `maximum(f.(xs); dims)`:
174
+ else
175
+ mid, cast_back = rrule_via_ad (config, broadcast, f, xs; dims= dims)
176
+ y, max_back = rrule ($ mimum, fxs; dims= dims)
177
+ function $pullback2 (dys)
178
+ _, dmid = max_back (dys)
179
+ _, df, dxs = cast_back (dmid) # if cast_back from rrule_via_ad makes an InplaceableThunk,
180
+ return (NoTangent (), df, project (dxs)) # then this project() will give an error.
199
181
end
200
- return NoTangent (), dfs, x_ithunk
182
+ return y, $ pullback2
201
183
end
202
184
203
- return y, $ mimum_pullback
204
- end
205
-
185
+ end # @eval function rrule(...)
206
186
end
207
187
208
- #=
209
-
210
- julia> @btime gradient(x -> maximum(sqrt, x), $(rand(30,30)));
211
- 5.632 μs (51 allocations: 8.39 KiB)
212
-
213
- julia> @btime gradient(x -> sum(maximum(sqrt, x, dims=1)), $(rand(30,30)));
214
- 9.792 μs (34 allocations: 13.92 KiB)
215
-
216
- julia> @btime gradient(x -> maximum(sqrt.(x)), $(rand(30,30)));
217
- 4.321 μs (16 allocations: 35.97 KiB)
218
-
219
- # bigger, nastier
220
-
221
- julia> @btime gradient(x -> maximum(log∘exp, x), $(rand(300,300)));
222
- 1.714 ms (132 allocations: 706.33 KiB)
223
-
224
- julia> @btime gradient(x -> maximum((log∘exp).(x)), $(rand(300,300)));
225
- 1.595 ms (20 allocations: 3.43 MiB)
226
-
227
- =#
188
+ # from another PR:
189
+ function _writezero (x, dy, ind, dims)
190
+ # It's unfortunate to close over `x`, but `similar(typeof(x), axes(x))` doesn't
191
+ # allow `eltype(dy)`, nor does it work for many structured matrices.
192
+ dx = fill! (similar (x, eltype (dy), axes (x)), false )
193
+ view (dx, ind) .= dy # possibly 0-dim view, allows dy::Number and dy::Array, and dx::CuArray
194
+ dx
195
+ end
228
196
229
197
# ####
230
198
# #### `prod`
0 commit comments