Skip to content

Commit

Permalink
Clean up AD
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed May 26, 2024
1 parent 24f83d8 commit ae0bf10
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 78 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ DiffResults = "1.1"
DifferentiationInterface = "0.4"
ExplicitImports = "1.5.0"
FastClosures = "0.3.2"
FiniteDiff = "2.22"
FiniteDiff = "2.23.1"
ForwardDiff = "0.10.36"
LinearAlgebra = "1.10"
LinearSolve = "2.30"
Expand Down
8 changes: 4 additions & 4 deletions src/SimpleNonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ using PrecompileTools: @compile_workload, @setup_workload, @recompile_invalidati
mul!, norm, transpose
using MaybeInplace: @bb, setindex_trait, CanSetindex, CannotSetindex
using Reexport: @reexport
using SciMLBase: SciMLBase, IntervalNonlinearProblem, NonlinearFunction,
NonlinearLeastSquaresProblem, NonlinearProblem, ReturnCode, init,
remake, solve, AbstractNonlinearAlgorithm, build_solution, isinplace,
_unwrap_val
using SciMLBase: SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem,
NonlinearFunction, NonlinearLeastSquaresProblem, NonlinearProblem,
ReturnCode, init, remake, solve, AbstractNonlinearAlgorithm,
build_solution, isinplace, _unwrap_val
using StaticArraysCore: StaticArray, SVector, SMatrix, SArray, MArray, Size
end

Expand Down
92 changes: 34 additions & 58 deletions src/ad.jl
Original file line number Diff line number Diff line change
@@ -1,25 +1,15 @@
function SciMLBase.solve(
prob::NonlinearProblem{<:Union{Number, <:AbstractArray}, iip,
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}},
alg::AbstractSimpleNonlinearSolveAlgorithm,
args...;
kwargs...) where {T, V, P, iip}
sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...)
dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p)
return SciMLBase.build_solution(
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
end

function SciMLBase.solve(
prob::NonlinearLeastSquaresProblem{
<:AbstractArray, iip, <:Union{<:AbstractArray{<:Dual{T, V, P}}}},
alg::AbstractSimpleNonlinearSolveAlgorithm,
args...;
kwargs...) where {T, V, P, iip}
sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...)
dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p)
return SciMLBase.build_solution(
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
for pType in (NonlinearProblem, NonlinearLeastSquaresProblem)
@eval function SciMLBase.solve(

Check warning on line 2 in src/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/ad.jl#L2

Added line #L2 was not covered by tests
prob::$(pType){<:Union{Number, <:AbstractArray}, iip,
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}},
alg::AbstractSimpleNonlinearSolveAlgorithm,
args...;
kwargs...) where {T, V, P, iip}
sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...)
dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p)
return SciMLBase.build_solution(

Check warning on line 10 in src/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/ad.jl#L8-L10

Added lines #L8 - L10 were not covered by tests
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
end
end

for algType in (Bisection, Brent, Alefeld, Falsi, ITP, Ridder)
Expand Down Expand Up @@ -47,8 +37,7 @@ function __nlsolve_ad(
tspan = value.(prob.tspan)
newprob = IntervalNonlinearProblem(prob.f, tspan, p; prob.kwargs...)
else
u0 = value(prob.u0)
newprob = NonlinearProblem(prob.f, u0, p; prob.kwargs...)
newprob = remake(prob; p, u0 = value(prob.u0))

Check warning on line 40 in src/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/ad.jl#L40

Added line #L40 was not covered by tests
end

sol = solve(newprob, alg, args...; kwargs...)
Expand All @@ -73,20 +62,16 @@ function __nlsolve_ad(
end

function __nlsolve_ad(prob::NonlinearLeastSquaresProblem, alg, args...; kwargs...)
p = value(prob.p)
u0 = value(prob.u0)
newprob = NonlinearLeastSquaresProblem(prob.f, u0, p; prob.kwargs...)

newprob = remake(prob; p = value(prob.p), u0 = value(prob.u0))

Check warning on line 65 in src/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/ad.jl#L65

Added line #L65 was not covered by tests
sol = solve(newprob, alg, args...; kwargs...)

uu = sol.u

# First check for custom `vjp` then custom `Jacobian` and if nothing is provided use
# nested autodiff as the last resort
if SciMLBase.has_vjp(prob.f)
if isinplace(prob)
_F = @closure (du, u, p) -> begin
resid = similar(du, length(sol.resid))
resid = __similar(du, length(sol.resid))

Check warning on line 74 in src/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/ad.jl#L74

Added line #L74 was not covered by tests
prob.f(resid, u, p)
prob.f.vjp(du, resid, u, p)
du .*= 2
Expand All @@ -101,9 +86,9 @@ function __nlsolve_ad(prob::NonlinearLeastSquaresProblem, alg, args...; kwargs..
elseif SciMLBase.has_jac(prob.f)
if isinplace(prob)
_F = @closure (du, u, p) -> begin
J = similar(du, length(sol.resid), length(u))
J = __similar(du, length(sol.resid), length(u))

Check warning on line 89 in src/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/ad.jl#L89

Added line #L89 was not covered by tests
prob.f.jac(J, u, p)
resid = similar(du, length(sol.resid))
resid = __similar(du, length(sol.resid))

Check warning on line 91 in src/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/ad.jl#L91

Added line #L91 was not covered by tests
prob.f(resid, u, p)
mul!(reshape(du, 1, :), vec(resid)', J, 2, false)
return nothing
Expand All @@ -116,43 +101,38 @@ function __nlsolve_ad(prob::NonlinearLeastSquaresProblem, alg, args...; kwargs..
else
if isinplace(prob)
_F = @closure (du, u, p) -> begin
resid = similar(du, length(sol.resid))
res = DiffResults.DiffResult(
resid, similar(du, length(sol.resid), length(u)))
_f = @closure (du, u) -> prob.f(du, u, p)
ForwardDiff.jacobian!(res, _f, resid, u)
mul!(reshape(du, 1, :), vec(DiffResults.value(res))',
DiffResults.jacobian(res), 2, false)
resid = __similar(du, length(sol.resid))
v, J = DI.value_and_jacobian(_f, resid, AutoForwardDiff(), u)
mul!(reshape(du, 1, :), vec(v)', J, 2, false)

Check warning on line 107 in src/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/ad.jl#L105-L107

Added lines #L105 - L107 were not covered by tests
return nothing
end
else
# For small problems, nesting ForwardDiff is actually quite fast
_f = Base.Fix2(prob.f, newprob.p)

Check warning on line 112 in src/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/ad.jl#L112

Added line #L112 was not covered by tests
if __is_extension_loaded(Val(:Zygote)) && (length(uu) + length(sol.resid) 50)
_F = @closure (u, p) -> __zygote_compute_nlls_vjp(prob.f, u, p)
# TODO: Remove once DI has the value_and_pullback_split defined
_F = @closure (u, p) -> __zygote_compute_nlls_vjp(_f, u, p)

Check warning on line 115 in src/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/ad.jl#L115

Added line #L115 was not covered by tests
else
_F = @closure (u, p) -> begin
T = promote_type(eltype(u), eltype(p))
res = DiffResults.DiffResult(similar(u, T, size(sol.resid)),
similar(u, T, length(sol.resid), length(u)))
ForwardDiff.jacobian!(res, Base.Fix2(prob.f, p), u)
return reshape(
2 .* vec(DiffResults.value(res))' * DiffResults.jacobian(res),
size(u))
_f = Base.Fix2(prob.f, p)
v, J = DI.value_and_jacobian(_f, AutoForwardDiff(), u)
return reshape(2 .* vec(v)' * J, size(u))

Check warning on line 120 in src/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/ad.jl#L118-L120

Added lines #L118 - L120 were not covered by tests
end
end
end
end

f_p = __nlsolve_∂f_∂p(prob, _F, uu, p)
f_x = __nlsolve_∂f_∂u(prob, _F, uu, p)
f_p = __nlsolve_∂f_∂p(prob, _F, uu, newprob.p)
f_x = __nlsolve_∂f_∂u(prob, _F, uu, newprob.p)

Check warning on line 127 in src/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/ad.jl#L126-L127

Added lines #L126 - L127 were not covered by tests

z_arr = -f_x \ f_p

pp = prob.p
sumfun = ((z, p),) -> map(zᵢ -> zᵢ * ForwardDiff.partials(p), z)
if uu isa Number
partials = sum(sumfun, zip(z_arr, pp))
elseif p isa Number
elseif pp isa Number

Check warning on line 135 in src/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/ad.jl#L135

Added line #L135 was not covered by tests
partials = sumfun((z_arr, pp))
else
partials = sum(sumfun, zip(eachcol(z_arr), pp))
Expand All @@ -164,7 +144,7 @@ end
@inline function __nlsolve_∂f_∂p(prob, f::F, u, p) where {F}
if isinplace(prob)
__f = p -> begin
du = similar(u, promote_type(eltype(u), eltype(p)))
du = __similar(u, promote_type(eltype(u), eltype(p)))

Check warning on line 147 in src/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/ad.jl#L147

Added line #L147 was not covered by tests
f(du, u, p)
return du
end
Expand All @@ -182,16 +162,12 @@ end

@inline function __nlsolve_∂f_∂u(prob, f::F, u, p) where {F}
if isinplace(prob)
du = similar(u)
__f = (du, u) -> f(du, u, p)
ForwardDiff.jacobian(__f, du, u)
__f = @closure (du, u) -> f(du, u, p)
return ForwardDiff.jacobian(__f, __similar(u), u)

Check warning on line 166 in src/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/ad.jl#L165-L166

Added lines #L165 - L166 were not covered by tests
else
__f = Base.Fix2(f, p)
if u isa Number
return ForwardDiff.derivative(__f, u)
else
return ForwardDiff.jacobian(__f, u)
end
u isa Number && return ForwardDiff.derivative(__f, u)
return ForwardDiff.jacobian(__f, u)

Check warning on line 170 in src/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/ad.jl#L169-L170

Added lines #L169 - L170 were not covered by tests
end
end

Expand Down
6 changes: 3 additions & 3 deletions src/nlsolve/halley.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleHalley, args...;
@bb xo = copy(x)

if setindex_trait(x) === CanSetindex()
A = similar(x, length(x), length(x))
Aaᵢ = similar(x, length(x))
cᵢ = similar(x)
A = __similar(x, length(x), length(x))
Aaᵢ = __similar(x, length(x))
cᵢ = __similar(x)
else
A = x
Aaᵢ = x
Expand Down
6 changes: 3 additions & 3 deletions src/nlsolve/lbroyden.jl
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ end
return :(return SVector{$N, $T}(($(getcalls...))))
end

__lbroyden_threshold_cache(x, ::Val{threshold}) where {threshold} = similar(x, threshold)
__lbroyden_threshold_cache(x, ::Val{threshold}) where {threshold} = __similar(x, threshold)
function __lbroyden_threshold_cache(x::StaticArray, ::Val{threshold}) where {threshold}
return zeros(MArray{Tuple{threshold}, eltype(x)})
end
Expand All @@ -298,7 +298,7 @@ end
end
end
function __init_low_rank_jacobian(u, fu, ::Val{threshold}) where {threshold}
Vᵀ = similar(u, threshold, length(u))
U = similar(u, length(fu), threshold)
Vᵀ = __similar(u, threshold, length(u))
U = __similar(u, length(fu), threshold)
return U, Vᵀ
end
32 changes: 23 additions & 9 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ Return the maximum of `a` and `b` if `x1 > x0`, otherwise return the minimum.
"""
__max_tdir(a, b, x0, x1) = ifelse(x1 > x0, max(a, b), min(a, b))

function __fixed_parameter_function(prob::NonlinearProblem)
function __fixed_parameter_function(prob::AbstractNonlinearProblem)
isinplace(prob) && return @closure (du, u) -> prob.f(du, u, prob.p)
return Base.Fix2(prob.f, prob.p)
end

function value_and_jacobian(
ad, prob::NonlinearProblem, f::F, y, x, cache; J = nothing) where {F}
ad, prob::AbstractNonlinearProblem, f::F, y, x, cache; J = nothing) where {F}
x isa Number && return DI.value_and_derivative(f, ad, x, cache)

if isinplace(prob)
Expand All @@ -46,29 +46,30 @@ function value_and_jacobian(
end
end

function jacobian_cache(ad, prob::NonlinearProblem, f::F, y, x) where {F}
function jacobian_cache(ad, prob::AbstractNonlinearProblem, f::F, y, x) where {F}
x isa Number && return (nothing, DI.prepare_derivative(f, ad, x))

if isinplace(prob)
J = similar(y, length(y), length(x))
J = __similar(y, length(y), length(x))
SciMLBase.has_jac(prob.f) && return J, HasAnalyticJacobian()
return J, DI.prepare_jacobian(f, y, ad, x)
else
SciMLBase.has_jac(prob.f) && return nothing, HasAnalyticJacobian()
J = ArrayInterface.can_setindex(x) ? similar(y, length(y), length(x)) : nothing
J = ArrayInterface.can_setindex(x) ? __similar(y, length(y), length(x)) : nothing
return J, DI.prepare_jacobian(f, ad, x)
end
end

function compute_jacobian_and_hessian(ad, prob::NonlinearProblem, f::F, y, x) where {F}
function compute_jacobian_and_hessian(
ad, prob::AbstractNonlinearProblem, f::F, y, x) where {F}
if x isa Number
df = @closure x -> DI.derivative(f, ad, x)
return f(x), df(x), DI.derivative(df, ad, x)
end

if isinplace(prob)
df = @closure x -> begin

Check warning on line 71 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L71

Added line #L71 was not covered by tests
res = similar(y, promote_type(eltype(y), eltype(x)))
res = __similar(y, promote_type(eltype(y), eltype(x)))
return DI.jacobian(f, res, ad, x)
end
J, H = DI.value_and_jacobian(df, ad, x)
Expand All @@ -83,7 +84,7 @@ end
__init_identity_jacobian(u::Number, fu, α = true) = oftype(u, α)
__init_identity_jacobian!!(J::Number) = one(J)
function __init_identity_jacobian(u, fu, α = true)
J = similar(u, promote_type(eltype(u), eltype(fu)), length(fu), length(u))
J = __similar(u, promote_type(eltype(u), eltype(fu)), length(fu), length(u))
fill!(J, zero(eltype(J)))
J[diagind(J)] .= eltype(J)(α)
return J
Expand Down Expand Up @@ -129,7 +130,7 @@ end
T = eltype(x)
return T.(f.resid_prototype)
else
fx = similar(x)
fx = __similar(x)
f(fx, x, p)
return fx
end
Expand Down Expand Up @@ -242,3 +243,16 @@ end

# Extension
function __zygote_compute_nlls_vjp end

function __similar(x, args...; kwargs...)
y = similar(x, args...; kwargs...)
return __init_bigfloat_array!!(y)
end

function __init_bigfloat_array!!(x)
if ArrayInterface.can_setindex(x)
eltype(x) <: BigFloat && fill!(x, BigFloat(0))
return x
end
return x

Check warning on line 257 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L257

Added line #L257 was not covered by tests
end

0 comments on commit ae0bf10

Please sign in to comment.