Skip to content

Commit

Permalink
Flesh out sparsity and secondorder things more, change the extension …
Browse files Browse the repository at this point in the history
…structure
  • Loading branch information
Vaibhavdixit02 committed Jul 15, 2024
1 parent ad7ca08 commit 4fe1f54
Show file tree
Hide file tree
Showing 10 changed files with 112 additions and 113 deletions.
11 changes: 6 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ version = "1.3.3"
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Manifolds = "1cead3c2-87b3-11e9-0ccd-23c62b72b94e"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
Expand All @@ -21,18 +22,20 @@ SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"

[weakdeps]
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
OptimizationDIExt = ["DifferentiationInterface", "ForwardDiff", "ReverseDiff"]
OptimizationForwardDiffExt = "ForwardDiff"
OptimizationFiniteDiffExt = "FiniteDiff"
OptimizationReverseDiffExt = "ReverseDiff"
OptimizationEnzymeExt = "Enzyme"
OptimizationMTKExt = "ModelingToolkit"
OptimizationZygoteExt = ["Zygote", "DifferentiationInterface"]
OptimizationZygoteExt = "Zygote"

[compat]
ADTypes = "1.3"
Expand All @@ -44,11 +47,9 @@ ModelingToolkit = "9"
Reexport = "1.2"
Requires = "1"
SciMLBase = "2"
SparseDiffTools = "2.14"
SymbolicAnalysis = "0.1, 0.2"
SymbolicIndexingInterface = "0.3"
Symbolics = "5.12"
Tracker = "0.2.29"
Zygote = "0.6.67"
julia = "1.10"

Expand Down
5 changes: 5 additions & 0 deletions ext/OptimizationFiniteDiffExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
module OptimizationFiniteDiffExt

using DifferentiationInterface, FiniteDiff

end
5 changes: 5 additions & 0 deletions ext/OptimizationForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
module OptimizationForwardDiffExt

using DifferentiationInterface, ForwardDiff

end
6 changes: 3 additions & 3 deletions ext/OptimizationMTKExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ import OptimizationBase.ADTypes: AutoModelingToolkit, AutoSymbolics, AutoSparse
using ModelingToolkit

function OptimizationBase.instantiate_function(
f::OptimizationFunction{true}, x, adtype::AutoSparse{<:AutoSymbolics, S, C}, p,
num_cons = 0) where {S, C}
f::OptimizationFunction{true}, x, adtype::AutoSparse{<:AutoSymbolics}, p,
num_cons = 0)
p = isnothing(p) ? SciMLBase.NullParameters() : p

sys = complete(ModelingToolkit.modelingtoolkitize(OptimizationProblem(f, x, p;
Expand Down Expand Up @@ -53,7 +53,7 @@ function OptimizationBase.instantiate_function(
end

function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, cache::OptimizationBase.ReInitCache,
adtype::AutoSparse{<:AutoSymbolics, S, C}, num_cons = 0) where {S, C}
adtype::AutoSparse{<:AutoSymbolics}, num_cons = 0)
p = isnothing(cache.p) ? SciMLBase.NullParameters() : cache.p

sys = complete(ModelingToolkit.modelingtoolkitize(OptimizationProblem(f, cache.u0,
Expand Down
5 changes: 5 additions & 0 deletions ext/OptimizationReverseDiffExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
module OptimizationReverseDiffExt

using DifferentiationInterface, ReverseDiff

end
2 changes: 2 additions & 0 deletions src/OptimizationBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ Base.length(::NullData) = 0
include("adtypes.jl")
include("cache.jl")
include("function.jl")
include("OptimizationDIExt.jl")
include("OptimizationDISparseExt.jl")

export solve, OptimizationCache, DEFAULT_CALLBACK, DEFAULT_DATA

Expand Down
49 changes: 30 additions & 19 deletions ext/OptimizationDIExt.jl → src/OptimizationDIExt.jl
Original file line number Diff line number Diff line change
@@ -1,20 +1,18 @@
module OptimizationDIExt

import OptimizationBase, OptimizationBase.ArrayInterface
using OptimizationBase
import OptimizationBase.ArrayInterface
import OptimizationBase.SciMLBase: OptimizationFunction
import OptimizationBase.LinearAlgebra: I
import DifferentiationInterface
import DifferentiationInterface: prepare_gradient, prepare_hessian, prepare_hvp, prepare_jacobian,
gradient!, hessian!, hvp!, jacobian!, gradient, hessian, hvp, jacobian
using ADTypes
import ForwardDiff, ReverseDiff
using ADTypes, SciMLBase

function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, adtype::ADTypes.AbstractADType, p = SciMLBase.NullParameters(), num_cons = 0)
_f = (θ, args...) -> first(f.f(θ, p, args...))

if ADTypes.mode(adtype) isa ADTypes.ForwardMode
soadtype = DifferentiationInterface.SecondOrder(adtype, AutoReverseDiff())
elseif ADTypes.mode(adtype) isa ADTypes.ReverseMode
if !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype) isa ADTypes.ForwardMode
soadtype = DifferentiationInterface.SecondOrder(adtype, AutoReverseDiff()) #make zygote?
elseif !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype) isa ADTypes.ReverseMode
soadtype = DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype)
end

Expand All @@ -32,7 +30,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
if f.hess === nothing
extras_hess = prepare_hessian(_f, soadtype, x)
function hess(res, θ, args...)
hessian!(_f, res, adtype, θ, extras_hess)
hessian!(_f, res, soadtype, θ, extras_hess)
end
else
hess = (H, θ, args...) -> f.hess(H, θ, p, args...)
Expand Down Expand Up @@ -79,7 +77,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,

function cons_h(H, θ)
for i in 1:num_cons
hessian!(fncs[i], H[i], adtype, θ, extras_cons_hess[i])
hessian!(fncs[i], H[i], soadtype, θ, extras_cons_hess[i])
end
end
else
Expand All @@ -106,7 +104,12 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, ca
x = cache.u0
p = cache.p
_f = (θ, args...) -> first(f.f(θ, p, args...))
soadtype = DifferentiationInterface.SecondOrder(adtype, adtype)

if !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype) isa ADTypes.ForwardMode
soadtype = DifferentiationInterface.SecondOrder(adtype, AutoReverseDiff()) #make zygote?
elseif !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype) isa ADTypes.ReverseMode
soadtype = DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype)
end

if f.grad === nothing
extras_grad = prepare_gradient(_f, adtype, x)
Expand Down Expand Up @@ -169,7 +172,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, ca

function cons_h(H, θ)
for i in 1:num_cons
hessian!(fncs[i], H[i], adtype, θ, extras_cons_hess[i])
hessian!(fncs[i], H[i], soadtype, θ, extras_cons_hess[i])
end
end
else
Expand All @@ -195,7 +198,12 @@ end

function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x, adtype::ADTypes.AbstractADType, p = SciMLBase.NullParameters(), num_cons = 0)
_f = (θ, args...) -> first(f.f(θ, p, args...))
soadtype = DifferentiationInterface.SecondOrder(adtype, adtype)

if !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype) isa ADTypes.ForwardMode
soadtype = DifferentiationInterface.SecondOrder(adtype, AutoReverseDiff()) #make zygote?
elseif !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype) isa ADTypes.ReverseMode
soadtype = DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype)
end

if f.grad === nothing
extras_grad = prepare_gradient(_f, adtype, x)
Expand All @@ -211,7 +219,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
if f.hess === nothing
extras_hess = prepare_hessian(_f, soadtype, x) #placeholder logic, can be made much better
function hess(θ, args...)
hessian(_f, adtype, θ, extras_hess)
hessian(_f, soadtype, θ, extras_hess)
end
else
hess = (θ, args...) -> f.hess(θ, p, args...)
Expand Down Expand Up @@ -259,7 +267,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x

function cons_h(θ)
H = map(1:num_cons) do i
hessian(fncs[i], adtype, θ, extras_cons_hess[i])
hessian(fncs[i], soadtype, θ, extras_cons_hess[i])
end
return H
end
Expand Down Expand Up @@ -287,7 +295,12 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, c
x = cache.u0
p = cache.p
_f = (θ, args...) -> first(f.f(θ, p, args...))
soadtype = DifferentiationInterface.SecondOrder(adtype, adtype)

if !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype) isa ADTypes.ForwardMode
soadtype = DifferentiationInterface.SecondOrder(adtype, AutoReverseDiff()) #make zygote?
elseif !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype) isa ADTypes.ReverseMode
soadtype = DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype)
end

if f.grad === nothing
extras_grad = prepare_gradient(_f, adtype, x)
Expand Down Expand Up @@ -351,7 +364,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, c

function cons_h(θ)
H = map(1:num_cons) do i
hessian(fncs[i], adtype, θ, extras_cons_hess[i])
hessian(fncs[i], soadtype, θ, extras_cons_hess[i])
end
return H
end
Expand All @@ -374,5 +387,3 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, c
cons_hess_colorvec = conshess_colors,
lag_h, f.lag_hess_prototype)
end

end
73 changes: 50 additions & 23 deletions ext/OptimizationDISparseExt.jl → src/OptimizationDISparseExt.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
module OptimizationDIExt

import OptimizationBase, OptimizationBase.ArrayInterface
using OptimizationBase
import OptimizationBase.ArrayInterface
import OptimizationBase.SciMLBase: OptimizationFunction
import OptimizationBase.LinearAlgebra: I
import DifferentiationInterface
Expand All @@ -9,21 +8,48 @@ import DifferentiationInterface: prepare_gradient, prepare_hessian, prepare_hvp,
using ADTypes
using SparseConnectivityTracer, SparseMatrixColorings

function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, adtype::ADTypes.AutoSparse, p = SciMLBase.NullParameters(), num_cons = 0)
_f = (θ, args...) -> first(f.f(θ, p, args...))

function generate_sparse_adtype(adtype)
if adtype.sparsity_detector isa ADTypes.NoSparsityDetector && adtype.coloring_algorithm isa ADTypes.NoColoringAlgorithm
adtype = AutoSparse(adtype.dense_ad; sparsity_detector = TracerLocalSparsityDetector(), coloring_algorithm = GreedyColoringAlgorithm())
elseif adtype.sparsity_detector isa ADTypes.NoSparsityDetector && !(adtype.coloring_algorithm isa AbstractADTypes.NoColoringAlgorithm)
if !(adtype.dense_ad isa SciMLBase.NoAD) && ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode
soadtype = AutoSparse(DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()), sparsity_detector = TracerLocalSparsityDetector(), coloring_algorithm = GreedyColoringAlgorithm()) #make zygote?
elseif !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype) isa ADTypes.ReverseMode
soadtype = AutoSparse(DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype), sparsity_detector = TracerLocalSparsityDetector(), coloring_algorithm = GreedyColoringAlgorithm())
end
elseif adtype.sparsity_detector isa ADTypes.NoSparsityDetector && !(adtype.coloring_algorithm isa ADTypes.NoColoringAlgorithm)
adtype = AutoSparse(adtype.dense_ad; sparsity_detector = TracerLocalSparsityDetector(), coloring_algorithm = adtype.coloring_algorithm)
if !(adtype.dense_ad isa SciMLBase.NoAD) && ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode
soadtype = AutoSparse(DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()), sparsity_detector = TracerLocalSparsityDetector(), coloring_algorithm = adtype.coloring_algorithm)
elseif !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype) isa ADTypes.ReverseMode
soadtype = AutoSparse(DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype), sparsity_detector = TracerLocalSparsityDetector(), coloring_algorithm = adtype.coloring_algorithm)
end
elseif !(adtype.sparsity_detector isa ADTypes.NoSparsityDetector) && adtype.coloring_algorithm isa ADTypes.NoColoringAlgorithm
adtype = AutoSparse(adtype.dense_ad; sparsity_detector = adtype.sparsity_detector, coloring_algorithm = GreedyColoringAlgorithm())
if !(adtype.dense_ad isa SciMLBase.NoAD) && ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode
soadtype = AutoSparse(DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()), sparsity_detector = adtype.sparsity_detector, coloring_algorithm = GreedyColoringAlgorithm())
elseif !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype) isa ADTypes.ReverseMode
soadtype = AutoSparse(DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype), sparsity_detector = adtype.sparsity_detector, coloring_algorithm = GreedyColoringAlgorithm())
end
else
if !(adtype.dense_ad isa SciMLBase.NoAD) && ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode
soadtype = AutoSparse(DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()), sparsity_detector = adtype.sparsity_detector, coloring_algorithm = adtype.coloring_algorithm)
elseif !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype) isa ADTypes.ReverseMode
soadtype = AutoSparse(DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype), sparsity_detector = adtype.sparsity_detector, coloring_algorithm = adtype.coloring_algorithm)
end
end
return adtype,soadtype
end


function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, adtype::ADTypes.AutoSparse{<:AbstractADType}, p = SciMLBase.NullParameters(), num_cons = 0)
_f = (θ, args...) -> first(f.f(θ, p, args...))

adtype, soadtype = generate_sparse_adtype(adtype)

if f.grad === nothing
extras_grad = prepare_gradient(_f, adtype, x)
extras_grad = prepare_gradient(_f, adtype.dense_ad, x)
function grad(res, θ)
gradient!(_f, res, adtype, θ, extras_grad)
gradient!(_f, res, adtype.dense_ad, θ, extras_grad)
end
else
grad = (G, θ, args...) -> f.grad(G, θ, p, args...)
Expand All @@ -34,7 +60,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
if f.hess === nothing
extras_hess = prepare_hessian(_f, soadtype, x) #placeholder logic, can be made much better
function hess(res, θ, args...)
hessian!(_f, res, adtype, θ, extras_hess)
hessian!(_f, res, soadtype, θ, extras_hess)
end
else
hess = (H, θ, args...) -> f.hess(H, θ, p, args...)
Expand Down Expand Up @@ -81,7 +107,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,

function cons_h(H, θ)
for i in 1:num_cons
hessian!(fncs[i], H[i], adtype, θ, extras_cons_hess[i])
hessian!(fncs[i], H[i], soadtype, θ, extras_cons_hess[i])
end
end
else
Expand All @@ -104,11 +130,12 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
lag_h, f.lag_hess_prototype)
end

function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, cache::OptimizationBase.ReInitCache, adtype::ADTypes.AbstractADType, num_cons = 0)
function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, cache::OptimizationBase.ReInitCache, adtype::ADTypes.AutoSparse{<:AbstractADType}, num_cons = 0)
x = cache.u0
p = cache.p
_f = (θ, args...) -> first(f.f(θ, p, args...))
soadtype = DifferentiationInterface.SecondOrder(adtype, adtype)

adtype, soadtype = generate_sparse_adtype(adtype)

if f.grad === nothing
extras_grad = prepare_gradient(_f, adtype, x)
Expand Down Expand Up @@ -171,7 +198,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, ca

function cons_h(H, θ)
for i in 1:num_cons
hessian!(fncs[i], H[i], adtype, θ, extras_cons_hess[i])
hessian!(fncs[i], H[i], soadtype, θ, extras_cons_hess[i])
end
end
else
Expand All @@ -195,9 +222,10 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, ca
end


function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x, adtype::ADTypes.AbstractADType, p = SciMLBase.NullParameters(), num_cons = 0)
function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x, adtype::ADTypes.AutoSparse{<:AbstractADType}, p = SciMLBase.NullParameters(), num_cons = 0)
_f = (θ, args...) -> first(f.f(θ, p, args...))
soadtype = DifferentiationInterface.SecondOrder(adtype, adtype)

adtype, soadtype = generate_sparse_adtype(adtype)

if f.grad === nothing
extras_grad = prepare_gradient(_f, adtype, x)
Expand All @@ -213,7 +241,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
if f.hess === nothing
extras_hess = prepare_hessian(_f, soadtype, x) #placeholder logic, can be made much better
function hess(θ, args...)
hessian(_f, adtype, θ, extras_hess)
hessian(_f, soadtype, θ, extras_hess)
end
else
hess = (θ, args...) -> f.hess(θ, p, args...)
Expand Down Expand Up @@ -261,7 +289,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x

function cons_h(θ)
H = map(1:num_cons) do i
hessian(fncs[i], adtype, θ, extras_cons_hess[i])
hessian(fncs[i], soadtype, θ, extras_cons_hess[i])
end
return H
end
Expand All @@ -285,11 +313,12 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
lag_h, f.lag_hess_prototype)
end

function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, cache::OptimizationBase.ReInitCache, adtype::ADTypes.AbstractADType, num_cons = 0)
function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, cache::OptimizationBase.ReInitCache, adtype::ADTypes.AutoSparse{<:AbstractADType}, num_cons = 0)
x = cache.u0
p = cache.p
_f = (θ, args...) -> first(f.f(θ, p, args...))
soadtype = DifferentiationInterface.SecondOrder(adtype, adtype)

adtype, soadtype = generate_sparse_adtype(adtype)

if f.grad === nothing
extras_grad = prepare_gradient(_f, adtype, x)
Expand Down Expand Up @@ -353,7 +382,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, c

function cons_h(θ)
H = map(1:num_cons) do i
hessian(fncs[i], adtype, θ, extras_cons_hess[i])
hessian(fncs[i], soadtype, θ, extras_cons_hess[i])
end
return H
end
Expand All @@ -376,5 +405,3 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, c
cons_hess_colorvec = conshess_colors,
lag_h, f.lag_hess_prototype)
end

end
Loading

0 comments on commit 4fe1f54

Please sign in to comment.