diff --git a/Project.toml b/Project.toml index 6592da9..5a1497d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Roots" uuid = "f2b01f46-fcfa-551c-844a-d8ac1e96c665" -version = "2.1.8" +version = "2.2.0" [deps] Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" @@ -9,12 +9,14 @@ CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" [weakdeps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" IntervalRootFinding = "d2bf35a9-74e0-55ec-b149-d360ff49b807" SymPy = "24249f21-da20-56a4-8eb1-6a02cf4ae2e6" SymPyPythonCall = "bc8888f7-b21e-4b7c-a06a-5d9c9496438c" [extensions] +RootsChainRulesCoreExt = "ChainRulesCore" RootsForwardDiffExt = "ForwardDiff" RootsIntervalRootFindingExt = "IntervalRootFinding" RootsSymPyExt = "SymPy" diff --git a/src/chain_rules.jl b/ext/RootsChainRulesCoreExt.jl similarity index 91% rename from src/chain_rules.jl rename to ext/RootsChainRulesCoreExt.jl index 9c1b374..c45dabb 100644 --- a/src/chain_rules.jl +++ b/ext/RootsChainRulesCoreExt.jl @@ -1,3 +1,8 @@ +module RootsChainRulesCoreExt + +using Roots +import ChainRulesCore + # View find_zero as solving `f(x, p) = 0` for `xᵅ(p)`. # This is implicitly defined. By the implicit function theorem, we have: # ∇f = 0 => ∂/∂ₓ f(xᵅ, p) ⋅ ∂xᵅ/∂ₚ + ∂/∂ₚf(x\^α, p) ⋅ I = 0 @@ -15,7 +20,6 @@ # that is fixable.) # this assumes a function and a parameter `p` passed in -import ChainRulesCore: Tangent, NoTangent, frule, rrule function ChainRulesCore.frule( config::ChainRulesCore.RuleConfig{>:ChainRulesCore.HasForwardsMode}, (_, _, _, Δp), @@ -42,17 +46,17 @@ ChainRulesCore.frule( config::ChainRulesCore.RuleConfig{>:ChainRulesCore.HasForwardsMode}, xdots, ::typeof(solve), - ZP::Roots.ZeroProblem, + ZP::ZeroProblem, M::Roots.AbstractUnivariateZeroMethod, ::Nothing; kwargs..., -) = frule(config, xdots, solve, ZP, M; kwargs...) +) = ChainRulesCore.frule(config, xdots, solve, ZP, M; kwargs...) function ChainRulesCore.frule( config::ChainRulesCore.RuleConfig{>:ChainRulesCore.HasForwardsMode}, (_, Δq, _), ::typeof(solve), - ZP::Roots.ZeroProblem, + ZP::ZeroProblem, M::Roots.AbstractUnivariateZeroMethod; kwargs..., ) @@ -61,12 +65,12 @@ function ChainRulesCore.frule( zprob2 = ZeroProblem(|>, ZP.x₀) nms = fieldnames(typeof(foo)) nt = NamedTuple{nms}(getfield(foo, n) for n in nms) - dfoo = Tangent{typeof(foo)}(; nt...) + dfoo = ChainRulesCore.Tangent{typeof(foo)}(; nt...) - return frule( + return ChainRulesCore.frule( config, - (NoTangent(), NoTangent(), NoTangent(), dfoo), - Roots.solve, + (ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), dfoo), + solve, zprob2, M, foo, @@ -146,3 +150,5 @@ function ChainRulesCore.rrule( return xᵅ, pullback_solve_ZeroProblem end + +end # module diff --git a/src/Roots.jl b/src/Roots.jl index 35ddf0d..8dea221 100644 --- a/src/Roots.jl +++ b/src/Roots.jl @@ -22,7 +22,6 @@ using Printf import CommonSolve import CommonSolve: solve, solve!, init using Accessors -import ChainRulesCore export fzero, fzeros, secant_method @@ -53,7 +52,6 @@ include("functions.jl") include("trace.jl") include("find_zero.jl") include("hybrid.jl") -include("chain_rules.jl") include("Bracketing/bracketing.jl") include("Bracketing/bisection.jl") @@ -83,4 +81,8 @@ include("find_zeros.jl") include("simple.jl") include("alternative_interfaces.jl") +if !isdefined(Base, :get_extension) + include("../ext/RootsChainRulesCoreExt.jl") +end + end