Skip to content

Add verbose for collocation methods #290

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ using BoundaryValueDiffEqCore: BVPJacobianAlgorithm, __extract_problem_details,
__concrete_nonlinearsolve_algorithm,
__internal_nlsolve_problem, BoundaryValueDiffEqAlgorithm,
__vec, __vec_f, __vec_f!, __vec_bc, __vec_bc!,
__extract_mesh, get_dense_ad, __get_bcresid_prototype
__extract_mesh, get_dense_ad, __get_bcresid_prototype,
__split_kwargs
using ConcreteStructs: @concrete
using DiffEqBase: DiffEqBase
using DifferentiationInterface: DifferentiationInterface, Constant, prepare_jacobian
Expand Down
31 changes: 16 additions & 15 deletions lib/BoundaryValueDiffEqAscher/src/ascher.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ function get_fixed_points(prob::BVProblem, alg::AbstractAscher)
end

function SciMLBase.__init(prob::BVProblem, alg::AbstractAscher; dt = 0.0,
adaptive = true, abstol = 1e-4, kwargs...)
adaptive = true, abstol = 1e-4, verbose = true, kwargs...)
(; tspan, p) = prob
_, T, ncy, n, u0 = __extract_problem_details(prob; dt, check_positive_dt = true)
t₀, t₁ = tspan
Expand Down Expand Up @@ -145,27 +145,24 @@ function SciMLBase.__init(prob::BVProblem, alg::AbstractAscher; dt = 0.0,
cache = AscherCache{iip, T}(
prob, f, jac, bc, bcjac, k, copy(mesh), mesh, mesh_dt, ncomp, ny, p,
zeta, fixpnt, alg, prob.problem_type, bcresid_prototype, residual,
zval, yval, gval, err, g, w, v, lz, ly, dmz, delz, deldmz, dqdmz,
dmv, pvtg, pvtw, TU, valst, (; abstol, dt, adaptive, kwargs...))
zval, yval, gval, err, g, w, v, lz, ly, dmz, delz, deldmz, dqdmz, dmv,
pvtg, pvtw, TU, valst, (; abstol, dt, adaptive, verbose, kwargs...))
return cache
end

function __split_ascher_kwargs(; abstol, dt, adaptive = true, kwargs...)
return ((abstol, adaptive, dt), (; abstol, adaptive, kwargs...))
end

function SciMLBase.solve!(cache::AscherCache{iip, T}) where {iip, T}
(abstol, adaptive, _), kwargs = __split_ascher_kwargs(; cache.kwargs...)
(abstol, adaptive, verbose, _), kwargs = __split_kwargs(; cache.kwargs...)
info::ReturnCode.T = ReturnCode.Success

# We do the first iteration outside the loop to preserve type-stability of the
# `original` field of the solution
z, y, info, error_norm = __perform_ascher_iteration(cache, abstol, adaptive; kwargs...)
z, y, info, error_norm = __perform_ascher_iteration(
cache, abstol, adaptive, verbose; kwargs...)

if adaptive
while SciMLBase.successful_retcode(info) && norm(error_norm) > abstol
z, y, info, error_norm = __perform_ascher_iteration(
cache, abstol, adaptive; kwargs...)
cache, abstol, adaptive, verbose; kwargs...)
end
end
u = [vcat(zᵢ, yᵢ) for (zᵢ, yᵢ) in zip(z, y)]
Expand All @@ -174,12 +171,13 @@ function SciMLBase.solve!(cache::AscherCache{iip, T}) where {iip, T}
cache.prob, cache.alg, cache.original_mesh, u; retcode = info)
end

function __perform_ascher_iteration(cache::AscherCache{iip, T}, abstol, adaptive::Bool;
nlsolve_kwargs = (;), kwargs...) where {iip, T}
function __perform_ascher_iteration(cache::AscherCache{iip, T}, abstol, adaptive::Bool,
verbose::Bool; nlsolve_kwargs = (;), kwargs...) where {iip, T}
info::ReturnCode.T = ReturnCode.Success
nlprob = __construct_nlproblem(cache)
nlsolve_alg = __concrete_nonlinearsolve_algorithm(nlprob, cache.alg.nlsolve)
nlsol = __solve(nlprob, nlsolve_alg; abstol, kwargs..., nlsolve_kwargs...)
nlsol = __solve(nlprob, nlsolve_alg; abstol = abstol,
verbose = verbose, kwargs..., nlsolve_kwargs...)
error_norm = 2 * abstol
info = nlsol.retcode

Expand Down Expand Up @@ -207,16 +205,19 @@ function __perform_ascher_iteration(cache::AscherCache{iip, T}, abstol, adaptive
__expand_cache_for_error!(cache)

_nlprob = __construct_nlproblem(cache)
nlsol = __solve(_nlprob, nlsolve_alg; abstol, kwargs..., nlsolve_kwargs...)
nlsol = __solve(_nlprob, nlsolve_alg; abstol = abstol,
verbose = verbose, kwargs..., nlsolve_kwargs...)

error_norm = error_estimate!(cache)
if norm(error_norm) > abstol
verbose && @warn "Global error norm bigger than tolerance, refining mesh"
mesh_selector!(cache, z, dmz, mesh, mesh_dt, abstol)
__expand_cache_for_next_iter!(cache)
end
else # Something bad happened
if 2 * (length(cache.mesh) - 1) > cache.alg.max_num_subintervals
# The solving process failed
# New mesh would be too large
verbose && @warn "Mesh being too large and still failing to solve, exiting"
info = ReturnCode.Failure
else
# doesn't need to halve the mesh again, just use the expanded cache
Expand Down
5 changes: 5 additions & 0 deletions lib/BoundaryValueDiffEqCore/src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -489,3 +489,8 @@ function _sparse_like(I, J, x::AbstractArray, m = maximum(I), n = maximum(J))
V = __ones_like(x, length(I))
return sparse(I′, J′, V, m, n)
end

# Keywords processing
function __split_kwargs(; abstol, dt, adaptive = true, verbose = true, kwargs...)
return ((abstol, adaptive, verbose, dt), (; abstol, adaptive, verbose, kwargs...))
end
3 changes: 2 additions & 1 deletion lib/BoundaryValueDiffEqFIRK/src/BoundaryValueDiffEqFIRK.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ using BoundaryValueDiffEqCore: BoundaryValueDiffEqAlgorithm, BVPJacobianAlgorith
MaybeDiffCache, __extract_mesh, __extract_u0,
__has_initial_guess, __initial_guess_length,
__initial_guess_on_mesh, __flatten_initial_guess,
__build_solution, __Fix3, _sparse_like, get_dense_ad
__build_solution, __Fix3, _sparse_like, get_dense_ad,
__split_kwargs

using ConcreteStructs: @concrete
using DiffEqBase: DiffEqBase
Expand Down
2 changes: 1 addition & 1 deletion lib/BoundaryValueDiffEqFIRK/src/adaptivity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ Generate new mesh based on the defect.
@views function mesh_selector!(cache::Union{
FIRKCacheExpand{iip, T}, FIRKCacheNested{iip, T}}) where {iip, T}
(; order, defect, mesh, mesh_dt) = cache
(abstol, _, _), kwargs = __split_mirk_kwargs(; cache.kwargs...)
(abstol, _, _, _), kwargs = __split_kwargs(; cache.kwargs...)
N = length(mesh)

safety_factor = T(1.3)
Expand Down
56 changes: 29 additions & 27 deletions lib/BoundaryValueDiffEqFIRK/src/firk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,18 +79,18 @@ function shrink_y(y, N, stage)
end

function SciMLBase.__init(prob::BVProblem, alg::AbstractFIRK; dt = 0.0,
abstol = 1e-3, adaptive = true, kwargs...)
abstol = 1e-3, adaptive = true, verbose = true, kwargs...)
if alg.nested_nlsolve
return init_nested(
prob, alg; dt = dt, abstol = abstol, adaptive = adaptive, kwargs...)
return init_nested(prob, alg; dt = dt, abstol = abstol,
adaptive = adaptive, verbose = verbose, kwargs...)
else
return init_expanded(
prob, alg; dt = dt, abstol = abstol, adaptive = adaptive, kwargs...)
return init_expanded(prob, alg; dt = dt, abstol = abstol,
adaptive = adaptive, verbose = verbose, kwargs...)
end
end

function init_nested(prob::BVProblem, alg::AbstractFIRK; dt = 0.0,
abstol = 1e-3, adaptive = true, kwargs...)
abstol = 1e-3, adaptive = true, verbose = true, kwargs...)
@set! alg.jac_alg = concrete_jacobian_algorithm(alg.jac_alg, prob, alg)

iip = isinplace(prob)
Expand Down Expand Up @@ -177,13 +177,13 @@ function init_nested(prob::BVProblem, alg::AbstractFIRK; dt = 0.0,

return FIRKCacheNested{iip, T}(
alg_order(alg), stage, M, size(X), f, bc, prob_, prob.problem_type,
prob.p, alg, TU, ITU, bcresid_prototype, mesh, mesh_dt,
k_discrete, y, y₀, residual, fᵢ_cache, fᵢ₂_cache, defect, nestprob,
nest_tol, resid₁_size, (; abstol, dt, adaptive, kwargs...))
prob.p, alg, TU, ITU, bcresid_prototype, mesh, mesh_dt, k_discrete,
y, y₀, residual, fᵢ_cache, fᵢ₂_cache, defect, nestprob, nest_tol,
resid₁_size, (; abstol, dt, adaptive, verbose, kwargs...))
end

function init_expanded(prob::BVProblem, alg::AbstractFIRK; dt = 0.0,
abstol = 1e-3, adaptive = true, kwargs...)
abstol = 1e-3, adaptive = true, verbose = true, kwargs...)
@set! alg.jac_alg = concrete_jacobian_algorithm(alg.jac_alg, prob, alg)

if adaptive && isa(alg, FIRKNoAdaptivity)
Expand Down Expand Up @@ -258,9 +258,9 @@ function init_expanded(prob::BVProblem, alg::AbstractFIRK; dt = 0.0,
prob_ = !(prob.u0 isa AbstractArray) ? remake(prob; u0 = X) : prob

return FIRKCacheExpand{iip, T}(
alg_order(alg), stage, M, size(X), f, bc, prob_, prob.problem_type, prob.p,
alg, TU, ITU, bcresid_prototype, mesh, mesh_dt, k_discrete, y, y₀, residual,
fᵢ_cache, fᵢ₂_cache, defect, resid₁_size, (; abstol, dt, adaptive, kwargs...))
alg_order(alg), stage, M, size(X), f, bc, prob_, prob.problem_type, prob.p, alg,
TU, ITU, bcresid_prototype, mesh, mesh_dt, k_discrete, y, y₀, residual, fᵢ_cache,
fᵢ₂_cache, defect, resid₁_size, (; abstol, dt, adaptive, verbose, kwargs...))
end

"""
Expand Down Expand Up @@ -289,23 +289,19 @@ function __expand_cache!(cache::FIRKCacheNested)
return cache
end

function __split_mirk_kwargs(; abstol, dt, adaptive = true, kwargs...)
return ((abstol, adaptive, dt), (; abstol, adaptive, kwargs...))
end

function SciMLBase.solve!(cache::FIRKCacheExpand)
(abstol, adaptive, _), kwargs = __split_mirk_kwargs(; cache.kwargs...)
(abstol, adaptive, verbose, _), kwargs = __split_kwargs(; cache.kwargs...)
info::ReturnCode.T = ReturnCode.Success

# We do the first iteration outside the loop to preserve type-stability of the
# `original` field of the solution
sol_nlprob, info, defect_norm = __perform_firk_iteration(
cache, abstol, adaptive; kwargs...)
cache, abstol, adaptive, verbose; kwargs...)

if adaptive
while SciMLBase.successful_retcode(info) && defect_norm > abstol
sol_nlprob, info, defect_norm = __perform_firk_iteration(
cache, abstol, adaptive; kwargs...)
cache, abstol, adaptive, verbose; kwargs...)
end
end

Expand All @@ -320,18 +316,18 @@ function SciMLBase.solve!(cache::FIRKCacheExpand)
end

function SciMLBase.solve!(cache::FIRKCacheNested)
(abstol, adaptive, _), kwargs = __split_mirk_kwargs(; cache.kwargs...)
(abstol, adaptive, verbose, _), kwargs = __split_kwargs(; cache.kwargs...)
info::ReturnCode.T = ReturnCode.Success

# We do the first iteration outside the loop to preserve type-stability of the
# `original` field of the solution
sol_nlprob, info, defect_norm = __perform_firk_iteration(
cache, abstol, adaptive; kwargs...)
cache, abstol, adaptive, verbose; kwargs...)

if adaptive
while SciMLBase.successful_retcode(info) && defect_norm > abstol
sol_nlprob, info, defect_norm = __perform_firk_iteration(
cache, abstol, adaptive; kwargs...)
cache, abstol, adaptive, verbose; kwargs...)
end
end

Expand All @@ -345,11 +341,11 @@ function SciMLBase.solve!(cache::FIRKCacheNested)
end

function __perform_firk_iteration(cache::Union{FIRKCacheExpand, FIRKCacheNested}, abstol,
adaptive::Bool; nlsolve_kwargs = (;), kwargs...)
adaptive::Bool, verbose::Bool; nlsolve_kwargs = (;), kwargs...)
nlprob = __construct_nlproblem(cache, vec(cache.y₀), copy(cache.y₀))
nlsolve_alg = __concrete_nonlinearsolve_algorithm(nlprob, cache.alg.nlsolve)
sol_nlprob = __solve(
nlprob, nlsolve_alg; abstol, kwargs..., nlsolve_kwargs..., alias_u0 = true)
sol_nlprob = __solve(nlprob, nlsolve_alg; abstol = abstol, verbose = verbose,
kwargs..., nlsolve_kwargs..., alias_u0 = true)
recursive_unflatten!(cache.y₀, sol_nlprob.u)

defect_norm = 2 * abstol
Expand All @@ -362,11 +358,16 @@ function __perform_firk_iteration(cache::Union{FIRKCacheExpand, FIRKCacheNested}
if info == ReturnCode.Success # Nonlinear Solve was successful
defect_norm = defect_estimate!(cache)
# The defect is greater than 10%, the solution is not acceptable
defect_norm > cache.alg.defect_threshold && (info = ReturnCode.Failure)
if defect_norm > cache.alg.defect_threshold
verbose &&
@warn "Defect norm is $defect_norm, bigger than threshold $(cache.alg.defect_threshold), halving mesh"
info = ReturnCode.Failure
end
end

if info == ReturnCode.Success # Nonlinear Solve Successful and defect norm is acceptable
if defect_norm > abstol
verbose && @warn "Defect norm bigger than tolerance, refining mesh"
# We construct a new mesh to equidistribute the defect
mesh, mesh_dt, _, info = mesh_selector!(cache)
if info == ReturnCode.Success
Expand All @@ -381,6 +382,7 @@ function __perform_firk_iteration(cache::Union{FIRKCacheExpand, FIRKCacheNested}
# We cannot obtain a solution for the current mesh
if 2 * (length(cache.mesh) - 1) > cache.alg.max_num_subintervals
# New mesh would be too large
verbose && @warn "Mesh being too large and still failing to solve, exiting"
info = ReturnCode.Failure
else
half_mesh!(cache)
Expand Down
2 changes: 1 addition & 1 deletion lib/BoundaryValueDiffEqMIRK/src/BoundaryValueDiffEqMIRK.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ using BoundaryValueDiffEqCore: BoundaryValueDiffEqAlgorithm, BVPJacobianAlgorith
__extract_mesh, __extract_u0, __has_initial_guess,
__initial_guess_length, __initial_guess_on_mesh,
__flatten_initial_guess, __build_solution, __Fix3,
get_dense_ad
get_dense_ad, __split_kwargs

using ConcreteStructs: @concrete
using DiffEqBase: DiffEqBase
Expand Down
2 changes: 1 addition & 1 deletion lib/BoundaryValueDiffEqMIRK/src/adaptivity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ Generate new mesh based on the defect.
"""
@views function mesh_selector!(cache::MIRKCache{iip, T}) where {iip, T}
(; order, defect, mesh, mesh_dt) = cache
(abstol, _, _), kwargs = __split_mirk_kwargs(; cache.kwargs...)
(abstol, _, _, _), kwargs = __split_kwargs(; cache.kwargs...)
N = length(mesh)

safety_factor = T(1.3)
Expand Down
34 changes: 18 additions & 16 deletions lib/BoundaryValueDiffEqMIRK/src/mirk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ end
Base.eltype(::MIRKCache{iip, T}) where {iip, T} = T

function SciMLBase.__init(prob::BVProblem, alg::AbstractMIRK; dt = 0.0,
abstol = 1e-3, adaptive = true, kwargs...)
abstol = 1e-3, adaptive = true, verbose = true, kwargs...)
@set! alg.jac_alg = concrete_jacobian_algorithm(alg.jac_alg, prob, alg)
iip = isinplace(prob)

Expand Down Expand Up @@ -104,9 +104,9 @@ function SciMLBase.__init(prob::BVProblem, alg::AbstractMIRK; dt = 0.0,

return MIRKCache{iip, T}(
alg_order(alg), stage, N, size(X), f, bc, prob_, prob.problem_type,
prob.p, alg, TU, ITU, bcresid_prototype, mesh, mesh_dt,
k_discrete, k_interp, y, y₀, residual, fᵢ_cache, fᵢ₂_cache, defect,
new_stages, resid₁_size, (; abstol, dt, adaptive, kwargs...))
prob.p, alg, TU, ITU, bcresid_prototype, mesh, mesh_dt, k_discrete,
k_interp, y, y₀, residual, fᵢ_cache, fᵢ₂_cache, defect, new_stages,
resid₁_size, (; abstol, dt, adaptive, verbose, kwargs...))
end

"""
Expand All @@ -127,23 +127,19 @@ function __expand_cache!(cache::MIRKCache)
return cache
end

function __split_mirk_kwargs(; abstol, dt, adaptive = true, kwargs...)
return ((abstol, adaptive, dt), (; abstol, adaptive, kwargs...))
end

function SciMLBase.solve!(cache::MIRKCache)
(abstol, adaptive, _), kwargs = __split_mirk_kwargs(; cache.kwargs...)
(abstol, adaptive, verbose, _), kwargs = __split_kwargs(; cache.kwargs...)
info::ReturnCode.T = ReturnCode.Success

# We do the first iteration outside the loop to preserve type-stability of the
# `original` field of the solution
sol_nlprob, info, defect_norm = __perform_mirk_iteration(
cache, abstol, adaptive; kwargs...)
cache, abstol, adaptive, verbose; kwargs...)

if adaptive
while SciMLBase.successful_retcode(info) && defect_norm > abstol
sol_nlprob, info, defect_norm = __perform_mirk_iteration(
cache, abstol, adaptive; kwargs...)
cache, abstol, adaptive, verbose; kwargs...)
end
end

Expand All @@ -156,12 +152,12 @@ function SciMLBase.solve!(cache::MIRKCache)
return __build_solution(cache.prob, odesol, sol_nlprob)
end

function __perform_mirk_iteration(
cache::MIRKCache, abstol, adaptive::Bool; nlsolve_kwargs = (;), kwargs...)
function __perform_mirk_iteration(cache::MIRKCache, abstol, adaptive::Bool,
verbose::Bool; nlsolve_kwargs = (;), kwargs...)
nlprob = __construct_nlproblem(cache, vec(cache.y₀), copy(cache.y₀))
nlsolve_alg = __concrete_nonlinearsolve_algorithm(nlprob, cache.alg.nlsolve)
sol_nlprob = __solve(
nlprob, nlsolve_alg; abstol, kwargs..., nlsolve_kwargs..., alias_u0 = true)
sol_nlprob = __solve(nlprob, nlsolve_alg; abstol = abstol, verbose = verbose,
kwargs..., nlsolve_kwargs..., alias_u0 = true)
recursive_unflatten!(cache.y₀, sol_nlprob.u)

defect_norm = 2 * abstol
Expand All @@ -174,11 +170,16 @@ function __perform_mirk_iteration(
if info == ReturnCode.Success # Nonlinear Solve was successful
defect_norm = defect_estimate!(cache)
# The defect is greater than 10%, the solution is not acceptable
defect_norm > cache.alg.defect_threshold && (info = ReturnCode.Failure)
if defect_norm > cache.alg.defect_threshold
verbose &&
@warn "Defect norm is $defect_norm, bigger than threshold $(cache.alg.defect_threshold), halving mesh"
info = ReturnCode.Failure
end
end

if info == ReturnCode.Success # Nonlinear Solve Successful and defect norm is acceptable
if defect_norm > abstol
verbose && @warn "Defect norm bigger than tolerance, refining mesh"
# We construct a new mesh to equidistribute the defect
mesh, mesh_dt, _, info = mesh_selector!(cache)
if info == ReturnCode.Success
Expand All @@ -193,6 +194,7 @@ function __perform_mirk_iteration(
# We cannot obtain a solution for the current mesh
if 2 * (length(cache.mesh) - 1) > cache.alg.max_num_subintervals
# New mesh would be too large
verbose && @warn "Mesh being too large and still failing to solve, exiting"
info = ReturnCode.Failure
else
half_mesh!(cache)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ using BoundaryValueDiffEqCore: BoundaryValueDiffEqAlgorithm, BVPJacobianAlgorith
__initial_guess_on_mesh, __flatten_initial_guess,
__build_solution, __Fix3, __default_sparse_ad,
__default_nonsparse_ad, get_dense_ad,
concrete_jacobian_algorithm
concrete_jacobian_algorithm, __split_kwargs

using ConcreteStructs: @concrete
using DiffEqBase: DiffEqBase
Expand Down
Loading
Loading