Skip to content

Commit 59120a0

Browse files
committed
tweaks
1 parent 765dbda commit 59120a0

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

src/rulesets/Base/mapreduce.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ for mimum in (:minimum, :maximum)
158158
imax = findfirst(mask)
159159
else
160160
y = $mimum(f, xs; dims=dims)
161-
mask = y .== f.(xs)
161+
mask = y .== f.(xs) # this is N^2 more calls to f, that's a lot!
162162
mask .= (mask .== cumsum(mask; dims=dims) .== true)
163163
imax = findall(mask)
164164
end
@@ -174,7 +174,7 @@ for mimum in (:minimum, :maximum)
174174
elseif Base.issingletontype(F)
175175
# Then we need not accumulate the gradient with respect to `f`.
176176
dfs = NoTangent()
177-
# On a matrix we called `f` N^2 times, now call it N more with `rrule_via_ad`:
177+
# On a matrix we called `f` 2*N^2 times, now call it N more with `rrule_via_ad`:
178178
dxmax = map(view(xs, imax), unthunk(dys)) do x, dy
179179
_, bk = rrule_via_ad(config, f, x)
180180
df, dx = bk(dy)
@@ -184,7 +184,7 @@ for mimum in (:minimum, :maximum)
184184
# This could perhaps accumulate df more smartly...
185185
call(g, x) = g(x)
186186
backs = map(x -> last(rrule_via_ad(config, f, x)), view(xs, imax))
187-
dfs_and_dxs = map(unthunklastcall, backs, unthunk(dy))
187+
dfs_and_dxs = map(unthunklastcall, backs, unthunk(dys))
188188
dfs = sum(first, dfs_and_dxs)
189189
dxmax = map(unthunklast, dfs_and_dxs)
190190
end

0 commit comments

Comments
 (0)