diff --git a/Project.toml b/Project.toml index 33e3620..881ca66 100644 --- a/Project.toml +++ b/Project.toml @@ -26,12 +26,14 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +TaylorDiff = "b36ab563-344f-407b-a36a-4f200bebf99c" [extensions] SimpleNonlinearSolveChainRulesCoreExt = "ChainRulesCore" SimpleNonlinearSolveReverseDiffExt = "ReverseDiff" SimpleNonlinearSolveTrackerExt = "Tracker" SimpleNonlinearSolveZygoteExt = "Zygote" +SimpleNonlinearSolveTaylorDiffExt = "TaylorDiff" [compat] ADTypes = "1.9" diff --git a/ext/SimpleNonlinearSolveTaylorDiffExt.jl b/ext/SimpleNonlinearSolveTaylorDiffExt.jl new file mode 100644 index 0000000..3257c68 --- /dev/null +++ b/ext/SimpleNonlinearSolveTaylorDiffExt.jl @@ -0,0 +1,60 @@ +module SimpleNonlinearSolveTaylorDiffExt +using SimpleNonlinearSolve +using SimpleNonlinearSolve: ImmutableNonlinearProblem, ReturnCode, build_solution, check_termination, init_termination_cache +using SimpleNonlinearSolve: __maybe_unaliased, _get_fx, __fixed_parameter_function +using MaybeInplace: @bb +using SciMLBase: isinplace + +import TaylorDiff + +@inline function __get_higher_order_derivatives(::SimpleHouseholder{N}, prob, f, x, fx) where N + vN = Val(N) + l = map(one, x) + t = TaylorDiff.make_seed(x, l, vN) + + if isinplace(prob) + bundle = similar(fx, TaylorDiff.TaylorScalar{eltype(fx), N}) + f(bundle, t) + map!(TaylorDiff.primal, fx, bundle) + else + bundle = f(t) + fx = map(TaylorDiff.primal, bundle) + end + bundle = inv.(bundle) + num = TaylorDiff.extract_derivative(bundle, N - 1) + den = TaylorDiff.extract_derivative(bundle, N) + return num, den, fx +end + +function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleHouseholder{N}, + args...; abstol = nothing, reltol = nothing, maxiters = 1000, + termination_condition = nothing, alias_u0 = false, kwargs...) where N + x = __maybe_unaliased(prob.u0, alias_u0) + length(x) == 1 || + throw(ArgumentError("SimpleHouseholder only supports scalar problems")) + fx = _get_fx(prob, x) + @bb xo = copy(x) + f = __fixed_parameter_function(prob) + + abstol, reltol, tc_cache = init_termination_cache( + prob, abstol, reltol, fx, x, termination_condition) + + for i in 1:maxiters + num, den, fx = __get_higher_order_derivatives(alg, prob, f, x, fx) + + if i == 1 + iszero(fx) && build_solution(prob, alg, x, fx; retcode = ReturnCode.Success) + else + # Termination Checks + tc_sol = check_termination(tc_cache, fx, x, xo, prob, alg) + tc_sol !== nothing && return tc_sol + end + + @bb copyto!(xo, x) + @bb x .+= (N - 1) .* num ./ den + end + + return build_solution(prob, alg, x, fx; retcode = ReturnCode.MaxIters) +end + +end diff --git a/src/SimpleNonlinearSolve.jl b/src/SimpleNonlinearSolve.jl index ba0b243..e6c939b 100644 --- a/src/SimpleNonlinearSolve.jl +++ b/src/SimpleNonlinearSolve.jl @@ -47,6 +47,7 @@ include("nlsolve/lbroyden.jl") include("nlsolve/klement.jl") include("nlsolve/trustRegion.jl") include("nlsolve/halley.jl") +include("nlsolve/householder.jl") include("nlsolve/dfsane.jl") ## Interval Nonlinear Solvers @@ -139,6 +140,7 @@ end export AutoFiniteDiff, AutoForwardDiff, AutoPolyesterForwardDiff export SimpleBroyden, SimpleDFSane, SimpleGaussNewton, SimpleHalley, SimpleKlement, SimpleLimitedMemoryBroyden, SimpleNewtonRaphson, SimpleTrustRegion +export SimpleHouseholder export Alefeld, Bisection, Brent, Falsi, ITP, Ridder end # module diff --git a/src/nlsolve/householder.jl b/src/nlsolve/householder.jl new file mode 100644 index 0000000..ce06c5f --- /dev/null +++ b/src/nlsolve/householder.jl @@ -0,0 +1,16 @@ +""" + SimpleHouseholder{order}() + +A low-overhead implementation of Householder's method to arbitrary order. +This method is non-allocating on scalar and static array problems. + +!!! warning + + Needs `TaylorDiff.jl` to be explicitly loaded before using this functionality. + Internally, this uses TaylorDiff.jl for automatic differentiation. + +### Type Parameters + + - `order`: the convergence order of the Householder method. `order = 2` is the same as Newton's method, `order = 3` is the same as Halley's method, etc. +""" +struct SimpleHouseholder{order} <: AbstractNewtonAlgorithm end diff --git a/test/core/rootfind_tests.jl b/test/core/rootfind_tests.jl index 272fa3c..76ba85d 100644 --- a/test/core/rootfind_tests.jl +++ b/test/core/rootfind_tests.jl @@ -91,6 +91,31 @@ end end end +@testitem "SimpleHouseholder" setup=[RootfindingTesting] tags=[:core] begin + using TaylorDiff + @testset "AutoDiff: TaylorDiff.jl" for order in (2, 3, 4) + @testset "[OOP] u0: $(nameof(typeof(u0)))" for u0 in ( + [1.0], @SVector[1.0], 1.0) + sol = benchmark_nlsolve_oop(quadratic_f, u0; solver = SimpleHouseholder{order}()) + @test SciMLBase.successful_retcode(sol) + @test all(abs.(sol.u .* sol.u .- 2) .< 1e-9) + end + + @testset "[IIP] u0: $(nameof(typeof(u0)))" for u0 in ([1.0],) + sol = benchmark_nlsolve_iip(quadratic_f!, u0; solver = SimpleHouseholder{order}()) + @test SciMLBase.successful_retcode(sol) + @test all(abs.(sol.u .* sol.u .- 2) .< 1e-9) + end + end + + @testset "Termination condition: $(nameof(typeof(termination_condition))) u0: $(nameof(typeof(u0)))" for termination_condition in TERMINATION_CONDITIONS, + u0 in (1.0, [1.0], @SVector[1.0]) + + probN = NonlinearProblem(quadratic_f, u0, 2.0) + @test all(solve(probN, SimpleHouseholder{2}(); termination_condition).u .≈ sqrt(2.0)) + end +end + @testitem "Derivative Free Metods" setup=[RootfindingTesting] tags=[:core] begin @testset "$(nameof(typeof(alg)))" for alg in [ SimpleBroyden(), SimpleKlement(), SimpleDFSane(),