From 12cebe731e811f50796441896164a5d038047e97 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 26 Sep 2024 03:04:53 -0400 Subject: [PATCH] feat: check for branching for ReverseDiff(compile=true) --- lib/NonlinearSolveBase/Project.toml | 2 ++ lib/NonlinearSolveBase/src/NonlinearSolveBase.jl | 10 ++++++---- lib/NonlinearSolveBase/src/autodiff.jl | 9 +++++++++ 3 files changed, 17 insertions(+), 4 deletions(-) diff --git a/lib/NonlinearSolveBase/Project.toml b/lib/NonlinearSolveBase/Project.toml index 819bcc79f..3999de770 100644 --- a/lib/NonlinearSolveBase/Project.toml +++ b/lib/NonlinearSolveBase/Project.toml @@ -12,6 +12,7 @@ ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" +FunctionProperties = "f62d2435-5019-4c03-9749-2d4c77af0cbc" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" @@ -36,6 +37,7 @@ DifferentiationInterface = "0.6.1" EnzymeCore = "0.8" FastClosures = "0.3" ForwardDiff = "0.10.36" +FunctionProperties = "0.1.2" LinearAlgebra = "1.10" Markdown = "1.10" RecursiveArrayTools = "3" diff --git a/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl b/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl index 4b3eec258..5e1a37326 100644 --- a/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl +++ b/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl @@ -7,13 +7,14 @@ using ConcreteStructs: @concrete using DifferentiationInterface: DifferentiationInterface using EnzymeCore: EnzymeCore using FastClosures: @closure +using FunctionProperties: hasbranching using LinearAlgebra: norm using Markdown: @doc_str using RecursiveArrayTools: AbstractVectorOfArray, ArrayPartition using SciMLBase: SciMLBase, ReturnCode, AbstractODEIntegrator, AbstractNonlinearProblem, NonlinearProblem, NonlinearLeastSquaresProblem, AbstractNonlinearFunction, - @add_kwonly, StandardNonlinearProblem, NullParameters, NonlinearProblem, - isinplace, warn_paramtype + @add_kwonly, StandardNonlinearProblem, NullParameters, isinplace, + warn_paramtype using StaticArraysCore: StaticArray const DI = DifferentiationInterface @@ -30,8 +31,9 @@ include("autodiff.jl") # Unexported Public API @compat(public, (L2_NORM, Linf_NORM, NAN_CHECK, UNITLESS_ABS2, get_tolerance)) @compat(public, (nonlinearsolve_forwarddiff_solve, nonlinearsolve_dual_solution)) -@compat(public, (select_forward_mode_autodiff, select_reverse_mode_autodiff, - select_jacobian_autodiff)) +@compat(public, + (select_forward_mode_autodiff, select_reverse_mode_autodiff, + select_jacobian_autodiff)) export RelTerminationMode, AbsTerminationMode, NormTerminationMode, RelNormTerminationMode, AbsNormTerminationMode, RelNormSafeTerminationMode, AbsNormSafeTerminationMode, diff --git a/lib/NonlinearSolveBase/src/autodiff.jl b/lib/NonlinearSolveBase/src/autodiff.jl index d2e51389d..f81ce7039 100644 --- a/lib/NonlinearSolveBase/src/autodiff.jl +++ b/lib/NonlinearSolveBase/src/autodiff.jl @@ -7,6 +7,7 @@ const ReverseADs = [ ADTypes.AutoEnzyme(; mode = EnzymeCore.Reverse), ADTypes.AutoZygote(), ADTypes.AutoTracker(), + ADTypes.AutoReverseDiff(; compile = true), ADTypes.AutoReverseDiff(), ADTypes.AutoFiniteDiff() ] @@ -103,6 +104,14 @@ function incompatible_backend_and_problem( end additional_incompatible_backend_check(::AbstractNonlinearProblem, ::AbstractADType) = false +function additional_incompatible_backend_check(prob::AbstractNonlinearProblem, + ::ADTypes.AutoReverseDiff{true}) + if SciMLBase.isinplace(prob) + fu = prob.f.resid_prototype === nothing ? zero(prob.u0) : prob.f.resid_prototype + return hasbranching(prob.f, fu, prob.u0, prob.p) + end + return hasbranching(prob.f, prob.u0, prob.p) +end is_finite_differences_backend(ad::AbstractADType) = false is_finite_differences_backend(::ADTypes.AutoFiniteDiff) = true