Skip to content

Commit

Permalink
WIP: Issue 408 (#416)
Browse files Browse the repository at this point in the history
* WIP

* WIP

* WIP; work to close #408

* forward diff extension

* oops

* adjustment for testing

* version bump
  • Loading branch information
jverzani authored Feb 5, 2024
1 parent 83476df commit 17cea1c
Show file tree
Hide file tree
Showing 7 changed files with 187 additions and 28 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ docs/site
test/benchmarks.json
Manifest.toml
TODO.md
default.profraw
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Roots"
uuid = "f2b01f46-fcfa-551c-844a-d8ac1e96c665"
version = "2.1.1"
version = "2.1.2"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
62 changes: 45 additions & 17 deletions ext/RootsForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -1,32 +1,60 @@

module RootsForwardDiffExt

using Roots
using ForwardDiff
import ForwardDiff: Dual, value, partials
import ForwardDiff: Dual, value, partials, Partials, derivative, gradient!

# For ForwardDiff we add a `solve` method for Dual types
# TODO (Issue #384) ForwardDiff.hessian fails, but this works:
#function hess(f, p)
# ∇(p) = ForwardDiff.gradient(f, p)
# ForwardDiff.jacobian(∇, p)
#end
# What works
# F(p) = find_zero(f, x0, M, p)
# G(p) = find_zero(𝐺(p), x0, M)
# F G
# ForwardDiff.derivative ✓ x (wrong answer, 0.0)
# ForwardDiff.gradient ✓ x (wrong answer, 0.0)
# ForwardDiff.hessian ✓ x (wrong answer, 0.0)
# Zygote.gradient ✓ ✓
# Zygote.hessian ✓ x (wrong answer!)
# Zygote.hessian_reverse ✓ x (MethodError)

function Roots.solve(ZP::ZeroProblem,
M::Roots.AbstractUnivariateZeroMethod,
𝐩::Union{Dual{T},
AbstractArray{<:Dual{T,<:Real}}
};
𝐩::Dual{T};
kwargs...) where {T}


# p_and_dp = 𝐩
p, dp = value.(𝐩), partials.(𝐩)

xᵅ = solve(ZP, M, p; kwargs...)

f = ZP.F
pᵥ = value.(𝐩)
xᵅ = solve(ZP, M, pᵥ; kwargs...)
𝐱ᵅ = Dual{T}(xᵅ, one(xᵅ))
fₓ = derivative(_x -> f(_x, p), xᵅ)
fₚ = derivative(_p -> f(xᵅ, _p), p)

fₓ = partials(f(𝐱ᵅ, pᵥ), 1)
fₚ = partials(f(xᵅ, 𝐩))
Dual{T}(xᵅ, - fₚ / fₓ)
# x and dx
dx = - (fₚ * dp) / fₓ

Dual{T}(xᵅ, dx)
end

# cf https://discourse.julialang.org/t/custom-rule-for-differentiating-through-newton-solver-using-forwarddiff-works-for-gradient-fails-for-hessian/93002/22
function Roots.solve(ZP::ZeroProblem,
M::Roots.AbstractUnivariateZeroMethod,
𝐩::AbstractArray{<:Dual{T,R,N}};
kwargs...) where {T,R,N}


# p_and_dp = 𝐩
p, dp = value.(𝐩), partials.(𝐩)
xᵅ = solve(ZP, M, p; kwargs...)

f = ZP.F
fₓ = derivative(_x -> f(_x, p), xᵅ)
fₚ = similar(𝐩) # <-- need this, not output of gradient(p->f(x,p), p)
gradient!(fₚ, _p -> f(xᵅ, _p), p)

# x_and_dx
dx = - (fₚ' * dp) / fₓ

Dual{T}(xᵅ, Partials(ntuple(k -> dx[k], Val(N))))
end
end
108 changes: 103 additions & 5 deletions src/chain_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,33 @@
# ∇f = 0 => ∂/∂ₓ f(xᵅ, p) ⋅ ∂xᵅ/∂ₚ + ∂/∂ₚf(x\^α, p) ⋅ I = 0
# or ∂xᵅ/∂ₚ = - ∂/∂ₚ f(xᵅ, p) / ∂/∂ₓ f(xᵅ, p)

# There are two cases considered
# F(p) = find_zero(f(x,p), x₀, M, p) # f a function
# G(p) = find_zero(𝐺(p), x₀, M) # 𝐺 a functor
# For G(p) first order derivatives are working
# **but** hessian is not with Zygote. *MOREOVER* it fails
# with the **wrong answer** not an error.
#
# (`Zygote.hessian` calls `ForwardDiff` and that isn't working with a functor;
# `Zygote.hessian_reverse` doesn't seem to work here, though perhaps
# that is fixable.)


# this assumes a function and a parameter `p` passed in
import ChainRulesCore: Tangent, NoTangent, frule, rrule
function ChainRulesCore.frule(
config::ChainRulesCore.RuleConfig{>:ChainRulesCore.HasForwardsMode},
(_, _, _, Δp),
::typeof(solve),
ZP::ZeroProblem,
M::AbstractUnivariateZeroMethod,
M::Roots.AbstractUnivariateZeroMethod,
p;
kwargs...,
)
xᵅ = solve(ZP, M, p; kwargs...)

# Use a single reverse-mode AD call with `rrule_via_ad` if `config` supports it?
F = p -> Callable_Function(M, ZP.F, p)
F = p -> Roots.Callable_Function(M, ZP.F, p)
fₓ(x) = first(F(p)(x))
fₚ(p) = first(F(p)(xᵅ))
fx = ChainRulesCore.frule_via_ad(config, (ChainRulesCore.NoTangent(), true), fₓ, xᵅ)[2]
Expand All @@ -24,23 +38,59 @@ function ChainRulesCore.frule(
xᵅ, -fp / fx
end

# Case of Functor carrying parameters
ChainRulesCore.frule(
config::ChainRulesCore.RuleConfig{>:ChainRulesCore.HasForwardsMode},
xdots,
::typeof(solve),
ZP::Roots.ZeroProblem,
M::Roots.AbstractUnivariateZeroMethod,
::Nothing;
kwargs...,
) =
frule(config, xdots, solve, ZP, M; kwargs...)

function ChainRulesCore.frule(
config::ChainRulesCore.RuleConfig{>:ChainRulesCore.HasForwardsMode},
(_, Δq, _),
::typeof(solve),
ZP::Roots.ZeroProblem,
M::Roots.AbstractUnivariateZeroMethod;
kwargs...,
)
# no `p`; make ZP.F the parameter (issue 408)
foo = ZP.F
zprob2 = ZeroProblem(|>, ZP.x₀)
nms = fieldnames(typeof(foo))
nt = NamedTuple{nms}(getfield(foo, n) for n nms)
dfoo = Tangent{typeof(foo)}(;nt...)

return frule(config,
(NoTangent(), NoTangent(), NoTangent(), dfoo),
Roots.solve, zprob2, M, foo)
end


##

## modified from
## https://github.com/gdalle/ImplicitDifferentiation.jl/blob/main/src/implicit_function.jl
# this is for passing a parameter `p`
function ChainRulesCore.rrule(
rc::ChainRulesCore.RuleConfig{>:ChainRulesCore.HasReverseMode},
::typeof(solve),
ZP::ZeroProblem,
M::AbstractUnivariateZeroMethod,
M::Roots.AbstractUnivariateZeroMethod,
p;
kwargs...,
)
xᵅ = solve(ZP, M, p; kwargs...)

f(x, p) = first(Callable_Function(M, ZP.F, p)(x))
f(x, p) = first(Roots.Callable_Function(M, ZP.F, p)(x))
_, pullback_f = ChainRulesCore.rrule_via_ad(rc, f, xᵅ, p)
_, fx, fp = pullback_f(true)
yp = -fp / fx

yp = -fp / fx
function pullback_solve_ZeroProblem(dy)
dp = yp * dy
return (
Expand All @@ -53,3 +103,51 @@ function ChainRulesCore.rrule(

return xᵅ, pullback_solve_ZeroProblem
end

# this assumes a functor 𝐺(p) for the function *and* no parameter
ChainRulesCore.rrule(
rc::ChainRulesCore.RuleConfig{>:ChainRulesCore.HasReverseMode},
::typeof(solve),
ZP::ZeroProblem,
M::Roots.AbstractUnivariateZeroMethod,
::Nothing;
kwargs...,
) =
ChainRulesCore.rrule(rc, solve, ZP, M; kwargs...)


function ChainRulesCore.rrule(
rc::ChainRulesCore.RuleConfig{>:ChainRulesCore.HasReverseMode},
::typeof(solve),
ZP::ZeroProblem,
M::Roots.AbstractUnivariateZeroMethod;
kwargs...,
)


𝑍𝑃 = ZeroProblem(|>, ZP.x₀)
xᵅ = solve(ZP, M; kwargs...)
f(x, p) = first(Roots.Callable_Function(M, 𝑍𝑃.F, p)(x))

_, pullback_f = ChainRulesCore.rrule_via_ad(rc, f, xᵅ, ZP.F)
_, fx, fp = pullback_f(true)

yp = NamedTuple{keys(fp)}(-fₚ/fx for fₚ values(fp))

function pullback_solve_ZeroProblem(dy)
dF = ChainRulesCore.Tangent{typeof(ZP.F)}(; yp...)

dZP = ChainRulesCore.Tangent{typeof(ZP)}(;
F = dF,
x₀ = ChainRulesCore.NoTangent()
)

dsolve = ChainRulesCore.NoTangent()
dM = ChainRulesCore.NoTangent()
dp = ChainRulesCore.NoTangent()

return dsolve, dZP, dM, dp
end

return xᵅ, pullback_solve_ZeroProblem
end
32 changes: 31 additions & 1 deletion test/test_chain_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,13 @@ using Zygote
using Test

# issue #325 add frule, rrule
struct 𝐺
p
end
(g::𝐺)(x) = cos(x) - g.p * x
G₃(p) = find_zero(𝐺(p), (0, pi/2), Bisection())
F₃(p) = find_zero((x,p) -> cos(x) - p*x, (0, pi/2), Bisection(), p)


@testset "Test frule and rrule" begin
# Type inference tests of `test_frule` and `test_rrule` with the default
Expand Down Expand Up @@ -33,7 +40,7 @@ using Test
G(p) = find_zero(g, 1, Order1(), p)
@test first(Zygote.gradient(G, [0, 4])) [1 / 2, 1 / 4]

# a tuple of functions
# a tuple of functions
fx(x, p) = 1 / x
test_frule(solve, ZeroProblem((f, fx), 1), Roots.Newton(), 1.0; check_inferred=false)
test_rrule(solve, ZeroProblem((f, fx), 1), Roots.Newton(), 1.0; check_inferred=false)
Expand Down Expand Up @@ -67,4 +74,27 @@ using Test
)
G2(p) = find_zero((g, gx), 1, Roots.Newton(), p)
@test first(Zygote.gradient(G2, [0, 4])) [1 / 2, 1 / 4]

# test Functor; issue #408
x = rand()
@test first(Zygote.gradient(F₃, x)) first(Zygote.gradient(G₃, x))
# ForwardDiff extension makes this fail.
VERSION >= v"1.9.0" && @test_broken first(Zygote.hessian(F₃, x)) first(Zygote.hessian(G₃, x))
# test_frule, test_rrule aren't successful
#=
# DimensionMismatch: arrays could not be broadcast to a common size; got a dimension with lengths 3 and 2
test_frule(
solve,
ZeroProblem(𝐺(2), (0.0, pi/2)),
Roots.Bisection();
check_inferred=false,
)
# MethodError: no method matching keys(::NoTangent)
test_rrule(
solve,
ZeroProblem(𝐺(2), (0.0, pi/2)),
Roots.Bisection();
check_inferred=false,
)
=#
end
6 changes: 2 additions & 4 deletions test/test_extensions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,19 +56,17 @@ using ForwardDiff
@test ForwardDiff.derivative(F, p) 1 / (2sqrt(p))
end

# Hessian is *broken*
# Hessian is *fixed* for F(p) = find_zero(f, x₀, M, p)
f = (x, p) -> x^2 - sum(p .^ 2)
Z = ZeroProblem(f, (0, 1000))
F = p -> solve(Z, Roots.Bisection(), p)
Z = ZeroProblem(f, (0, 1000))
F = p -> solve(Z, Roots.Bisection(), p)
hess(f, p) = ForwardDiff.jacobian(p -> ForwardDiff.gradient(F, p), p)
for p ([1,2], [1,3], [1,4])
@test F(p) sqrt(sum(p .^ 2))
@test_throws DimensionMismatch ForwardDiff.hessian(F, p)
a, b = p
n = sqrt(a^2 + b^2)^3
@test hess(F, p) [b^2 -a*b; -a*b a^2] / n
@test ForwardDiff.hessian(F, p) [b^2 -a*b; -a*b a^2] / n
end
end

Expand Down
4 changes: 4 additions & 0 deletions tmp/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Diffractor = "9f5e2b26-1114-432f-b630-d3fe2085c51c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

2 comments on commit 17cea1c

@jverzani
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/100302

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v2.1.2 -m "<description of version>" 17cea1c8d8ee1d2f84d27cd5f595324a3f7e0873
git push origin v2.1.2

Please sign in to comment.