diff --git a/src/DiffEqBase.jl b/src/DiffEqBase.jl index 3bab625c7..4430c2ae2 100644 --- a/src/DiffEqBase.jl +++ b/src/DiffEqBase.jl @@ -127,6 +127,7 @@ include("stats.jl") include("calculate_residuals.jl") include("tableaus.jl") include("internal_falsi.jl") +include("internal_itp.jl") include("callbacks.jl") include("common_defaults.jl") diff --git a/src/callbacks.jl b/src/callbacks.jl index 5d3d886e9..e5a09045c 100644 --- a/src/callbacks.jl +++ b/src/callbacks.jl @@ -358,15 +358,16 @@ end # rough implementation, needs multiple type handling # always ensures that if r = bisection(f, (x0, x1)) # then either f(nextfloat(r)) == 0 or f(nextfloat(r)) * f(r) < 0 +# note: not really using bisection - uses the ITP method function bisection(f, tup, t_forward::Bool, rootfind::SciMLBase.RootfindOpt, abstol, reltol; maxiters = 1000) if rootfind == SciMLBase.LeftRootFind solve(IntervalNonlinearProblem{false}(f, tup), - InternalFalsi(), abstol = abstol, + InternalITP(), abstol = abstol, reltol = reltol).left else solve(IntervalNonlinearProblem{false}(f, tup), - InternalFalsi(), abstol = abstol, + InternalITP(), abstol = abstol, reltol = reltol).right end end diff --git a/src/internal_falsi.jl b/src/internal_falsi.jl index 01b3013d4..a6f729a4d 100644 --- a/src/internal_falsi.jl +++ b/src/internal_falsi.jl @@ -36,55 +36,117 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::InternalFalsi, arg i = 1 if !iszero(fr) + using_falsi_steps = true while i < maxiters - if nextfloat_tdir(left, prob.tspan...) == right + # First, perform a regula falsi iteration + if using_falsi_steps + if nextfloat_tdir(left, prob.tspan...) == right + return SciMLBase.build_solution(prob, alg, left, fl; + retcode = ReturnCode.FloatingPointLimit, + left = left, right = right) + end + mid = (fr * left - fl * right) / (fr - fl) + for i in 1:10 + mid = max_tdir(left, prevfloat_tdir(mid, prob.tspan...), prob.tspan...) + end + if mid == right || mid == left + using_falsi_steps = false + continue + end + fm = f(mid) + if iszero(fm) + right = mid + using_falsi_steps = false + continue + end + if sign(fl) == sign(fm) + fl = fm + left = mid + else + fr = fm + right = mid + end + i += 1 + end + + # Then, perform a bisection iteration + mid = (left + right) / 2 + (mid == left || mid == right) && return SciMLBase.build_solution(prob, alg, left, fl; retcode = ReturnCode.FloatingPointLimit, left = left, right = right) - end - mid = (fr * left - fl * right) / (fr - fl) - for i in 1:10 - mid = max_tdir(left, prevfloat_tdir(mid, prob.tspan...), prob.tspan...) - end - if mid == right || mid == left - break - end fm = f(mid) if iszero(fm) right = mid - break - end - if sign(fl) == sign(fm) - fl = fm + fr = fm + elseif sign(fm) == sign(fl) left = mid + fl = fm else - fr = fm right = mid + fr = fm end i += 1 end end - while i < maxiters - mid = (left + right) / 2 - (mid == left || mid == right) && - return SciMLBase.build_solution(prob, alg, left, fl; - retcode = ReturnCode.FloatingPointLimit, - left = left, right = right) - fm = f(mid) - if iszero(fm) - right = mid - fr = fm - elseif sign(fm) == sign(fl) - left = mid - fl = fm - else - right = mid - fr = fm - end - i += 1 - end - return SciMLBase.build_solution(prob, alg, left, fl; retcode = ReturnCode.MaxIters, left = left, right = right) end + +function scalar_nlsolve_ad(prob, alg::InternalFalsi, args...; kwargs...) + f = prob.f + p = value(prob.p) + + if prob isa IntervalNonlinearProblem + tspan = value(prob.tspan) + newprob = IntervalNonlinearProblem(f, tspan, p; prob.kwargs...) + else + u0 = value(prob.u0) + newprob = NonlinearProblem(f, u0, p; prob.kwargs...) + end + + sol = solve(newprob, alg, args...; kwargs...) + + uu = sol.u + if p isa Number + f_p = ForwardDiff.derivative(Base.Fix1(f, uu), p) + else + f_p = ForwardDiff.gradient(Base.Fix1(f, uu), p) + end + + f_x = ForwardDiff.derivative(Base.Fix2(f, p), uu) + pp = prob.p + sumfun = let f_x′ = -f_x + ((fp, p),) -> (fp / f_x′) * ForwardDiff.partials(p) + end + partials = sum(sumfun, zip(f_p, pp)) + return sol, partials +end + +function SciMLBase.solve(prob::IntervalNonlinearProblem{uType, iip, + <:ForwardDiff.Dual{T, V, P}}, + alg::InternalFalsi, args...; + kwargs...) where {uType, iip, T, V, P} + sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...) + return SciMLBase.build_solution(prob, alg, ForwardDiff.Dual{T, V, P}(sol.u, partials), + sol.resid; retcode = sol.retcode, + left = ForwardDiff.Dual{T, V, P}(sol.left, partials), + right = ForwardDiff.Dual{T, V, P}(sol.right, partials)) +end + +function SciMLBase.solve(prob::IntervalNonlinearProblem{uType, iip, + <:AbstractArray{ + <:ForwardDiff.Dual{T, + V, + P}, + }}, + alg::InternalFalsi, args...; + kwargs...) where {uType, iip, T, V, P} + sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...) + + return SciMLBase.build_solution(prob, alg, ForwardDiff.Dual{T, V, P}(sol.u, partials), + sol.resid; retcode = sol.retcode, + left = ForwardDiff.Dual{T, V, P}(sol.left, partials), + right = ForwardDiff.Dual{T, V, P}(sol.right, partials)) +end diff --git a/src/internal_itp.jl b/src/internal_itp.jl new file mode 100644 index 000000000..e778793c3 --- /dev/null +++ b/src/internal_itp.jl @@ -0,0 +1,155 @@ +""" +`InternalITP`: A non-allocating ITP method, internal to DiffEqBase for +simpler dependencies. +""" +struct InternalITP + k1::Float64 + k2::Float64 + n0::Int +end + +InternalITP() = InternalITP(0.007, 1.5, 10) + +function SciMLBase.solve(prob::IntervalNonlinearProblem{IP, Tuple{T,T}}, alg::InternalITP, args...; + maxiters = 1000, kwargs...) where {IP, T} + f = Base.Fix2(prob.f, prob.p) + left, right = prob.tspan # a and b + fl, fr = f(left), f(right) + ϵ = eps(T) + if iszero(fl) + return SciMLBase.build_solution(prob, alg, left, fl; + retcode = ReturnCode.ExactSolutionLeft, left = left, + right = right) + elseif iszero(fr) + return SciMLBase.build_solution(prob, alg, right, fr; + retcode = ReturnCode.ExactSolutionRight, left = left, + right = right) + end + #defining variables/cache + k1 = alg.k1 + k2 = alg.k2 + n0 = alg.n0 + n_h = ceil(log2((right - left) / (2 * ϵ))) + mid = (left + right) / 2 + x_f = (fr * left - fl * right) / (fr - fl) + xt = left + xp = left + r = zero(left) #minmax radius + δ = zero(left) # truncation error + σ = 1.0 + ϵ_s = ϵ * 2^(n_h + n0) + i = 0 #iteration + while i <= maxiters + #mid = (left + right) / 2 + span = abs(right - left) + r = ϵ_s - (span / 2) + δ = k1 * (span^k2) + + ## Interpolation step ## + x_f = left + (right - left) * (fl/(fl - fr)) + + ## Truncation step ## + σ = sign(mid - x_f) + if δ <= abs(mid - x_f) + xt = x_f + (σ * δ) + else + xt = mid + end + + ## Projection step ## + if abs(xt - mid) <= r + xp = xt + else + xp = mid - (σ * r) + end + + ## Update ## + tmax = max(left, right) + tmin = min(left, right) + xp >= tmax && (xp = prevfloat(tmax)) + xp <= tmin && (xp = nextfloat(tmin)) + yp = f(xp) + yps = yp * sign(fr) + if yps > 0 + right = xp + fr = yp + elseif yps < 0 + left = xp + fl = yp + else + left = prevfloat_tdir(xp, prob.tspan...) + right = xp + return SciMLBase.build_solution(prob, alg, left, f(left); + retcode = ReturnCode.Success, left = left, + right = right) + end + i += 1 + mid = (left + right) / 2 + ϵ_s /= 2 + + if nextfloat_tdir(left, prob.tspan...) == right + return SciMLBase.build_solution(prob, alg, left, fl; + retcode = ReturnCode.FloatingPointLimit, left = left, + right = right) + end + end + return SciMLBase.build_solution(prob, alg, left, fl; retcode = ReturnCode.MaxIters, + left = left, right = right) +end + +function scalar_nlsolve_ad(prob, alg::InternalITP, args...; kwargs...) + f = prob.f + p = value(prob.p) + + if prob isa IntervalNonlinearProblem + tspan = value(prob.tspan) + newprob = IntervalNonlinearProblem(f, tspan, p; prob.kwargs...) + else + u0 = value(prob.u0) + newprob = NonlinearProblem(f, u0, p; prob.kwargs...) + end + + sol = solve(newprob, alg, args...; kwargs...) + + uu = sol.u + if p isa Number + f_p = ForwardDiff.derivative(Base.Fix1(f, uu), p) + else + f_p = ForwardDiff.gradient(Base.Fix1(f, uu), p) + end + + f_x = ForwardDiff.derivative(Base.Fix2(f, p), uu) + pp = prob.p + sumfun = let f_x′ = -f_x + ((fp, p),) -> (fp / f_x′) * ForwardDiff.partials(p) + end + partials = sum(sumfun, zip(f_p, pp)) + return sol, partials +end + +function SciMLBase.solve(prob::IntervalNonlinearProblem{uType, iip, + <:ForwardDiff.Dual{T, V, P}}, + alg::InternalITP, args...; + kwargs...) where {uType, iip, T, V, P} + sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...) + return SciMLBase.build_solution(prob, alg, ForwardDiff.Dual{T, V, P}(sol.u, partials), + sol.resid; retcode = sol.retcode, + left = ForwardDiff.Dual{T, V, P}(sol.left, partials), + right = ForwardDiff.Dual{T, V, P}(sol.right, partials)) +end + +function SciMLBase.solve(prob::IntervalNonlinearProblem{uType, iip, + <:AbstractArray{ + <:ForwardDiff.Dual{T, + V, + P}, + }}, + alg::InternalITP, args...; + kwargs...) where {uType, iip, T, V, P} + sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...) + + return SciMLBase.build_solution(prob, alg, ForwardDiff.Dual{T, V, P}(sol.u, partials), + sol.resid; retcode = sol.retcode, + left = ForwardDiff.Dual{T, V, P}(sol.left, partials), + right = ForwardDiff.Dual{T, V, P}(sol.right, partials)) +end diff --git a/test/internal_rootfinder.jl b/test/internal_rootfinder.jl new file mode 100644 index 000000000..5458b17c3 --- /dev/null +++ b/test/internal_rootfinder.jl @@ -0,0 +1,24 @@ +using DiffEqBase +using DiffEqBase: InternalFalsi, InternalITP, IntervalNonlinearProblem +using ForwardDiff + +for Rootfinder in (InternalFalsi, InternalITP) + # From SimpleNonlinearSolve + f = (u, p) -> u * u - p + tspan = (1.0, 20.0) + g = function (p) + probN = IntervalNonlinearProblem{false}(f, typeof(p).(tspan), p) + sol = solve(probN, Rootfinder()) + return sol.u + end + + for p in (1.0,) #1.1:0.1:100.0 + @test g(p) ≈ sqrt(p) + #@test ForwardDiff.derivative(g, p) ≈ 1 / (2 * sqrt(p)) + end + + # https://github.com/SciML/DiffEqBase.jl/issues/916 + inp = IntervalNonlinearProblem((t, p) -> min(-1.0 + 0.001427344607477125 * t, 1e-9), + (699.0079267259368, 700.6176418816023)) + @test solve(inp, Rootfinder()).u ≈ 700.6016590257979 +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index fee17f03e..7a92bc7f4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -25,6 +25,9 @@ end @time @safetestset "Callbacks" begin include("callbacks.jl") end + @time @safetestset "Internal Rootfinders" begin + include("internal_rootfinder.jl") + end @time @safetestset "Plot Vars" begin include("plot_vars.jl") end