From 2c3df716b0c0d76c3b94cf00edac93571d07000b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Mon, 26 Dec 2022 15:35:07 +0100 Subject: [PATCH] Clarify issue with scalar products --- src/chain_rules.jl | 16 ++++++++++++--- test/chain_rules.jl | 47 +++++++++++++++++++++++++++------------------ 2 files changed, 41 insertions(+), 22 deletions(-) diff --git a/src/chain_rules.jl b/src/chain_rules.jl index 60877411..6e9d0ba8 100644 --- a/src/chain_rules.jl +++ b/src/chain_rules.jl @@ -1,3 +1,9 @@ +# The publlback depends on the scalar product on the polynomials +# With the scalar product `LinearAlgebra.dot(p, q) = p * q`, there is no pullback for `differentiate` +# With the scalar product `_dot(p, q)` of `test/chain_rules.jl`, there is a pullback for `differentiate` +# and the pullback for `*` changes. +# We give the one for the scalar product `_dot`. + import ChainRulesCore ChainRulesCore.@scalar_rule +(x::APL) true @@ -22,7 +28,7 @@ function ChainRulesCore.frule((_, Δp, Δq), ::typeof(*), p::APL, q::APL) return p * q, MA.add_mul!!(p * Δq, q, Δp) end -function _adjoint_mult(op::F, ts, p, Δ) where {F<:Function} +function _mult_pullback(op::F, ts, p, Δ) where {F<:Function} for t in terms(p) c = coefficient(t) m = monomial(t) @@ -38,20 +44,23 @@ function _adjoint_mult(op::F, ts, p, Δ) where {F<:Function} end function adjoint_mult_left(p, Δ) ts = MA.promote_operation(*, MA.promote_operation(adjoint, termtype(p)), termtype(Δ))[] - return _adjoint_mult(ts, p, Δ) do c, d + return _mult_pullback(ts, p, Δ) do c, d c' * d end end function adjoint_mult_right(p, Δ) ts = MA.promote_operation(*, termtype(Δ), MA.promote_operation(adjoint, termtype(p)))[] - return _adjoint_mult(ts, p, Δ) do c, d + return _mult_pullback(ts, p, Δ) do c, d d * c' end end function ChainRulesCore.rrule(::typeof(*), p::APL, q::APL) function times_pullback2(ΔΩ̇) + # This is for the scalar product `_dot`: return (ChainRulesCore.NoTangent(), adjoint_mult_right(q, ΔΩ̇), adjoint_mult_left(p, ΔΩ̇)) + # For the scalar product `dot`, it would be instead: + return (ChainRulesCore.NoTangent(), ΔΩ̇ * q', p' * ΔΩ̇) end return p * q, times_pullback2 end @@ -82,6 +91,7 @@ end function ChainRulesCore.frule((_, Δp, _), ::typeof(differentiate), p, x) return differentiate(p, x), differentiate(Δp, x) end +# This is for the scalar product `_dot`, there is no pullback for the scalar product `dot` function differentiate_pullback(Δdpdx, x) return ChainRulesCore.NoTangent(), x * differentiate(x * Δdpdx, x), ChainRulesCore.NoTangent() end diff --git a/test/chain_rules.jl b/test/chain_rules.jl index d8e3cb71..307893f0 100644 --- a/test/chain_rules.jl +++ b/test/chain_rules.jl @@ -12,6 +12,20 @@ function test_chain_rule(dot, op, args, Δin, Δout) @test dot(Δin, rΔin[2:end]) ≈ dot(fΔout, Δout) end +function _dot(p, q) + monos = monovec([monomials(p); monomials(q)]) + return dot(coefficient.(p, monos), coefficient.(q, monos)) +end +function _dot(px::Tuple, qx::Tuple) + return _dot(first(px), first(qx)) + _dot(Base.tail(px), Base.tail(qx)) +end +function _dot(::Tuple{}, ::Tuple{}) + return MultivariatePolynomials.MA.Zero() +end +function _dot(::NoTangent, ::NoTangent) + return MultivariatePolynomials.MA.Zero() +end + @testset "ChainRulesCore" begin Mod.@polyvar x y p = 1.1x + y @@ -42,30 +56,25 @@ end @test pullback(q) == (NoTangent(), (-0.2 + 2im) * x^2 - x*y, NoTangent()) @test pullback(1x) == (NoTangent(), 2x^2, NoTangent()) - test_chain_rule(dot, +, (p,), (q,), p) - test_chain_rule(dot, +, (q,), (p,), q) + for d in [dot, _dot] + test_chain_rule(d, +, (p,), (q,), p) + test_chain_rule(d, +, (q,), (p,), q) - test_chain_rule(dot, -, (p,), (q,), p) - test_chain_rule(dot, -, (p,), (p,), q) + test_chain_rule(d, -, (p,), (q,), p) + test_chain_rule(d, -, (p,), (p,), q) - test_chain_rule(dot, +, (p, q), (q, p), p) - test_chain_rule(dot, +, (p, q), (p, q), q) + test_chain_rule(d, +, (p, q), (q, p), p) + test_chain_rule(d, +, (p, q), (p, q), q) - test_chain_rule(dot, -, (p, q), (q, p), p) - test_chain_rule(dot, -, (p, q), (p, q), q) + test_chain_rule(d, -, (p, q), (q, p), p) + test_chain_rule(d, -, (p, q), (p, q), q) + end - test_chain_rule(dot, *, (p, q), (q, p), p * q) - test_chain_rule(dot, *, (p, q), (p, q), q * q) - test_chain_rule(dot, *, (q, p), (p, q), q * q) - test_chain_rule(dot, *, (p, q), (q, p), q * q) + test_chain_rule(_dot, *, (p, q), (q, p), p * q) + test_chain_rule(_dot, *, (p, q), (p, q), q * q) + test_chain_rule(_dot, *, (q, p), (p, q), q * q) + test_chain_rule(_dot, *, (p, q), (q, p), q * q) - function _dot(p, q) - monos = monomials(p + q) - return dot(coefficient.(p, monos), coefficient.(q, monos)) - end - function _dot(px::Tuple{<:AbstractPolynomial,NoTangent}, qx::Tuple{<:AbstractPolynomial,NoTangent}) - return _dot(px[1], qx[1]) - end test_chain_rule(_dot, differentiate, (p, x), (q, NoTangent()), p) test_chain_rule(_dot, differentiate, (p, x), (q, NoTangent()), differentiate(p, x)) test_chain_rule(_dot, differentiate, (p, x), (q, NoTangent()), differentiate(q, x))