Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Start using termination conditions from DiffEqBase #208

Merged
merged 14 commits into from
Oct 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"

[targets]
test = ["Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary", "LeastSquaresOptim", "FastLevenbergMarquardt", "NaNMath", "BandedMatrices"]
test = ["Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary", "LeastSquaresOptim", "FastLevenbergMarquardt", "NaNMath", "BandedMatrices", "DiffEqBase"]
53 changes: 44 additions & 9 deletions src/broyden.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ end
f
alg
u
u_prev
du
fu
fu2
Expand All @@ -46,17 +47,21 @@ end
internalnorm
retcode::ReturnCode.T
abstol
reltol
reset_tolerance
reset_check
prob
stats::NLStats
lscache
termination_condition
tc_storage
end

get_fu(cache::GeneralBroydenCache) = cache.fu

function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::GeneralBroyden, args...;
alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM,
alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing,
termination_condition = nothing, internalnorm = DEFAULT_NORM,
kwargs...) where {uType, iip}
@unpack f, u0, p = prob
u = alias_u0 ? u0 : deepcopy(u0)
Expand All @@ -65,23 +70,38 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::GeneralBroyde
reset_tolerance = alg.reset_tolerance === nothing ? sqrt(eps(eltype(u))) :
alg.reset_tolerance
reset_check = x -> abs(x) ≤ reset_tolerance
return GeneralBroydenCache{iip}(f, alg, u, _mutable_zero(u), fu, zero(fu),

abstol, reltol, termination_condition = _init_termination_elements(abstol,
reltol,
termination_condition,
eltype(u))

mode = DiffEqBase.get_termination_mode(termination_condition)

storage = mode ∈ DiffEqBase.SAFE_TERMINATION_MODES ? NLSolveSafeTerminationResult() :
nothing
return GeneralBroydenCache{iip}(f, alg, u, zero(u), _mutable_zero(u), fu, zero(fu),
zero(fu), p, J⁻¹, zero(_reshape(fu, 1, :)), _mutable_zero(u), false, 0,
alg.max_resets, maxiters, internalnorm, ReturnCode.Default, abstol, reset_tolerance,
alg.max_resets, maxiters, internalnorm, ReturnCode.Default, abstol, reltol,
reset_tolerance,
reset_check, prob, NLStats(1, 0, 0, 0, 0),
init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip)))
init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip)), termination_condition,
storage)
end

function perform_step!(cache::GeneralBroydenCache{true})
@unpack f, p, du, fu, fu2, dfu, u, J⁻¹, J⁻¹df, J⁻¹₂ = cache
@unpack f, p, du, fu, fu2, dfu, u, u_prev, J⁻¹, J⁻¹df, J⁻¹₂, tc_storage = cache

termination_condition = cache.termination_condition(tc_storage)
T = eltype(u)

mul!(_vec(du), J⁻¹, -_vec(fu))
α = perform_linesearch!(cache.lscache, u, du)
_axpy!(α, du, u)
f(fu2, u, p)

cache.internalnorm(fu2) < cache.abstol && (cache.force_stop = true)
termination_condition(fu2, u, u_prev, cache.abstol, cache.reltol) &&
(cache.force_stop = true)
cache.stats.nf += 1

cache.force_stop && return nothing
Expand All @@ -106,20 +126,25 @@ function perform_step!(cache::GeneralBroydenCache{true})
mul!(J⁻¹, _vec(du), J⁻¹₂, 1, 1)
end
fu .= fu2
@. u_prev = u

return nothing
end

function perform_step!(cache::GeneralBroydenCache{false})
@unpack f, p = cache
@unpack f, p, tc_storage = cache

termination_condition = cache.termination_condition(tc_storage)

T = eltype(cache.u)

cache.du = _restructure(cache.du, cache.J⁻¹ * -_vec(cache.fu))
α = perform_linesearch!(cache.lscache, cache.u, cache.du)
cache.u = cache.u .+ α * cache.du
cache.fu2 = f(cache.u, p)

cache.internalnorm(cache.fu2) < cache.abstol && (cache.force_stop = true)
termination_condition(cache.fu2, cache.u, cache.u_prev, cache.abstol, cache.reltol) &&
(cache.force_stop = true)
cache.stats.nf += 1

cache.force_stop && return nothing
Expand All @@ -142,12 +167,15 @@ function perform_step!(cache::GeneralBroydenCache{false})
cache.J⁻¹ = cache.J⁻¹ .+ _vec(cache.du) * cache.J⁻¹₂
end
cache.fu = cache.fu2
cache.u_prev = @. cache.u

return nothing
end

function SciMLBase.reinit!(cache::GeneralBroydenCache{iip}, u0 = cache.u; p = cache.p,
abstol = cache.abstol, maxiters = cache.maxiters) where {iip}
abstol = cache.abstol, reltol = cache.reltol,
termination_condition = cache.termination_condition,
maxiters = cache.maxiters) where {iip}
cache.p = p
if iip
recursivecopy!(cache.u, u0)
Expand All @@ -157,7 +185,14 @@ function SciMLBase.reinit!(cache::GeneralBroydenCache{iip}, u0 = cache.u; p = ca
cache.u = u0
cache.fu = cache.f(cache.u, p)
end
termination_condition = _get_reinit_termination_condition(cache,
abstol,
reltol,
termination_condition)

cache.abstol = abstol
cache.reltol = reltol
cache.termination_condition = termination_condition
cache.maxiters = maxiters
cache.stats.nf = 1
cache.stats.nsteps = 1
Expand Down
41 changes: 34 additions & 7 deletions src/dfsane.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,16 @@ end
internalnorm
retcode::SciMLBase.ReturnCode.T
abstol
reltol
prob
stats::NLStats
termination_condition
tc_storage
end

function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::DFSane, args...;
alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM,
alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing,
termination_condition = nothing, internalnorm = DEFAULT_NORM,
kwargs...) where {uType, iip}
uₙ = alias_u0 ? prob.u0 : deepcopy(prob.u0)

Expand Down Expand Up @@ -122,14 +126,27 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::DFSane, args.
f₍ₙₒᵣₘ₎₀ = f₍ₙₒᵣₘ₎ₙ₋₁

ℋ = fill(f₍ₙₒᵣₘ₎ₙ₋₁, M)

abstol, reltol, termination_condition = _init_termination_elements(abstol,
reltol,
termination_condition,
T)

mode = DiffEqBase.get_termination_mode(termination_condition)

storage = mode ∈ DiffEqBase.SAFE_TERMINATION_MODES ? NLSolveSafeTerminationResult() :
nothing

return DFSaneCache{iip}(alg, uₙ, uₙ₋₁, fuₙ, fuₙ₋₁, 𝒹, ℋ, f₍ₙₒᵣₘ₎ₙ₋₁, f₍ₙₒᵣₘ₎₀,
M, σₙ, σₘᵢₙ, σₘₐₓ, α₁, γ, τₘᵢₙ, τₘₐₓ, nₑₓₚ, p, false, maxiters,
internalnorm, ReturnCode.Default, abstol, prob, NLStats(1, 0, 0, 0, 0))
internalnorm, ReturnCode.Default, abstol, reltol, prob, NLStats(1, 0, 0, 0, 0),
termination_condition, storage)
end

function perform_step!(cache::DFSaneCache{true})
@unpack alg, f₍ₙₒᵣₘ₎ₙ₋₁, f₍ₙₒᵣₘ₎₀, σₙ, σₘᵢₙ, σₘₐₓ, α₁, γ, τₘᵢₙ, τₘₐₓ, nₑₓₚ, M = cache
@unpack alg, f₍ₙₒᵣₘ₎ₙ₋₁, f₍ₙₒᵣₘ₎₀, σₙ, σₘᵢₙ, σₘₐₓ, α₁, γ, τₘᵢₙ, τₘₐₓ, nₑₓₚ, M, tc_storage = cache

termination_condition = cache.termination_condition(tc_storage)
f = (dx, x) -> cache.prob.f(dx, x, cache.p)

T = eltype(cache.uₙ)
Expand Down Expand Up @@ -174,7 +191,7 @@ function perform_step!(cache::DFSaneCache{true})
f₍ₙₒᵣₘ₎ₙ = norm(cache.fuₙ)^nₑₓₚ
end

if cache.internalnorm(cache.fuₙ) < cache.abstol
if termination_condition(cache.fuₙ, cache.uₙ, cache.uₙ₋₁, cache.abstol, cache.reltol)
cache.force_stop = true
end

Expand Down Expand Up @@ -205,8 +222,9 @@ function perform_step!(cache::DFSaneCache{true})
end

function perform_step!(cache::DFSaneCache{false})
@unpack alg, f₍ₙₒᵣₘ₎ₙ₋₁, f₍ₙₒᵣₘ₎₀, σₙ, σₘᵢₙ, σₘₐₓ, α₁, γ, τₘᵢₙ, τₘₐₓ, nₑₓₚ, M = cache
@unpack alg, f₍ₙₒᵣₘ₎ₙ₋₁, f₍ₙₒᵣₘ₎₀, σₙ, σₘᵢₙ, σₘₐₓ, α₁, γ, τₘᵢₙ, τₘₐₓ, nₑₓₚ, M, tc_storage = cache

termination_condition = cache.termination_condition(tc_storage)
f = x -> cache.prob.f(x, cache.p)

T = eltype(cache.uₙ)
Expand Down Expand Up @@ -249,7 +267,7 @@ function perform_step!(cache::DFSaneCache{false})
f₍ₙₒᵣₘ₎ₙ = norm(cache.fuₙ)^nₑₓₚ
end

if cache.internalnorm(cache.fuₙ) < cache.abstol
if termination_condition(cache.fuₙ, cache.uₙ, cache.uₙ₋₁, cache.abstol, cache.reltol)
cache.force_stop = true
end

Expand Down Expand Up @@ -296,7 +314,9 @@ function SciMLBase.solve!(cache::DFSaneCache)
end

function SciMLBase.reinit!(cache::DFSaneCache{iip}, u0 = cache.uₙ; p = cache.p,
abstol = cache.abstol, maxiters = cache.maxiters) where {iip}
abstol = cache.abstol, reltol = cache.reltol,
termination_condition = cache.termination_condition,
maxiters = cache.maxiters) where {iip}
cache.p = p
if iip
recursivecopy!(cache.uₙ, u0)
Expand All @@ -317,7 +337,14 @@ function SciMLBase.reinit!(cache::DFSaneCache{iip}, u0 = cache.uₙ; p = cache.p
T = eltype(cache.uₙ)
cache.σₙ = T(cache.alg.σ_1)

termination_condition = _get_reinit_termination_condition(cache,
abstol,
reltol,
termination_condition)

cache.abstol = abstol
cache.reltol = reltol
cache.termination_condition = termination_condition
cache.maxiters = maxiters
cache.stats.nf = 1
cache.stats.nsteps = 1
Expand Down
62 changes: 52 additions & 10 deletions src/gaussnewton.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,16 @@
function GaussNewton(; concrete_jac = nothing, linsolve = nothing,
precs = DEFAULT_PRECS, adkwargs...)
ad = default_adargs_to_adtype(; adkwargs...)
return GaussNewton{_unwrap_val(concrete_jac)}(ad, linsolve, precs)
return GaussNewton{_unwrap_val(concrete_jac)}(ad,
linsolve,
precs)
end

@concrete mutable struct GaussNewtonCache{iip} <: AbstractNonlinearSolveCache{iip}
f
alg
u
u_prev
fu1
fu2
fu_new
Expand All @@ -72,12 +75,17 @@
internalnorm
retcode::ReturnCode.T
abstol
reltol
prob
stats::NLStats
tc_storage
termination_condition
end

function SciMLBase.__init(prob::NonlinearLeastSquaresProblem{uType, iip}, alg_::GaussNewton,
args...; alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM,
args...; alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing,
termination_condition = nothing,
internalnorm = DEFAULT_NORM,
kwargs...) where {uType, iip}
alg = get_concrete_algorithm(alg_, prob)
@unpack f, u0, p = prob
Expand All @@ -101,15 +109,29 @@
JᵀJ, Jᵀf = nothing, nothing
end

return GaussNewtonCache{iip}(f, alg, u, fu1, fu2, zero(fu1), du, p, uf, linsolve, J,
abstol, reltol, termination_condition = _init_termination_elements(abstol,
reltol,
termination_condition,
eltype(u); mode = NLSolveTerminationMode.AbsNorm)

mode = DiffEqBase.get_termination_mode(termination_condition)

storage = mode ∈ DiffEqBase.SAFE_TERMINATION_MODES ? NLSolveSafeTerminationResult() :
nothing

return GaussNewtonCache{iip}(f, alg, u, copy(u), fu1, fu2, zero(fu1), du, p, uf,
linsolve, J,
JᵀJ, Jᵀf, jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol,
prob, NLStats(1, 0, 0, 0, 0))
reltol,
prob, NLStats(1, 0, 0, 0, 0), storage, termination_condition)
end

function perform_step!(cache::GaussNewtonCache{true})
@unpack u, fu1, f, p, alg, J, JᵀJ, Jᵀf, linsolve, du = cache
@unpack u, u_prev, fu1, f, p, alg, J, JᵀJ, Jᵀf, linsolve, du, tc_storage = cache
jacobian!!(J, cache)

termination_condition = cache.termination_condition(tc_storage)

if JᵀJ !== nothing
__matmul!(JᵀJ, J', J)
__matmul!(Jᵀf, J', fu1)
Expand All @@ -127,9 +149,15 @@
@. u = u - du
f(cache.fu_new, u, p)

(cache.internalnorm(cache.fu_new .- cache.fu1) < cache.abstol ||
cache.internalnorm(cache.fu_new) < cache.abstol) &&
(termination_condition(cache.fu_new .- cache.fu1,
cache.u,
u_prev,
cache.abstol,
cache.reltol) ||
termination_condition(cache.fu_new, cache.u, u_prev, cache.abstol, cache.reltol)) &&
(cache.force_stop = true)

@. u_prev = u
cache.fu1 .= cache.fu_new
cache.stats.nf += 1
cache.stats.njacs += 1
Expand All @@ -139,7 +167,9 @@
end

function perform_step!(cache::GaussNewtonCache{false})
@unpack u, fu1, f, p, alg, linsolve = cache
@unpack u, u_prev, fu1, f, p, alg, linsolve, tc_storage = cache

termination_condition = cache.termination_condition(tc_storage)

cache.J = jacobian!!(cache.J, cache)

Expand All @@ -164,7 +194,10 @@
cache.u = @. u - cache.du # `u` might not support mutation
cache.fu_new = f(cache.u, p)

(cache.internalnorm(cache.fu_new) < cache.abstol) && (cache.force_stop = true)
termination_condition(cache.fu_new, cache.u, u_prev, cache.abstol, cache.reltol) &&
(cache.force_stop = true)

cache.u_prev = @. cache.u
cache.fu1 = cache.fu_new
cache.stats.nf += 1
cache.stats.njacs += 1
Expand All @@ -174,7 +207,9 @@
end

function SciMLBase.reinit!(cache::GaussNewtonCache{iip}, u0 = cache.u; p = cache.p,
abstol = cache.abstol, maxiters = cache.maxiters) where {iip}
abstol = cache.abstol, reltol = cache.reltol,
termination_condition = cache.termination_condition,
maxiters = cache.maxiters) where {iip}
cache.p = p
if iip
recursivecopy!(cache.u, u0)
Expand All @@ -184,7 +219,14 @@
cache.u = u0
cache.fu1 = cache.f(cache.u, p)
end
termination_condition = _get_reinit_termination_condition(cache,

Check warning on line 222 in src/gaussnewton.jl

View check run for this annotation

Codecov / codecov/patch

src/gaussnewton.jl#L222

Added line #L222 was not covered by tests
abstol,
reltol,
termination_condition)

cache.abstol = abstol
cache.reltol = reltol
cache.termination_condition = termination_condition

Check warning on line 229 in src/gaussnewton.jl

View check run for this annotation

Codecov / codecov/patch

src/gaussnewton.jl#L228-L229

Added lines #L228 - L229 were not covered by tests
cache.maxiters = maxiters
cache.stats.nf = 1
cache.stats.nsteps = 1
Expand Down
Loading
Loading