Skip to content

Commit

Permalink
Merge pull request #203 from avik-pal/ap/cleanup
Browse files Browse the repository at this point in the history
Towards a cleaner and more maintainable internals of NonlinearSolve.jl
  • Loading branch information
ChrisRackauckas authored Sep 21, 2023
2 parents 81e9164 + 4cd2d97 commit b2946b1
Show file tree
Hide file tree
Showing 18 changed files with 1,081 additions and 2,124 deletions.
3 changes: 2 additions & 1 deletion .JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
style = "sciml"
format_markdown = true
format_markdown = true
annotate_untyped_fields_with_any = false
1 change: 0 additions & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ jobs:
- Core
version:
- '1'
- '1.6'
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v1
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/Downstream.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
strategy:
fail-fast: false
matrix:
julia-version: [1,1.6]
julia-version: [1]
os: [ubuntu-latest]
package:
- {user: SciML, repo: ModelingToolkit.jl, group: All}
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@ Manifest.toml
docs/src/assets/Project.toml

.vscode
wip
23 changes: 17 additions & 6 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
name = "NonlinearSolve"
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
authors = ["SciML"]
version = "1.10.0"
version = "2.0.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Expand All @@ -22,33 +25,41 @@ StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"

[compat]
ADTypes = "0.2"
ArrayInterface = "6.0.24, 7"
DiffEqBase = "6"
ConcreteStructs = "0.2"
DiffEqBase = "6.130"
EnumX = "1"
Enzyme = "0.11"
FiniteDiff = "2"
ForwardDiff = "0.10.3"
LinearSolve = "2"
LineSearches = "7"
PrecompileTools = "1"
RecursiveArrayTools = "2"
Reexport = "0.2, 1"
SciMLBase = "1.92.4"
SciMLBase = "1.97"
SimpleNonlinearSolve = "0.1"
SparseDiffTools = "1, 2"
SparseDiffTools = "2.6"
StaticArraysCore = "1.4"
UnPack = "1.0"
julia = "1.6"
Zygote = "0.6"
julia = "1.9"

[extras]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra"]
test = ["Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools"]
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
BenchmarkTools = "1"
Documenter = "0.27"
LinearSolve = "2"
NonlinearSolve = "1"
NonlinearSolve = "1, 2"
NonlinearSolveMINPACK = "0.1"
SciMLNLSolve = "0.1"
SimpleNonlinearSolve = "0.1.5"
Expand Down
64 changes: 34 additions & 30 deletions src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
@@ -1,39 +1,41 @@
module NonlinearSolve
if isdefined(Base, :Experimental) &&
isdefined(Base.Experimental, Symbol("@max_methods"))

if isdefined(Base, :Experimental) && isdefined(Base.Experimental, Symbol("@max_methods"))
@eval Base.Experimental.@max_methods 1
end
using Reexport
using UnPack: @unpack
using FiniteDiff, ForwardDiff
using ForwardDiff: Dual
using LinearAlgebra
using StaticArraysCore
using RecursiveArrayTools
import EnumX
import ArrayInterface
import LinearSolve
using DiffEqBase
using SparseDiffTools

@reexport using SciMLBase
using SciMLBase: NLStats
@reexport using SimpleNonlinearSolve

import SciMLBase: _unwrap_val

abstract type AbstractNonlinearSolveAlgorithm <: SciMLBase.AbstractNonlinearAlgorithm end
abstract type AbstractNewtonAlgorithm{CS, AD, FDT, ST, CJ} <:
AbstractNonlinearSolveAlgorithm end

function SciMLBase.__solve(prob::NonlinearProblem,
alg::AbstractNonlinearSolveAlgorithm, args...;
kwargs...)

using DiffEqBase, LinearAlgebra, LinearSolve, SparseDiffTools
import ForwardDiff

import ADTypes: AbstractFiniteDifferencesMode
import ArrayInterface: undefmatrix, matrix_colors, parameterless_type, ismutable
import ConcreteStructs: @concrete
import EnumX: @enumx
import ForwardDiff: Dual
import LinearSolve: ComposePreconditioner, InvPreconditioner, needs_concrete_A
import RecursiveArrayTools: ArrayPartition,
AbstractVectorOfArray, recursivecopy!, recursivefill!
import Reexport: @reexport
import SciMLBase: AbstractNonlinearAlgorithm, NLStats, _unwrap_val, has_jac, isinplace
import StaticArraysCore: StaticArray, SVector, SArray, MArray
import UnPack: @unpack

@reexport using ADTypes, LineSearches, SciMLBase, SimpleNonlinearSolve

const AbstractSparseADType = Union{ADTypes.AbstractSparseFiniteDifferences,
ADTypes.AbstractSparseForwardMode, ADTypes.AbstractSparseReverseMode}

abstract type AbstractNonlinearSolveAlgorithm <: AbstractNonlinearAlgorithm end
abstract type AbstractNewtonAlgorithm{CJ, AD} <: AbstractNonlinearSolveAlgorithm end

function SciMLBase.__solve(prob::NonlinearProblem, alg::AbstractNonlinearSolveAlgorithm,
args...; kwargs...)
cache = init(prob, alg, args...; kwargs...)
sol = solve!(cache)
return solve!(cache)
end

include("utils.jl")
include("linesearch.jl")
include("raphson.jl")
include("trustRegion.jl")
include("levenberg.jl")
Expand All @@ -46,7 +48,7 @@ PrecompileTools.@compile_workload begin
for T in (Float32, Float64)
prob = NonlinearProblem{false}((u, p) -> u .* u .- p, T(0.1), T(2))

precompile_algs = if VERSION >= v"1.7"
precompile_algs = if VERSION v"1.7"
(NewtonRaphson(), TrustRegion(), LevenbergMarquardt())
else
(NewtonRaphson(),)
Expand All @@ -68,4 +70,6 @@ export RadiusUpdateSchemes

export NewtonRaphson, TrustRegion, LevenbergMarquardt

export LineSearch

end # module
67 changes: 43 additions & 24 deletions src/ad.jl
Original file line number Diff line number Diff line change
@@ -1,44 +1,63 @@
function scalar_nlsolve_ad(prob, alg, args...; kwargs...)
f = prob.f
p = value(prob.p)

u0 = value(prob.u0)
newprob = NonlinearProblem(f, u0, p; prob.kwargs...)

sol = solve(newprob, alg, args...; kwargs...)

uu = sol.u
if p isa Number
f_p = ForwardDiff.derivative(Base.Fix1(f, uu), p)
else
f_p = ForwardDiff.gradient(Base.Fix1(f, uu), p)
end
f_p = scalar_nlsolve_∂f_∂p(f, uu, p)
f_x = scalar_nlsolve_∂f_∂u(f, uu, p)

z_arr = -inv(f_x) * f_p

f_x = ForwardDiff.derivative(Base.Fix2(f, p), uu)
pp = prob.p
sumfun = let f_x′ = -f_x
((fp, p),) -> (fp / f_x′) * ForwardDiff.partials(p)
sumfun = ((z, p),) -> map(zᵢ -> zᵢ * ForwardDiff.partials(p), z)
if uu isa Number
partials = sum(sumfun, zip(z_arr, pp))
elseif p isa Number
partials = sumfun((z_arr, pp))
else
partials = sum(sumfun, zip(eachcol(z_arr), pp))
end
partials = sum(sumfun, zip(f_p, pp))

return sol, partials
end

function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, StaticArraysCore.SVector},
iip,
<:Dual{T, V, P}},
alg::AbstractNewtonAlgorithm,
args...; kwargs...) where {iip, T, V, P}
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, SVector, <:AbstractArray},
iip, <:Dual{T, V, P}}, alg::AbstractNewtonAlgorithm, args...;
kwargs...) where {iip, T, V, P}
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
return SciMLBase.build_solution(prob, alg, Dual{T, V, P}(sol.u, partials), sol.resid;
retcode = sol.retcode)
dual_soln = scalar_nlsolve_dual_soln(sol.u, partials, prob.p)
return SciMLBase.build_solution(prob, alg, dual_soln, sol.resid; sol.retcode)
end
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, StaticArraysCore.SVector},
iip,
<:AbstractArray{<:Dual{T, V, P}}},
alg::AbstractNewtonAlgorithm,
args...;

function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, SVector, <:AbstractArray},
iip, <:AbstractArray{<:Dual{T, V, P}}}, alg::AbstractNewtonAlgorithm, args...;
kwargs...) where {iip, T, V, P}
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
return SciMLBase.build_solution(prob, alg, Dual{T, V, P}(sol.u, partials), sol.resid;
retcode = sol.retcode)
dual_soln = scalar_nlsolve_dual_soln(sol.u, partials, prob.p)
return SciMLBase.build_solution(prob, alg, dual_soln, sol.resid; sol.retcode)
end

function scalar_nlsolve_∂f_∂p(f, u, p)
ff = p isa Number ? ForwardDiff.derivative :
(u isa Number ? ForwardDiff.gradient : ForwardDiff.jacobian)
return ff(Base.Fix1(f, u), p)
end

function scalar_nlsolve_∂f_∂u(f, u, p)
ff = u isa Number ? ForwardDiff.derivative : ForwardDiff.jacobian
return ff(Base.Fix2(f, p), u)
end

function scalar_nlsolve_dual_soln(u::Number, partials,
::Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}}) where {T, V, P}
return Dual{T, V, P}(u, partials)
end

function scalar_nlsolve_dual_soln(u::AbstractArray, partials,
::Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}}) where {T, V, P}
return map(((uᵢ, pᵢ),) -> Dual{T, V, P}(uᵢ, pᵢ), zip(u, partials))
end
Loading

0 comments on commit b2946b1

Please sign in to comment.