Skip to content

Commit

Permalink
Fix forward AD
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 11, 2023
1 parent 47135d4 commit 5e7bafc
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 39 deletions.
57 changes: 39 additions & 18 deletions src/ad.jl
Original file line number Diff line number Diff line change
@@ -1,40 +1,61 @@
function scalar_nlsolve_ad(prob, alg, args...; kwargs...)
f = prob.f
p = value(prob.p)

u0 = value(prob.u0)
newprob = NonlinearProblem(f, u0, p; prob.kwargs...)

sol = solve(newprob, alg, args...; kwargs...)

uu = sol.u
if p isa Number
f_p = ForwardDiff.derivative(Base.Fix1(f, uu), p)
else
f_p = ForwardDiff.gradient(Base.Fix1(f, uu), p)
end
f_p = scalar_nlsolve_∂f_∂p(f, uu, p)
f_x = scalar_nlsolve_∂f_∂u(f, uu, p)

Check warning on line 11 in src/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/ad.jl#L10-L11

Added lines #L10 - L11 were not covered by tests

z_arr = -inv(f_x) * f_p

Check warning on line 13 in src/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/ad.jl#L13

Added line #L13 was not covered by tests

f_x = ForwardDiff.derivative(Base.Fix2(f, p), uu)
pp = prob.p
sumfun = let f_x′ = -f_x
((fp, p),) -> (fp / f_x′) * ForwardDiff.partials(p)
sumfun = ((z, p),) -> [zᵢ * ForwardDiff.partials(p) for zᵢ in z]
if uu isa Number
partials = sum(sumfun, zip(z_arr, pp))

Check warning on line 18 in src/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/ad.jl#L16-L18

Added lines #L16 - L18 were not covered by tests
else
partials = sum(sumfun, zip(eachcol(z_arr), pp))

Check warning on line 20 in src/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/ad.jl#L20

Added line #L20 was not covered by tests
end
partials = sum(sumfun, zip(f_p, pp))

return sol, partials
end

function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, SVector}, iip,
<:Dual{T, V, P}}, alg::AbstractNewtonAlgorithm, args...;
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, SVector, <:AbstractArray},

Check warning on line 26 in src/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/ad.jl#L26

Added line #L26 was not covered by tests
iip, <:Dual{T, V, P}}, alg::AbstractNewtonAlgorithm, args...;
kwargs...) where {iip, T, V, P}
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
return SciMLBase.build_solution(prob, alg, Dual{T, V, P}(sol.u, partials), sol.resid;
sol.retcode)
dual_soln = scalar_nlsolve_dual_soln(sol.u, partials, prob.p)
return SciMLBase.build_solution(prob, alg, dual_soln, sol.resid; sol.retcode)

Check warning on line 31 in src/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/ad.jl#L30-L31

Added lines #L30 - L31 were not covered by tests
end

function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, SVector}, iip,
<:AbstractArray{<:Dual{T, V, P}}}, alg::AbstractNewtonAlgorithm, args...;
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, SVector, <:AbstractArray},

Check warning on line 34 in src/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/ad.jl#L34

Added line #L34 was not covered by tests
iip, <:AbstractArray{<:Dual{T, V, P}}}, alg::AbstractNewtonAlgorithm, args...;
kwargs...) where {iip, T, V, P}
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
return SciMLBase.build_solution(prob, alg, Dual{T, V, P}(sol.u, partials), sol.resid;
sol.retcode)
dual_soln = scalar_nlsolve_dual_soln(sol.u, partials, prob.p)
return SciMLBase.build_solution(prob, alg, dual_soln, sol.resid; sol.retcode)

Check warning on line 39 in src/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/ad.jl#L38-L39

Added lines #L38 - L39 were not covered by tests
end

function scalar_nlsolve_∂f_∂p(f, u, p)
ff = p isa Number ? ForwardDiff.derivative :

Check warning on line 43 in src/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/ad.jl#L42-L43

Added lines #L42 - L43 were not covered by tests
(u isa Number ? ForwardDiff.gradient : ForwardDiff.jacobian)
return ff(Base.Fix1(f, u), p)

Check warning on line 45 in src/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/ad.jl#L45

Added line #L45 was not covered by tests
end

function scalar_nlsolve_∂f_∂u(f, u, p)
ff = u isa Number ? ForwardDiff.derivative : ForwardDiff.jacobian
return ff(Base.Fix2(f, p), u)

Check warning on line 50 in src/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/ad.jl#L48-L50

Added lines #L48 - L50 were not covered by tests
end

function scalar_nlsolve_dual_soln(u::Number, partials,

Check warning on line 53 in src/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/ad.jl#L53

Added line #L53 was not covered by tests
::Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}}) where {T, V, P}
return Dual{T, V, P}(u, partials[1])

Check warning on line 55 in src/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/ad.jl#L55

Added line #L55 was not covered by tests
end

function scalar_nlsolve_dual_soln(u::AbstractArray, partials,

Check warning on line 58 in src/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/ad.jl#L58

Added line #L58 was not covered by tests
::Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}}) where {T, V, P}
return map(((uᵢ, pᵢ),) -> Dual{T, V, P}(uᵢ, pᵢ), zip(u, partials))

Check warning on line 60 in src/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/ad.jl#L60

Added line #L60 was not covered by tests
end
30 changes: 9 additions & 21 deletions test/basictests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,13 @@ end
@test (@ballocated solve!($cache)) 64
end

# FIXME: Even the previous tests were broken, but due to a typo in the tests they
# accidentally passed
@testset "[OOP] [Immutable AD] p: $(p)" for p in 1.0:0.1:100.0
@test begin
res = benchmark_nlsolve_oop(quadratic_f, @SVector[1.0, 1.0], p)
res_true = sqrt(p)
all(res.u .≈ res_true)
end
@test_broken ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
@test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
@SVector[1.0, 1.0], p).u[end], p) 1 / (2 * sqrt(p))
end

Expand Down Expand Up @@ -101,11 +99,9 @@ end
@test nlprob_iterator_interface(quadratic_f, p, Val(false)) sqrt.(p)
@test nlprob_iterator_interface(quadratic_f!, p, Val(true)) sqrt.(p)

probN = NonlinearProblem(quadratic_f, @SVector[1.0, 1.0], 2.0)
@testset "ADType: $(autodiff) u0: $(u0)" for autodiff in (false, true,
@testset "ADType: $(autodiff) u0: $(_nameof(u0))" for autodiff in (false, true,
AutoSparseForwardDiff(), AutoSparseFiniteDiff(), AutoZygote(),
AutoSparseZygote(),
AutoSparseEnzyme()), u0 in (1.0, [1.0, 1.0], @SVector[1.0, 1.0])
AutoSparseZygote(), AutoSparseEnzyme()), u0 in (1.0, [1.0, 1.0])
probN = NonlinearProblem(quadratic_f, u0, 2.0)
@test all(solve(probN, NewtonRaphson(; autodiff)).u .≈ sqrt(2.0))
end
Expand Down Expand Up @@ -149,8 +145,6 @@ end
@test (@ballocated solve!($cache)) 64
end

# FIXME: Even the previous tests were broken, but due to a typo in the tests they
# accidentally passed
@testset "[OOP] [Immutable AD] radius_update_scheme: $(radius_update_scheme) p: $(p)" for radius_update_scheme in radius_update_schemes,
p in 1.0:0.1:100.0

Expand All @@ -160,7 +154,7 @@ end
res_true = sqrt(p)
all(res.u .≈ res_true)
end
@test_broken ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
@test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
@SVector[1.0, 1.0], p; radius_update_scheme).u[end], p) 1 / (2 * sqrt(p))
end

Expand Down Expand Up @@ -204,11 +198,9 @@ end
@test nlprob_iterator_interface(quadratic_f, p, Val(false)) sqrt.(p)
@test nlprob_iterator_interface(quadratic_f!, p, Val(true)) sqrt.(p)

probN = NonlinearProblem(quadratic_f, @SVector[1.0, 1.0], 2.0)
@testset "ADType: $(autodiff) u0: $(u0) radius_update_scheme: $(radius_update_scheme)" for autodiff in (false,
@testset "ADType: $(autodiff) u0: $(_nameof(u0)) radius_update_scheme: $(radius_update_scheme)" for autodiff in (false,
true, AutoSparseForwardDiff(), AutoSparseFiniteDiff(), AutoZygote(),
AutoSparseZygote(), AutoSparseEnzyme()),
u0 in (1.0, [1.0, 1.0], @SVector[1.0, 1.0]),
AutoSparseZygote(), AutoSparseEnzyme()), u0 in (1.0, [1.0, 1.0]),
radius_update_scheme in radius_update_schemes

probN = NonlinearProblem(quadratic_f, u0, 2.0)
Expand Down Expand Up @@ -302,15 +294,13 @@ end
@test (@ballocated solve!($cache)) 64
end

# FIXME: Even the previous tests were broken, but due to a typo in the tests they
# accidentally passed
@testset "[OOP] [Immutable AD] p: $(p)" for p in 1.0:0.1:100.0
@test begin
res = benchmark_nlsolve_oop(quadratic_f, @SVector[1.0, 1.0], p)
res_true = sqrt(p)
all(res.u .≈ res_true)
end
@test_broken ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
@test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
@SVector[1.0, 1.0], p).u[end], p) 1 / (2 * sqrt(p))
end

Expand All @@ -330,11 +320,9 @@ end
@test ForwardDiff.jacobian(p -> [benchmark_nlsolve_oop(quadratic_f2, 0.5, p).u], p)
ForwardDiff.jacobian(t, p)

probN = NonlinearProblem(quadratic_f, @SVector[1.0, 1.0], 2.0)
@testset "ADType: $(autodiff) u0: $(u0)" for autodiff in (false, true,
@testset "ADType: $(autodiff) u0: $(_nameof(u0))" for autodiff in (false, true,
AutoSparseForwardDiff(), AutoSparseFiniteDiff(), AutoZygote(),
AutoSparseZygote(),
AutoSparseEnzyme()), u0 in (1.0, [1.0, 1.0], @SVector[1.0, 1.0])
AutoSparseZygote(), AutoSparseEnzyme()), u0 in (1.0, [1.0, 1.0])
probN = NonlinearProblem(quadratic_f, u0, 2.0)
@test all(solve(probN, LevenbergMarquardt(; autodiff)).u .≈ sqrt(2.0))
end
Expand Down

0 comments on commit 5e7bafc

Please sign in to comment.