Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix rrule for * and add support for constant operations #211

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 82 additions & 4 deletions src/chain_rules.jl
Original file line number Diff line number Diff line change
@@ -1,30 +1,108 @@
# The publlback depends on the scalar product on the polynomials
# With the scalar product `LinearAlgebra.dot(p, q) = p * q`, there is no pullback for `differentiate`
# With the scalar product `_dot(p, q)` of `test/chain_rules.jl`, there is a pullback for `differentiate`
# and the pullback for `*` changes.
# We give the one for the scalar product `_dot`.

import ChainRulesCore

ChainRulesCore.@scalar_rule +(x::APL) true
ChainRulesCore.@scalar_rule -(x::APL) -1

ChainRulesCore.@scalar_rule +(x::APL, y::APL) (true, true)
function plusconstant1_pullback(Δ)
return ChainRulesCore.NoTangent(), Δ, coefficient(Δ, constantmonomial(Δ))
end
function ChainRulesCore.rrule(::typeof(plusconstant), p::APL, α)
return plusconstant(p, α), plusconstant1_pullback
end
function plusconstant2_pullback(Δ)
return ChainRulesCore.NoTangent(), coefficient(Δ, constantmonomial(Δ)), Δ
end
function ChainRulesCore.rrule(::typeof(plusconstant), α, p::APL)
return plusconstant(α, p), plusconstant2_pullback
end
ChainRulesCore.@scalar_rule -(x::APL, y::APL) (true, -1)

function ChainRulesCore.frule((_, Δp, Δq), ::typeof(*), p::APL, q::APL)
return p * q, MA.add_mul!!(p * Δq, q, Δp)
end

function _mult_pullback(op::F, ts, p, Δ) where {F<:Function}
for t in terms(p)
c = coefficient(t)
m = monomial(t)
for δ in Δ
if divides(m, δ)
coef = op(c, coefficient(δ))
mono = _div(monomial(δ), m)
push!(ts, term(coef, mono))
end
end
end
return polynomial(ts)
end
function adjoint_mult_left(p, Δ)
ts = MA.promote_operation(*, MA.promote_operation(adjoint, termtype(p)), termtype(Δ))[]
return _mult_pullback(ts, p, Δ) do c, d
c' * d
end
end
function adjoint_mult_right(p, Δ)
ts = MA.promote_operation(*, termtype(Δ), MA.promote_operation(adjoint, termtype(p)))[]
return _mult_pullback(ts, p, Δ) do c, d
d * c'
end
end

function ChainRulesCore.rrule(::typeof(*), p::APL, q::APL)
function times_pullback2(ΔΩ̇)
#ΔΩ = ChainRulesCore.unthunk(Ω̇)
#return (ChainRulesCore.NoTangent(), ChainRulesCore.ProjectTo(p)(ΔΩ * q'), ChainRulesCore.ProjectTo(q)(p' * ΔΩ))
# This is for the scalar product `_dot`:
return (ChainRulesCore.NoTangent(), adjoint_mult_right(q, ΔΩ̇), adjoint_mult_left(p, ΔΩ̇))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous definition was correct for the scalar product dot(p, q) = p * q, the current one is correct for the scalar product dot(p, q) = dot(coefficients(p), coefficients(q)) (assuming they have the same monomials).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This second scalar product corresponds to the _dot we use for testing differentiate. This means that for differentiate we assume the other scalar product

# For the scalar product `dot`, it would be instead:
return (ChainRulesCore.NoTangent(), ΔΩ̇ * q', p' * ΔΩ̇)
end
return p * q, times_pullback2
end

function ChainRulesCore.rrule(::typeof(multconstant), α, p::APL)
function times_pullback2(ΔΩ̇)
# TODO we could make it faster, don't need to compute `Δα` entirely if we only care about the constant term.
Δα = adjoint_mult_right(p, ΔΩ̇)
return (ChainRulesCore.NoTangent(), coefficient(Δα, constantmonomial(Δα)), α' * ΔΩ̇)
end
return multconstant(α, p), times_pullback2
end

function ChainRulesCore.rrule(::typeof(multconstant), p::APL, α)
function times_pullback2(ΔΩ̇)
# TODO we could make it faster, don't need to compute `Δα` entirely if we only care about the constant term.
Δα = adjoint_mult_left(p, ΔΩ̇)
return (ChainRulesCore.NoTangent(), ΔΩ̇ * α', coefficient(Δα, constantmonomial(Δα)))
end
return multconstant(p, α), times_pullback2
end

notangent3(Δ) = ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent()
function ChainRulesCore.rrule(::typeof(^), mono::AbstractMonomialLike, i::Integer)
return mono^i, notangent3
end

function ChainRulesCore.frule((_, Δp, _), ::typeof(differentiate), p, x)
return differentiate(p, x), differentiate(Δp, x)
end
function pullback(Δdpdx, x)
# This is for the scalar product `_dot`, there is no pullback for the scalar product `dot`
function differentiate_pullback(Δdpdx, x)
return ChainRulesCore.NoTangent(), x * differentiate(x * Δdpdx, x), ChainRulesCore.NoTangent()
end
function ChainRulesCore.rrule(::typeof(differentiate), p, x)
dpdx = differentiate(p, x)
return dpdx, Base.Fix2(pullback, x)
return dpdx, Base.Fix2(differentiate_pullback, x)
end

function coefficient_pullback(Δ, m::AbstractMonomialLike)
return ChainRulesCore.NoTangent(), polynomial(term(Δ, m)), ChainRulesCore.NoTangent()
end
function ChainRulesCore.rrule(::typeof(coefficient), p::APL, m::AbstractMonomialLike)
return coefficient(p, m), Base.Fix2(coefficient_pullback, m)
end
47 changes: 28 additions & 19 deletions test/chain_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,20 @@ function test_chain_rule(dot, op, args, Δin, Δout)
@test dot(Δin, rΔin[2:end]) ≈ dot(fΔout, Δout)
end

function _dot(p, q)
monos = monovec([monomials(p); monomials(q)])
return dot(coefficient.(p, monos), coefficient.(q, monos))
end
function _dot(px::Tuple, qx::Tuple)
return _dot(first(px), first(qx)) + _dot(Base.tail(px), Base.tail(qx))
end
function _dot(::Tuple{}, ::Tuple{})
return MultivariatePolynomials.MA.Zero()
end
function _dot(::NoTangent, ::NoTangent)
return MultivariatePolynomials.MA.Zero()
end

@testset "ChainRulesCore" begin
Mod.@polyvar x y
p = 1.1x + y
Expand Down Expand Up @@ -42,30 +56,25 @@ end
@test pullback(q) == (NoTangent(), (-0.2 + 2im) * x^2 - x*y, NoTangent())
@test pullback(1x) == (NoTangent(), 2x^2, NoTangent())

test_chain_rule(dot, +, (p,), (q,), p)
test_chain_rule(dot, +, (q,), (p,), q)
for d in [dot, _dot]
test_chain_rule(d, +, (p,), (q,), p)
test_chain_rule(d, +, (q,), (p,), q)

test_chain_rule(dot, -, (p,), (q,), p)
test_chain_rule(dot, -, (p,), (p,), q)
test_chain_rule(d, -, (p,), (q,), p)
test_chain_rule(d, -, (p,), (p,), q)

test_chain_rule(dot, +, (p, q), (q, p), p)
test_chain_rule(dot, +, (p, q), (p, q), q)
test_chain_rule(d, +, (p, q), (q, p), p)
test_chain_rule(d, +, (p, q), (p, q), q)

test_chain_rule(dot, -, (p, q), (q, p), p)
test_chain_rule(dot, -, (p, q), (p, q), q)
test_chain_rule(d, -, (p, q), (q, p), p)
test_chain_rule(d, -, (p, q), (p, q), q)
end

test_chain_rule(dot, *, (p, q), (q, p), p * q)
test_chain_rule(dot, *, (p, q), (p, q), q * q)
test_chain_rule(dot, *, (q, p), (p, q), q * q)
test_chain_rule(dot, *, (p, q), (q, p), q * q)
test_chain_rule(_dot, *, (p, q), (q, p), p * q)
test_chain_rule(_dot, *, (p, q), (p, q), q * q)
test_chain_rule(_dot, *, (q, p), (p, q), q * q)
test_chain_rule(_dot, *, (p, q), (q, p), q * q)

function _dot(p, q)
monos = monomials(p + q)
return dot(coefficient.(p, monos), coefficient.(q, monos))
end
function _dot(px::Tuple{<:AbstractPolynomial,NoTangent}, qx::Tuple{<:AbstractPolynomial,NoTangent})
return _dot(px[1], qx[1])
end
test_chain_rule(_dot, differentiate, (p, x), (q, NoTangent()), p)
test_chain_rule(_dot, differentiate, (p, x), (q, NoTangent()), differentiate(p, x))
test_chain_rule(_dot, differentiate, (p, x), (q, NoTangent()), differentiate(q, x))
Expand Down