@@ -158,7 +158,7 @@ for mimum in (:minimum, :maximum)
158
158
imax = findfirst (mask)
159
159
else
160
160
y = $ mimum (f, xs; dims= dims)
161
- mask = y .== f .(xs)
161
+ mask = y .== f .(xs) # this is N^2 more calls to f, that's a lot!
162
162
mask .= (mask .== cumsum (mask; dims= dims) .== true )
163
163
imax = findall (mask)
164
164
end
@@ -174,7 +174,7 @@ for mimum in (:minimum, :maximum)
174
174
elseif Base. issingletontype (F)
175
175
# Then we need not accumulate the gradient with respect to `f`.
176
176
dfs = NoTangent ()
177
- # On a matrix we called `f` N^2 times, now call it N more with `rrule_via_ad`:
177
+ # On a matrix we called `f` 2* N^2 times, now call it N more with `rrule_via_ad`:
178
178
dxmax = map (view (xs, imax), unthunk (dys)) do x, dy
179
179
_, bk = rrule_via_ad (config, f, x)
180
180
df, dx = bk (dy)
@@ -184,7 +184,7 @@ for mimum in (:minimum, :maximum)
184
184
# This could perhaps accumulate df more smartly...
185
185
call (g, x) = g (x)
186
186
backs = map (x -> last (rrule_via_ad (config, f, x)), view (xs, imax))
187
- dfs_and_dxs = map (unthunk∘ last∘ call, backs, unthunk (dy ))
187
+ dfs_and_dxs = map (unthunk∘ last∘ call, backs, unthunk (dys ))
188
188
dfs = sum (first, dfs_and_dxs)
189
189
dxmax = map (unthunk∘ last, dfs_and_dxs)
190
190
end
0 commit comments