From 9238c2e79b972c77a9e991980ea69d70b9de1cc8 Mon Sep 17 00:00:00 2001 From: Frames White Date: Fri, 9 Feb 2024 20:31:05 +0800 Subject: [PATCH] Support frules with keyword arguments --- src/stage1/forward.jl | 28 +++++++++++++++++++++++++--- test/forward.jl | 38 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 3 deletions(-) diff --git a/src/stage1/forward.jl b/src/stage1/forward.jl index 7b8a37b7..c059e931 100644 --- a/src/stage1/forward.jl +++ b/src/stage1/forward.jl @@ -117,11 +117,33 @@ function (::∂☆internal{1})(args::AbstractTangentBundle{1}...) end end -_frule(partials, primals...) = frule(DiffractorRuleConfig(), partials, primals...) -function _frule(::NTuple{<:Any, AbstractZero}, f, primal_args...) +# frules with keywords support: +function (::∂☆internal{1})(kwc::ATB{1, typeof(Core.kwcall)}, kw::ATB{1}, f::ATB{1}, args::ATB{1}...) + args_primals = (primal(f), map(primal, args)...) + args_partials = (first_partial(f), map(first_partial, args)...) + # First check if directly overloading frule for kwcall + r = _frule( + (first_partial(kwc), first_partial(kw), args_partials...), + primal(kwc), primal(kw), args_primals... + ) + if r===nothing + # then check if the frule for f accepts keywords + # This silently discards tangents of the kw-args + # TODO: should we error if they nonzero? + r = _frule(args_partials, args_primals...; primal(kw)...) + end + if r === nothing + return ∂☆recurse{1}()(kwc, kw, f, args...) + else + return shuffle_base(r) + end +end + +_frule(partials, primals...; kwargs...) = frule(DiffractorRuleConfig(), partials, primals...; kwargs...) +function _frule(::NTuple{<:Any, AbstractZero}, f, primal_args...; kwargs...) # frules are linear in partials, so zero maps to zero, no need to evaluate the frule # If all partials are immutable AbstractZero subtyoes we know we don't have to worry about a mutating frule either - r = f(primal_args...) + r = f(primal_args...; kwargs...) return r, zero_tangent(r) end diff --git a/test/forward.jl b/test/forward.jl index f8040639..ec3c018c 100644 --- a/test/forward.jl +++ b/test/forward.jl @@ -227,4 +227,42 @@ end end end +@testset "frule with kwarg" begin + mulby_kw(v; n) = n*v + triple(v) = mulby_kw(v; n=3) + frule_hits = 0 + function ChainRulesCore.frule((_, dv), ::typeof(mulby_kw), v; n) + y = mulby_kw(v; n) + dy = n*dv + frule_hits +=1 + return y, dy + end + + let var"'" = Diffractor.PrimeDerivativeFwd + @assert frule_hits == 0 + @test triple'(2.0) == 3.0 + @test frule_hits == 1 + end + + mulby_kw2(v; n) = n*v + square(v) = mulby_kw2(v; n=v) + frule_hits = 0 + function ChainRulesCore.frule((_, dkw, _, dv), ::typeof(Core.kwcall), kw, ::typeof(mulby_kw2), v) + n = kw.n + dn = dkw.n + y = mulby_kw2(v; n) + dy = n*dv + dn*v + frule_hits +=1 + return y, dy + end + + let var"'" = Diffractor.PrimeDerivativeFwd + @assert frule_hits == 0 + @test square'(3.0) == 6.0 + @test frule_hits == 1 + end +end + end # module + +