Skip to content

Commit

Permalink
tweaks
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Aug 5, 2021
1 parent 765dbda commit 59120a0
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/rulesets/Base/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ for mimum in (:minimum, :maximum)
imax = findfirst(mask)
else
y = $mimum(f, xs; dims=dims)
mask = y .== f.(xs)
mask = y .== f.(xs) # this is N^2 more calls to f, that's a lot!
mask .= (mask .== cumsum(mask; dims=dims) .== true)
imax = findall(mask)
end
Expand All @@ -174,7 +174,7 @@ for mimum in (:minimum, :maximum)
elseif Base.issingletontype(F)
# Then we need not accumulate the gradient with respect to `f`.
dfs = NoTangent()
# On a matrix we called `f` N^2 times, now call it N more with `rrule_via_ad`:
# On a matrix we called `f` 2*N^2 times, now call it N more with `rrule_via_ad`:
dxmax = map(view(xs, imax), unthunk(dys)) do x, dy
_, bk = rrule_via_ad(config, f, x)
df, dx = bk(dy)
Expand All @@ -184,7 +184,7 @@ for mimum in (:minimum, :maximum)
# This could perhaps accumulate df more smartly...
call(g, x) = g(x)
backs = map(x -> last(rrule_via_ad(config, f, x)), view(xs, imax))
dfs_and_dxs = map(unthunklastcall, backs, unthunk(dy))
dfs_and_dxs = map(unthunklastcall, backs, unthunk(dys))
dfs = sum(first, dfs_and_dxs)
dxmax = map(unthunklast, dfs_and_dxs)
end
Expand Down

0 comments on commit 59120a0

Please sign in to comment.