Skip to content

Commit

Permalink
Update ODE retcode handling
Browse files Browse the repository at this point in the history
  • Loading branch information
gerlero committed Dec 29, 2023
1 parent 3a939b8 commit 77c3abc
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 47 deletions.
1 change: 0 additions & 1 deletion src/Fronts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ using PCHIPInterpolation: Interpolator, integrate
import NumericalIntegration
using RecipesBase

using OrdinaryDiffEq.SciMLBase: NullParameters
using OrdinaryDiffEq: ODEFunction, ODEProblem, ODESolution
using OrdinaryDiffEq: init, solve!, reinit!
using OrdinaryDiffEq: DiscreteCallback, terminate!
Expand Down
2 changes: 1 addition & 1 deletion src/ParamEstim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ module ParamEstim
import ..Fronts
using ..Fronts: InverseProblem, AbstractSemiinfiniteProblem, Solution, ReturnCode, solve
import ..Fronts: sorptivity
import ..Fronts.SciMLBase: successful_retcode, NullParameters

using LsqFit: curve_fit
import OrdinaryDiffEq.SciMLBase: successful_retcode, NullParameters

"""
ScaledSolution
Expand Down
22 changes: 12 additions & 10 deletions src/integration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ end
Transform `prob` into an ODE problem in terms of the Boltzmann variable `o`.
The ODE problem is set up to terminate automatically (`ReturnCode.Terminated`) when the steady state is reached.
The ODE problem is set up to terminate automatically (with `.retcode == ReturnCode.Success`) when the steady state is reached.
See also: [`DifferentialEquations`](https://diffeq.sciml.ai/stable/)
"""
Expand All @@ -45,7 +45,11 @@ function boltzmann(prob::Union{CauchyProblem, SorptivityCauchyProblem})
settled = DiscreteCallback(let direction = monotonicity(prob)
(u, t, integrator) -> direction * u[2] zero(u[2])
end,
terminate!,
function succeed!(integrator)
terminate!(integrator)
integrator.sol = SciMLBase.solution_new_retcode(integrator.sol,
ReturnCode.Success)
end,
save_positions = (false, false))

ODEProblem(boltzmann(prob.eq), u0, (ob, typemax(ob)), callback = settled)
Expand Down Expand Up @@ -112,16 +116,14 @@ function solve(prob::Union{CauchyProblem, SorptivityCauchyProblem},
verbose = true)
odesol = solve!(_init(prob, alg, verbose = verbose))

@assert odesol.retcode != ReturnCode.Success

if odesol.retcode != ReturnCode.Terminated
return Solution(odesol, prob, alg, _retcode = odesol.retcode, _niter = 1)
end

return Solution(odesol, prob, alg, _retcode = ReturnCode.Success, _niter = 1)
return Solution(odesol, prob, alg, _niter = 1)
end

function Solution(_odesol::ODESolution, _prob, _alg::BoltzmannODE; _retcode, _niter)
function Solution(_odesol::ODESolution,
_prob,
_alg::BoltzmannODE;
_retcode = _odesol.retcode,
_niter)
return Solution(o -> _odesol(o, idxs = 1),
_prob,
_alg,
Expand Down
8 changes: 4 additions & 4 deletions src/odes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@ See also: [`DifferentialEquations`](https://diffeq.sciml.ai/stable/), [`StaticAr
"""
function boltzmann(eq::DiffusionEquation{1})
let K = u -> conductivity(eq, u), C = u -> capacity(eq, u)
function f((u, du_do), ::NullParameters, o)
function f((u, du_do), ::SciMLBase.NullParameters, o)
K_, dK_du = value_and_derivative(K, u)

d²u_do² = -((C(u) * o / 2 + dK_du * du_do) / K_) * du_do

return @SVector [du_do, d²u_do²]
end
function jac((u, du_do), ::NullParameters, o)
function jac((u, du_do), ::SciMLBase.NullParameters, o)
K_, dK_du, d²K_du² = value_and_derivatives(K, u)
C_, dC_du = value_and_derivative(C, u)

Expand All @@ -36,14 +36,14 @@ end
function boltzmann(eq::DiffusionEquation{m}) where {m}
@assert m in 2:3
let K = u -> conductivity(eq, u), C = u -> capacity(eq, u), k = m - 1
function f((u, du_do), ::NullParameters, o)
function f((u, du_do), ::SciMLBase.NullParameters, o)
K_, dK_du = value_and_derivative(K, u)

d²u_do² = -((C(u) * o / 2 + dK_du * du_do) / K_ + k / o) * du_do

return @SVector [du_do, d²u_do²]
end
function jac((u, du_do), ::NullParameters, o)
function jac((u, du_do), ::SciMLBase.NullParameters, o)
K_, dK_du, d²K_du² = value_and_derivatives(K, u)
C_, dC_du = value_and_derivative(C, u)

Expand Down
38 changes: 8 additions & 30 deletions src/shooting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,11 @@ function solve(prob::DirichletProblem, alg::BoltzmannODE = BoltzmannODE();
CauchyProblem(prob.eq, b = prob.b, d_dob = zero(d_dob_hint), ob = prob.ob))
solve!(integrator)

@assert integrator.sol.retcode != ReturnCode.Success
retcode = integrator.sol.retcode == ReturnCode.Terminated ? ReturnCode.Success :
integrator.sol.retcode
if verbose && !SciMLBase.successful_retcode(integrator.sol)
if verbose && integrator.sol.retcode != ReturnCode.Success
@warn "Problem has a trivial solution but failed to obtain it"
end

return Solution(integrator.sol, prob, alg, _retcode = retcode, _niter = 0)
return Solution(integrator.sol, prob, alg, _niter = 0)
end

d_dob_trial = bracket_bisect(zero(d_dob_hint), d_dob_hint, resid)
Expand All @@ -68,20 +65,14 @@ function solve(prob::DirichletProblem, alg::BoltzmannODE = BoltzmannODE();
CauchyProblem(prob.eq, b = prob.b, d_dob = d_dob_trial(resid), ob = prob.ob))
solve!(integrator)

@assert integrator.sol.retcode != ReturnCode.Success
if integrator.sol.retcode == ReturnCode.Terminated &&
direction * integrator.sol.u[end][1] direction * limit
if integrator.sol.retcode == ReturnCode.Success
resid = integrator.sol.u[end][1] - prob.i
else
resid = direction * typemax(prob.i)
end

if abs(resid) abstol
return Solution(integrator.sol,
prob,
alg,
_retcode = ReturnCode.Success,
_niter = niter)
return Solution(integrator.sol, prob, alg, _niter = niter)
end
end

Expand Down Expand Up @@ -150,18 +141,11 @@ function solve(prob::Union{FlowrateProblem, SorptivityProblem},
SorptivityCauchyProblem(prob.eq, b = prob.i, S = zero(S), ob = ob))
solve!(integrator)

@assert integrator.sol.retcode != ReturnCode.Success
retcode = integrator.sol.retcode == ReturnCode.Terminated ? ReturnCode.Success :
integrator.sol.retcode
if verbose && !SciMLBase.successful_retcode(integrator.sol)
if verbose && integrator.sol.retcode != ReturnCode.Success
@warn "Problem has a trivial solution but failed to obtain it"
end

return Solution(integrator.sol,
prob,
alg,
_retcode = retcode,
_niter = 0)
return Solution(integrator.sol, prob, alg, _niter = 0)
end

b_trial = bracket_bisect(prob.i, b_hint)
Expand All @@ -171,9 +155,7 @@ function solve(prob::Union{FlowrateProblem, SorptivityProblem},
SorptivityCauchyProblem(prob.eq, b = b_trial(resid), S = S, ob = ob))
solve!(integrator)

@assert integrator.sol.retcode != ReturnCode.Success
if integrator.sol.retcode == ReturnCode.Terminated &&
direction * integrator.sol.u[end][1] direction * limit
if integrator.sol.retcode == ReturnCode.Success
resid = integrator.sol.u[end][1] - prob.i
elseif integrator.sol.retcode != ReturnCode.Terminated && integrator.t == ob
resid = -direction * typemax(prob.i)
Expand All @@ -182,11 +164,7 @@ function solve(prob::Union{FlowrateProblem, SorptivityProblem},
end

if abs(resid) abstol
return Solution(integrator.sol,
prob,
alg,
_retcode = ReturnCode.Success,
_niter = niter)
return Solution(integrator.sol, prob, alg, _niter = niter)
end
end

Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ using Fronts
using Fronts._Diff
using Fronts.PorousModels
using Fronts.ParamEstim
using Fronts.SciMLBase: NullParameters
using Test

import ForwardDiff
import NaNMath
using NumericalIntegration
using OrdinaryDiffEq: ODEFunction, ODEProblem
using OrdinaryDiffEq.DiffEqBase: NullParameters
using StaticArrays: @SVector, SVector

using Plots: plot
Expand Down

0 comments on commit 77c3abc

Please sign in to comment.