Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use DifferentiationInterface #148

Merged
merged 4 commits into from
May 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
name = "SimpleNonlinearSolve"
uuid = "727e6d20-b764-4bd8-a329-72de5adea6c7"
authors = ["SciML"]
version = "1.8.1"
version = "1.9.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Expand All @@ -17,21 +18,18 @@ MaybeInplace = "bb5d69b7-63fc-4a16-80bd-7e42200c7bdb"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
SimpleNonlinearSolveChainRulesCoreExt = "ChainRulesCore"
SimpleNonlinearSolvePolyesterForwardDiffExt = "PolyesterForwardDiff"
SimpleNonlinearSolveReverseDiffExt = "ReverseDiff"
SimpleNonlinearSolveStaticArraysExt = "StaticArrays"
SimpleNonlinearSolveTrackerExt = "Tracker"
SimpleNonlinearSolveZygoteExt = "Zygote"

Expand All @@ -41,13 +39,14 @@ AllocCheck = "0.1.1"
Aqua = "0.8"
ArrayInterface = "7.9"
CUDA = "5.2"
ChainRulesCore = "1.22"
ChainRulesCore = "1.23"
ConcreteStructs = "0.2.3"
DiffEqBase = "6.149"
DiffResults = "1.1"
DifferentiationInterface = "0.4"
ExplicitImports = "1.5.0"
FastClosures = "0.3.2"
FiniteDiff = "2.22"
FiniteDiff = "2.23.1"
ForwardDiff = "0.10.36"
LinearAlgebra = "1.10"
LinearSolve = "2.30"
Expand All @@ -59,13 +58,14 @@ PrecompileTools = "1.2"
Random = "1.10"
ReTestItems = "1.23"
Reexport = "1.2"
ReverseDiff = "1.15"
ReverseDiff = "1.15.3"
SciMLBase = "2.37.0"
SciMLSensitivity = "7.58"
Setfield = "1.1.1"
StaticArrays = "1.9"
StaticArraysCore = "1.4.2"
Test = "1.10"
Tracker = "0.2.32"
Tracker = "0.2.33"
Zygote = "0.6.69"
julia = "1.10"

Expand Down
20 changes: 0 additions & 20 deletions ext/SimpleNonlinearSolvePolyesterForwardDiffExt.jl

This file was deleted.

7 changes: 0 additions & 7 deletions ext/SimpleNonlinearSolveStaticArraysExt.jl

This file was deleted.

15 changes: 10 additions & 5 deletions src/SimpleNonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@ module SimpleNonlinearSolve
using PrecompileTools: @compile_workload, @setup_workload, @recompile_invalidations

@recompile_invalidations begin
using ADTypes: ADTypes, AutoFiniteDiff, AutoForwardDiff, AutoPolyesterForwardDiff
using ADTypes: ADTypes, AbstractADType, AutoFiniteDiff, AutoForwardDiff,
AutoPolyesterForwardDiff
using ArrayInterface: ArrayInterface
using ConcreteStructs: @concrete
using DiffEqBase: DiffEqBase, AbstractNonlinearTerminationMode,
AbstractSafeNonlinearTerminationMode,
AbstractSafeBestNonlinearTerminationMode, AbsNormTerminationMode,
NONLINEARSOLVE_DEFAULT_NORM
using DifferentiationInterface: DifferentiationInterface
using DiffResults: DiffResults
using FastClosures: @closure
using FiniteDiff: FiniteDiff
Expand All @@ -18,13 +20,16 @@ using PrecompileTools: @compile_workload, @setup_workload, @recompile_invalidati
mul!, norm, transpose
using MaybeInplace: @bb, setindex_trait, CanSetindex, CannotSetindex
using Reexport: @reexport
using SciMLBase: SciMLBase, IntervalNonlinearProblem, NonlinearFunction,
NonlinearLeastSquaresProblem, NonlinearProblem, ReturnCode, init,
remake, solve, AbstractNonlinearAlgorithm, build_solution, isinplace,
_unwrap_val
using SciMLBase: SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem,
NonlinearFunction, NonlinearLeastSquaresProblem, NonlinearProblem,
ReturnCode, init, remake, solve, AbstractNonlinearAlgorithm,
build_solution, isinplace, _unwrap_val
using Setfield: @set!
using StaticArraysCore: StaticArray, SVector, SMatrix, SArray, MArray, Size
end

const DI = DifferentiationInterface

@reexport using SciMLBase

abstract type AbstractSimpleNonlinearSolveAlgorithm <: AbstractNonlinearAlgorithm end
Expand Down
94 changes: 36 additions & 58 deletions src/ad.jl
Original file line number Diff line number Diff line change
@@ -1,25 +1,15 @@
function SciMLBase.solve(
prob::NonlinearProblem{<:Union{Number, <:AbstractArray}, iip,
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}},
alg::AbstractSimpleNonlinearSolveAlgorithm,
args...;
kwargs...) where {T, V, P, iip}
sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...)
dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p)
return SciMLBase.build_solution(
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
end

function SciMLBase.solve(
prob::NonlinearLeastSquaresProblem{
<:AbstractArray, iip, <:Union{<:AbstractArray{<:Dual{T, V, P}}}},
alg::AbstractSimpleNonlinearSolveAlgorithm,
args...;
kwargs...) where {T, V, P, iip}
sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...)
dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p)
return SciMLBase.build_solution(
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
for pType in (NonlinearProblem, NonlinearLeastSquaresProblem)
@eval function SciMLBase.solve(
prob::$(pType){<:Union{Number, <:AbstractArray}, iip,
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}},
alg::AbstractSimpleNonlinearSolveAlgorithm,
args...;
kwargs...) where {T, V, P, iip}
sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...)
dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p)
return SciMLBase.build_solution(
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
end
end

for algType in (Bisection, Brent, Alefeld, Falsi, ITP, Ridder)
Expand Down Expand Up @@ -47,8 +37,7 @@
tspan = value.(prob.tspan)
newprob = IntervalNonlinearProblem(prob.f, tspan, p; prob.kwargs...)
else
u0 = value(prob.u0)
newprob = NonlinearProblem(prob.f, u0, p; prob.kwargs...)
newprob = remake(prob; p, u0 = value(prob.u0))
end

sol = solve(newprob, alg, args...; kwargs...)
Expand All @@ -73,20 +62,16 @@
end

function __nlsolve_ad(prob::NonlinearLeastSquaresProblem, alg, args...; kwargs...)
p = value(prob.p)
u0 = value(prob.u0)
newprob = NonlinearLeastSquaresProblem(prob.f, u0, p; prob.kwargs...)

newprob = remake(prob; p = value(prob.p), u0 = value(prob.u0))
sol = solve(newprob, alg, args...; kwargs...)

uu = sol.u

# First check for custom `vjp` then custom `Jacobian` and if nothing is provided use
# nested autodiff as the last resort
if SciMLBase.has_vjp(prob.f)
if isinplace(prob)
_F = @closure (du, u, p) -> begin
resid = similar(du, length(sol.resid))
resid = __similar(du, length(sol.resid))
prob.f(resid, u, p)
prob.f.vjp(du, resid, u, p)
du .*= 2
Expand All @@ -101,9 +86,9 @@
elseif SciMLBase.has_jac(prob.f)
if isinplace(prob)
_F = @closure (du, u, p) -> begin
J = similar(du, length(sol.resid), length(u))
J = __similar(du, length(sol.resid), length(u))
prob.f.jac(J, u, p)
resid = similar(du, length(sol.resid))
resid = __similar(du, length(sol.resid))
prob.f(resid, u, p)
mul!(reshape(du, 1, :), vec(resid)', J, 2, false)
return nothing
Expand All @@ -116,43 +101,40 @@
else
if isinplace(prob)
_F = @closure (du, u, p) -> begin
resid = similar(du, length(sol.resid))
res = DiffResults.DiffResult(
resid, similar(du, length(sol.resid), length(u)))
_f = @closure (du, u) -> prob.f(du, u, p)
ForwardDiff.jacobian!(res, _f, resid, u)
mul!(reshape(du, 1, :), vec(DiffResults.value(res))',
DiffResults.jacobian(res), 2, false)
resid = __similar(du, length(sol.resid))
v, J = DI.value_and_jacobian(_f, resid, AutoForwardDiff(), u)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you may want to use preparation

mul!(reshape(du, 1, :), vec(v)', J, 2, false)
return nothing
end
else
# For small problems, nesting ForwardDiff is actually quite fast
if __is_extension_loaded(Val(:Zygote)) && (length(uu) + length(sol.resid) ≥ 50)
_F = @closure (u, p) -> __zygote_compute_nlls_vjp(prob.f, u, p)
# TODO: Remove once DI has the value_and_pullback_split defined
_F = @closure (u, p) -> begin
_f = Base.Fix2(prob.f, p)
return __zygote_compute_nlls_vjp(_f, u, p)

Check warning on line 116 in src/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/ad.jl#L114-L116

Added lines #L114 - L116 were not covered by tests
end
else
_F = @closure (u, p) -> begin
T = promote_type(eltype(u), eltype(p))
res = DiffResults.DiffResult(similar(u, T, size(sol.resid)),
similar(u, T, length(sol.resid), length(u)))
ForwardDiff.jacobian!(res, Base.Fix2(prob.f, p), u)
return reshape(
2 .* vec(DiffResults.value(res))' * DiffResults.jacobian(res),
size(u))
_f = Base.Fix2(prob.f, p)
v, J = DI.value_and_jacobian(_f, AutoForwardDiff(), u)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

preparation?

return reshape(2 .* vec(v)' * J, size(u))
end
end
end
end

f_p = __nlsolve_∂f_∂p(prob, _F, uu, p)
f_x = __nlsolve_∂f_∂u(prob, _F, uu, p)
f_p = __nlsolve_∂f_∂p(prob, _F, uu, newprob.p)
f_x = __nlsolve_∂f_∂u(prob, _F, uu, newprob.p)

z_arr = -f_x \ f_p

pp = prob.p
sumfun = ((z, p),) -> map(zᵢ -> zᵢ * ForwardDiff.partials(p), z)
if uu isa Number
partials = sum(sumfun, zip(z_arr, pp))
elseif p isa Number
elseif pp isa Number
partials = sumfun((z_arr, pp))
else
partials = sum(sumfun, zip(eachcol(z_arr), pp))
Expand All @@ -164,7 +146,7 @@
@inline function __nlsolve_∂f_∂p(prob, f::F, u, p) where {F}
if isinplace(prob)
__f = p -> begin
du = similar(u, promote_type(eltype(u), eltype(p)))
du = __similar(u, promote_type(eltype(u), eltype(p)))
f(du, u, p)
return du
end
Expand All @@ -182,16 +164,12 @@

@inline function __nlsolve_∂f_∂u(prob, f::F, u, p) where {F}
if isinplace(prob)
du = similar(u)
__f = (du, u) -> f(du, u, p)
ForwardDiff.jacobian(__f, du, u)
__f = @closure (du, u) -> f(du, u, p)
return ForwardDiff.jacobian(__f, __similar(u), u)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not use DI?

else
__f = Base.Fix2(f, p)
if u isa Number
return ForwardDiff.derivative(__f, u)
else
return ForwardDiff.jacobian(__f, u)
end
u isa Number && return ForwardDiff.derivative(__f, u)
return ForwardDiff.jacobian(__f, u)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DI?

end
end

Expand Down
15 changes: 9 additions & 6 deletions src/nlsolve/dfsane.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,8 @@
α_1 = one(T)
f_1 = fx_norm

history_f_k = if x isa SArray ||
(x isa Number && __is_extension_loaded(Val(:StaticArrays)))
ones(SVector{M, T}) * fx_norm
else
fill(fx_norm, M)
end
history_f_k = x isa SArray ? ones(SVector{M, T}) * fx_norm :
__history_vec(fx_norm, Val(M))

# Generate the cache
@bb x_cache = similar(x)
Expand Down Expand Up @@ -150,6 +146,8 @@
# Store function value
if history_f_k isa SVector
history_f_k = Base.setindex(history_f_k, fx_norm_new, mod1(k, M))
elseif history_f_k isa NTuple
@set! history_f_k[mod1(k, M)] = fx_norm_new
else
history_f_k[mod1(k, M)] = fx_norm_new
end
Expand All @@ -158,3 +156,8 @@

return build_solution(prob, alg, x, fx; retcode = ReturnCode.MaxIters)
end

@inline @generated function __history_vec(fx_norm, ::Val{M}) where {M}

Check warning on line 160 in src/nlsolve/dfsane.jl

View check run for this annotation

Codecov / codecov/patch

src/nlsolve/dfsane.jl#L160

Added line #L160 was not covered by tests
M ≥ 11 && return :(fill(fx_norm, M)) # Julia can't specialize here
return :(ntuple(Returns(fx_norm), $(M)))
end
Loading
Loading