diff --git a/src/rulesets/Base/mapreduce.jl b/src/rulesets/Base/mapreduce.jl index 5588c245c..a70bc27bf 100644 --- a/src/rulesets/Base/mapreduce.jl +++ b/src/rulesets/Base/mapreduce.jl @@ -158,7 +158,7 @@ for mimum in (:minimum, :maximum) imax = findfirst(mask) else y = $mimum(f, xs; dims=dims) - mask = y .== f.(xs) + mask = y .== f.(xs) # this is N^2 more calls to f, that's a lot! mask .= (mask .== cumsum(mask; dims=dims) .== true) imax = findall(mask) end @@ -174,7 +174,7 @@ for mimum in (:minimum, :maximum) elseif Base.issingletontype(F) # Then we need not accumulate the gradient with respect to `f`. dfs = NoTangent() - # On a matrix we called `f` N^2 times, now call it N more with `rrule_via_ad`: + # On a matrix we called `f` 2*N^2 times, now call it N more with `rrule_via_ad`: dxmax = map(view(xs, imax), unthunk(dys)) do x, dy _, bk = rrule_via_ad(config, f, x) df, dx = bk(dy) @@ -184,7 +184,7 @@ for mimum in (:minimum, :maximum) # This could perhaps accumulate df more smartly... call(g, x) = g(x) backs = map(x -> last(rrule_via_ad(config, f, x)), view(xs, imax)) - dfs_and_dxs = map(unthunk∘last∘call, backs, unthunk(dy)) + dfs_and_dxs = map(unthunk∘last∘call, backs, unthunk(dys)) dfs = sum(first, dfs_and_dxs) dxmax = map(unthunk∘last, dfs_and_dxs) end