Skip to content

Commit

Permalink
feat: automatic backend selection for autodiff
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 26, 2024
1 parent ab4c9f8 commit 77cd271
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 1 deletion.
6 changes: 6 additions & 0 deletions lib/NonlinearSolveBase/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@ authors = ["Avik Pal <[email protected]> and contributors"]
version = "1.0.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
Expand All @@ -24,10 +27,13 @@ NonlinearSolveBaseForwardDiffExt = "ForwardDiff"
NonlinearSolveBaseSparseArraysExt = "SparseArrays"

[compat]
ADTypes = "1.9"
ArrayInterface = "7.9"
CommonSolve = "0.2.4"
Compat = "4.15"
ConcreteStructs = "0.2.3"
DifferentiationInterface = "0.6.1"
EnzymeCore = "0.8"
FastClosures = "0.3"
ForwardDiff = "0.10.36"
LinearAlgebra = "1.10"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
module NonlinearSolveBaseForwardDiffExt

using ADTypes: ADTypes, AutoForwardDiff, AutoPolyesterForwardDiff
using CommonSolve: solve
using FastClosures: @closure
using ForwardDiff: ForwardDiff, Dual
using SciMLBase: SciMLBase, IntervalNonlinearProblem, NonlinearProblem,
using SciMLBase: SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem,
NonlinearProblem,
NonlinearLeastSquaresProblem, remake

using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem, Utils

function NonlinearSolveBase.additional_incompatible_backend_check(
prob::AbstractNonlinearProblem, ::Union{AutoForwardDiff, AutoPolyesterForwardDiff})
return !ForwardDiff.can_dual(eltype(prob.u0))
end

Utils.value(::Type{Dual{T, V, N}}) where {T, V, N} = V
Utils.value(x::Dual) = Utils.value(ForwardDiff.value(x))
Utils.value(x::AbstractArray{<:Dual}) = Utils.value.(x)
Expand Down
9 changes: 9 additions & 0 deletions lib/NonlinearSolveBase/src/NonlinearSolveBase.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
module NonlinearSolveBase

using ADTypes: ADTypes, AbstractADType, ForwardMode, ReverseMode
using ArrayInterface: ArrayInterface
using Compat: @compat
using ConcreteStructs: @concrete
using DifferentiationInterface: DifferentiationInterface
using EnzymeCore: EnzymeCore
using FastClosures: @closure
using LinearAlgebra: norm
using Markdown: @doc_str
Expand All @@ -13,16 +16,22 @@ using SciMLBase: SciMLBase, ReturnCode, AbstractODEIntegrator, AbstractNonlinear
isinplace, warn_paramtype
using StaticArraysCore: StaticArray

const DI = DifferentiationInterface

include("public.jl")
include("utils.jl")

include("immutable_problem.jl")
include("common_defaults.jl")
include("termination_conditions.jl")

include("autodiff.jl")

# Unexported Public API
@compat(public, (L2_NORM, Linf_NORM, NAN_CHECK, UNITLESS_ABS2, get_tolerance))
@compat(public, (nonlinearsolve_forwarddiff_solve, nonlinearsolve_dual_solution))
@compat(public, (select_forward_mode_autodiff, select_reverse_mode_autodiff,
select_jacobian_autodiff))

export RelTerminationMode, AbsTerminationMode, NormTerminationMode, RelNormTerminationMode,
AbsNormTerminationMode, RelNormSafeTerminationMode, AbsNormSafeTerminationMode,
Expand Down
109 changes: 109 additions & 0 deletions lib/NonlinearSolveBase/src/autodiff.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Here we determine the preferred AD backend. We have a predefined list of ADs and then
# we select the first one that is avialable and would work with the problem.

# Ordering is important here. We want to select the first one that is compatible with the
# problem.
const ReverseADs = [
ADTypes.AutoEnzyme(; mode = EnzymeCore.Reverse),
ADTypes.AutoZygote(),
ADTypes.AutoTracker(),
ADTypes.AutoReverseDiff(),
ADTypes.AutoFiniteDiff()
]

const ForwardADs = [
ADTypes.AutoEnzyme(; mode = EnzymeCore.Forward),
ADTypes.AutoPolyesterForwardDiff(),
ADTypes.AutoForwardDiff(),
ADTypes.AutoFiniteDiff()
]

# TODO: Handle Sparsity

function select_forward_mode_autodiff(
prob::AbstractNonlinearProblem, ad::AbstractADType; warn_check_mode::Bool = true)
if warn_check_mode && !(ADTypes.mode(ad) isa ADTypes.ForwardMode)
@warn "The chosen AD backend $(ad) is not a forward mode AD. Use with caution."
end
if incompatible_backend_and_problem(prob, ad)
adₙ = select_forward_mode_autodiff(prob, nothing; warn_check_mode)
@warn "The chosen AD backend `$(ad)` does not support the chosen problem. After \
running autodiff selection detected `$(adₙ)` as a potential forward mode \
backend."
return adₙ
end
return ad
end

function select_forward_mode_autodiff(prob::AbstractNonlinearProblem, ::Nothing;
warn_check_mode::Bool = true)
idx = findfirst(!Base.Fix1(incompatible_backend_and_problem, prob), ForwardADs)
idx !== nothing && return ForwardADs[idx]
throw(ArgumentError("No forward mode AD backend is compatible with the chosen problem. \
This could be because no forward mode autodiff backend is loaded \
or the loaded backends don't support the problem."))
end

function select_reverse_mode_autodiff(
prob::AbstractNonlinearProblem, ad::AbstractADType; warn_check_mode::Bool = true)
if warn_check_mode && !(ADTypes.mode(ad) isa ADTypes.ReverseMode)
if !is_finite_differences_backend(ad)
@warn "The chosen AD backend $(ad) is not a reverse mode AD. Use with caution."
else
@warn "The chosen AD backend $(ad) is a finite differences backend. This might \
be slow and inaccurate. Use with caution."
end
end
if incompatible_backend_and_problem(prob, ad)
adₙ = select_reverse_mode_autodiff(prob, nothing; warn_check_mode)
@warn "The chosen AD backend `$(ad)` does not support the chosen problem. After \
running autodiff selection detected `$(adₙ)` as a potential reverse mode \
backend."
return adₙ
end
return ad
end

function select_reverse_mode_autodiff(prob::AbstractNonlinearProblem, ::Nothing;
warn_check_mode::Bool = true)
idx = findfirst(!Base.Fix1(incompatible_backend_and_problem, prob), ReverseADs)
idx !== nothing && return ReverseADs[idx]
throw(ArgumentError("No reverse mode AD backend is compatible with the chosen problem. \
This could be because no reverse mode autodiff backend is loaded \
or the loaded backends don't support the problem."))
end

function select_jacobian_autodiff(prob::AbstractNonlinearProblem, ad::AbstractADType)
if incompatible_backend_and_problem(prob, ad)
adₙ = select_jacobian_autodiff(prob, nothing)
@warn "The chosen AD backend `$(ad)` does not support the chosen problem. After \
running autodiff selection detected `$(adₙ)` as a potential jacobian \
backend."
return adₙ
end
return ad
end

function select_jacobian_autodiff(prob::AbstractNonlinearProblem, ::Nothing)
idx = findfirst(!Base.Fix1(incompatible_backend_and_problem, prob), ForwardADs)
idx !== nothing && !is_finite_differences_backend(ForwardADs[idx]) &&
return ForwardADs[idx]
idx = findfirst(!Base.Fix1(incompatible_backend_and_problem, prob), ReverseADs)
idx !== nothing && return ReverseADs[idx]
throw(ArgumentError("No jacobian AD backend is compatible with the chosen problem. \
This could be because no jacobian autodiff backend is loaded \
or the loaded backends don't support the problem."))
end

function incompatible_backend_and_problem(
prob::AbstractNonlinearProblem, ad::AbstractADType)
!DI.check_available(ad) && return true
SciMLBase.isinplace(prob) && !DI.check_inplace(ad) && return true
return additional_incompatible_backend_check(prob, ad)
end

additional_incompatible_backend_check(::AbstractNonlinearProblem, ::AbstractADType) = false

is_finite_differences_backend(ad::AbstractADType) = false
is_finite_differences_backend(::ADTypes.AutoFiniteDiff) = true
is_finite_differences_backend(::ADTypes.AutoFiniteDifferences) = true

0 comments on commit 77cd271

Please sign in to comment.