-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Configured rule for maximum(f, xs)
#490
Conversation
maximum(f, xs)
maximum(f, xs)
First attemptWith a more expensive function:
The The broadcasted one uses dual numbers, which is much quicker. Note BTW that there is no chunk mode in play here -- it always evaluates I'm not so sure why the complete reduction is slower than broadcasting here, but it's much closer, and 3x less memory. Diffractor, BTW, does not see this rule. It does see #480, but broadcast times are variable:
|
This has been much simplified. For the case of a complete reduction only, julia> @btime gradient(x -> sum(maximum(sqrt, x)), $(rand(30,30))); # this PR + Zygote + Julia 1.8
min 8.625 μs, mean 10.906 μs (52 allocations, 8.92 KiB. GC mean 13.94%)
julia> @btime gradient(x -> sum(maximum(sqrt.(x))), $(rand(30,30)));
min 10.041 μs, mean 16.087 μs (49 allocations, 36.88 KiB. GC mean 20.75%)
julia> @btime gradient(x -> sum(maximum(log∘exp, x)), $(rand(30,30))); # with a more expensive function:
min 20.208 μs, mean 22.335 μs (116 allocations, 10.88 KiB. GC mean 5.22%)
julia> @btime gradient(x -> sum(maximum((log∘exp).(x))), $(rand(30,30)));
min 19.291 μs, mean 25.757 μs (49 allocations, 36.88 KiB. GC mean 13.03%)
julia> @btime maximum(log∘exp, $(rand(30,30)));
min 8.958 μs, mean 9.128 μs (0 allocations) That means it calls Instead of using For cases with On Julia 1.6 and below, the method |
Status here is as in (edited) first message above. Perhaps the broadcast path can be easily tested using JuliaDiff/ChainRulesTestUtils.jl#243 once that's available. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few questions, generally looks good. Do you plan to extend the tests?
test/rulesets/Base/mapreduce.jl
Outdated
@test_skip test_rrule(maximum, sqrt, Float64[1 2; 3 4], fkwargs=(; dims = 1), check_inferred=false) | ||
@test_skip test_rrule(minimum, abs, randn(3,3), fkwargs=(; dims = 2), check_inferred=false) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these will need JuliaDiff/FiniteDifferences.jl#203
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought these needed JuliaDiff/ChainRulesTestUtils.jl#243 : with dims
it always calls broadcast
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, they do need JuliaDiff/ChainRulesTestUtils.jl#243 (now merged), but also JuliaDiff/FiniteDifferences.jl#203 to get around to_vec
ing InplaceableThunk
s correctly (tested locally)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But where do InplaceableThunks come from? This path of this rule doesn't make them.
I do still get an error with only CRTU update:
julia> test_rrule(maximum, sqrt, Float64[1 2; 3 4], fkwargs=(; dims = 1), check_inferred=false)
test_rrule: maximum on typeof(sqrt),Matrix{Float64}: Error During Test at /Users/me/.julia/packages/ChainRulesTestUtils/fCvaU/src/testers.jl:193
Got exception outside of a @test
DimensionMismatch("second dimension of A, 4, does not match length of x, 7")
Stacktrace:
[1] gemv!(y::Vector{Float64}, tA::Char, A::Matrix{Float64}, x::Vector{Float64}, α::Bool, β::Bool)
@ LinearAlgebra ~/.julia/dev/julia/usr/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:493
[2] mul!
@ ~/.julia/dev/julia/usr/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:93 [inlined]
[3] mul!
@ ~/.julia/dev/julia/usr/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:276 [inlined]
[4] *(tA::Transpose{Float64, Matrix{Float64}}, x::Vector{Float64})
@ LinearAlgebra ~/.julia/dev/julia/usr/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:86
[5] _j′vp(fdm::FiniteDifferences.AdaptedFiniteDifferenceMethod{5, 1, FiniteDifferences.UnadaptedFiniteDifferenceMethod{7, 5}}, f::Function, ȳ::Vector{Float64}, x::Vector{Float64})
@ FiniteDifferences ~/.julia/packages/FiniteDifferences/R6uao/src/grad.jl:80
[6] j′vp(fdm::FiniteDifferences.AdaptedFiniteDifferenceMethod{5, 1, FiniteDifferences.UnadaptedFiniteDifferenceMethod{7, 5}}, f::ChainRulesTestUtils.var"#fnew#45"{ChainRulesTestUtils.var"#call#41"{Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Tuple{typeof(broadcast), typeof(sqrt), Matrix{Float64}}, Tuple{Bool, Bool, Bool}}, ȳ::InplaceableThunk{Thunk{ChainRules.var"#1316#1319"{Matrix{Float64}, Int64, Matrix{Float64}, ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}, Matrix{CartesianIndex{2}}}}, ChainRules.var"#1317#1320"{Matrix{Float64}, Int64, Matrix{CartesianIndex{2}}}}, x::Matrix{Float64})
@ FiniteDifferences ~/.julia/packages/FiniteDifferences/R6uao/src/grad.jl:73
[7] _make_j′vp_call(fdm::Any, f::Any, ȳ::Any, xs::Any, ignores::Any)
@ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/fCvaU/src/finite_difference_calls.jl:51
[8] f_pb
@ ~/.julia/packages/ChainRulesTestUtils/fCvaU/src/rule_config.jl:40 [inlined]
[9] (::ChainRules.var"#minormax_f_back2#2098"{ChainRules.var"#maximum_pullback#1326"{ChainRules.var"#findmax_pullback#1318"{Int64
```
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Solved by the to_vec PR, as you said.
Can this thing give less cryptic errors than these "DimensionMismatch" when it goes wrong?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I agree with you in general: JuliaDiff/ChainRulesTestUtils.jl#244
Here though this is coming from rrule_via_ad
using the make_v'jp_call
rather than the usual place 😂
Solving JuliaDiff/ChainRulesTestUtils.jl#213 would be a big QoL improvement indeed. It's on my list
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
JuliaDiff/FiniteDifferences.jl#203 is now merged, so I think we can update the tests
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great!
This one is weird locally, but on 1.6 it seems to work (or will once changed to ≈ [10 0 0; 0 -20 0]
):
julia> y2, bk2 = rrule(CFG, minimum, abs, [1 2 3; -5 -4 -4], dims = 2);
julia> @test y2 == hcat([1, 4])
Test Passed
Expression: y2 == hcat([1, 4])
Evaluated: [1; 4;;] == [1; 4;;]
julia> bk2(hcat([10, 20]))
(NoTangent(), NoTangent(), NoTangent())
save less stuff in sum(f, xs) rule probably destroyed in the rebase re-organise change to use BitArray add a few tests Revert "save less stuff in sum(f, xs) rule" This reverts commit c8034da. tidy, add cumsum trick tests for multiple maxima tweaks
fixup update, tidy Apply 3 suggestions Co-authored-by: Miha Zgubic <[email protected]> add an error remove error, as closing over `y` breaks inference simplify, update solve Core.Box tests approx
This uses the
RuleConfig{>:HasReverseMode}
story to call back into AD to write a rule formaximum(f, xs)
.It's much simplified from the first attempt:
i = findmax(f, xs)
, and then usesrrule_via_ad(f, xs[i])
.Fast case, before & after:
Before this PR,
gradient(x -> sum(maximum(sqrt, x, dims=1)), (rand(30,30)))
gives an error with Zygote. After, it is the same speed as broadcasting.What doesn't seem easy now is testing the broadcast path.
First attempt
However, it only needs one such call, rather than one for every element. That means it ends up calling
f
sayN^2 + 1
times for a matrix (orN^2 + N
withdims
). This is much more efficient than calling it via AD allN^2
times, saving the pullbacks somewhere, and calling just one. Not always faster than Zygote's current broadcasting (which uses ForwardDiff), but much less memory:If this is OK, then perhaps the
sum(f, x)
rule from #441 should also consider callingf
more times. There's a commit here doing that, which cuts the memory use by quite a bit. Perhaps there are functionsf
for which calling twice would be slower? Perhaps writingsum(f, x)
vs.sum(f.(x))
is how you emphasise that you care more about memory?(It may make sense to remove this & discuss[Now removed here.]sum
in another thread.)All WIP, needs more careful testing, etc.