Skip to content

Commit

Permalink
fix sparse adtype passed to hvp
Browse files Browse the repository at this point in the history
  • Loading branch information
Vaibhavdixit02 committed Jul 19, 2024
1 parent 35bd067 commit c1a5e1f
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 21 deletions.
2 changes: 1 addition & 1 deletion src/OptimizationBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ Base.length(::NullData) = 0

include("adtypes.jl")
include("cache.jl")
include("function.jl")
include("OptimizationDIExt.jl")
include("OptimizationDISparseExt.jl")
include("function.jl")

export solve, OptimizationCache, DEFAULT_CALLBACK, DEFAULT_DATA

Expand Down
8 changes: 4 additions & 4 deletions src/OptimizationDIExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import DifferentiationInterface: prepare_gradient, prepare_hessian, prepare_hvp,
hvp, jacobian
using ADTypes, SciMLBase

function OptimizationBase.instantiate_function(
function instantiate_function(
f::OptimizationFunction{true}, x, adtype::ADTypes.AbstractADType,
p = SciMLBase.NullParameters(), num_cons = 0)
_f = (θ, args...) -> first(f.f(θ, p, args...))
Expand Down Expand Up @@ -103,7 +103,7 @@ function OptimizationBase.instantiate_function(
lag_h, f.lag_hess_prototype)
end

function OptimizationBase.instantiate_function(
function instantiate_function(
f::OptimizationFunction{true}, cache::OptimizationBase.ReInitCache,
adtype::ADTypes.AbstractADType, num_cons = 0)
x = cache.u0
Expand Down Expand Up @@ -199,7 +199,7 @@ function OptimizationBase.instantiate_function(
lag_h, f.lag_hess_prototype)
end

function OptimizationBase.instantiate_function(
function instantiate_function(
f::OptimizationFunction{false}, x, adtype::ADTypes.AbstractADType,
p = SciMLBase.NullParameters(), num_cons = 0)
_f = (θ, args...) -> first(f.f(θ, p, args...))
Expand Down Expand Up @@ -295,7 +295,7 @@ function OptimizationBase.instantiate_function(
lag_h, f.lag_hess_prototype)
end

function OptimizationBase.instantiate_function(
function instantiate_function(
f::OptimizationFunction{false}, cache::OptimizationBase.ReInitCache,
adtype::ADTypes.AbstractADType, num_cons = 0)
x = cache.u0
Expand Down
52 changes: 36 additions & 16 deletions src/OptimizationDISparseExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@ function generate_sparse_adtype(adtype)
adtype.coloring_algorithm isa ADTypes.NoColoringAlgorithm
adtype = AutoSparse(adtype.dense_ad; sparsity_detector = TracerSparsityDetector(),
coloring_algorithm = GreedyColoringAlgorithm())
if !(adtype.dense_ad isa SciMLBase.NoAD) &&
if adtype.dense_ad isa ADTypes.AutoFiniteDiff
soadtype = AutoSparse(
DifferentiationInterface.SecondOrder(adtype.dense_ad, adtype.dense_ad),
sparsity_detector = TracerSparsityDetector(),
coloring_algorithm = GreedyColoringAlgorithm())
elseif !(adtype.dense_ad isa SciMLBase.NoAD) &&
ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode
soadtype = AutoSparse(
DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()),
Expand All @@ -32,7 +37,12 @@ function generate_sparse_adtype(adtype)
!(adtype.coloring_algorithm isa ADTypes.NoColoringAlgorithm)
adtype = AutoSparse(adtype.dense_ad; sparsity_detector = TracerSparsityDetector(),
coloring_algorithm = adtype.coloring_algorithm)
if !(adtype.dense_ad isa SciMLBase.NoAD) &&
if adtype.dense_ad isa ADTypes.AutoFiniteDiff
soadtype = AutoSparse(
DifferentiationInterface.SecondOrder(adtype.dense_ad, adtype.dense_ad),
sparsity_detector = TracerSparsityDetector(),
coloring_algorithm = adtype.coloring_algorithm)
elseif !(adtype.dense_ad isa SciMLBase.NoAD) &&
ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode
soadtype = AutoSparse(
DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()),
Expand All @@ -49,7 +59,12 @@ function generate_sparse_adtype(adtype)
adtype.coloring_algorithm isa ADTypes.NoColoringAlgorithm
adtype = AutoSparse(adtype.dense_ad; sparsity_detector = adtype.sparsity_detector,
coloring_algorithm = GreedyColoringAlgorithm())
if !(adtype.dense_ad isa SciMLBase.NoAD) &&
if adtype.dense_ad isa ADTypes.AutoFiniteDiff
soadtype = AutoSparse(
DifferentiationInterface.SecondOrder(adtype.dense_ad, adtype.dense_ad),
sparsity_detector = adtype.sparsity_detector,
coloring_algorithm = GreedyColoringAlgorithm())
elseif !(adtype.dense_ad isa SciMLBase.NoAD) &&
ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode
soadtype = AutoSparse(
DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()),
Expand All @@ -63,7 +78,12 @@ function generate_sparse_adtype(adtype)
coloring_algorithm = GreedyColoringAlgorithm())
end
else
if !(adtype.dense_ad isa SciMLBase.NoAD) &&
if adtype.dense_ad isa ADTypes.AutoFiniteDiff
soadtype = AutoSparse(
DifferentiationInterface.SecondOrder(adtype.dense_ad, adtype.dense_ad),
sparsity_detector = adtype.sparsity_detector,
coloring_algorithm = adtype.coloring_algorithm)
elseif !(adtype.dense_ad isa SciMLBase.NoAD) &&
ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode
soadtype = AutoSparse(
DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()),
Expand All @@ -80,7 +100,7 @@ function generate_sparse_adtype(adtype)
return adtype, soadtype
end

function OptimizationBase.instantiate_function(
function instantiate_function(
f::OptimizationFunction{true}, x, adtype::ADTypes.AutoSparse{<:AbstractADType},
p = SciMLBase.NullParameters(), num_cons = 0)
_f = (θ, args...) -> first(f.f(θ, p, args...))
Expand Down Expand Up @@ -108,9 +128,9 @@ function OptimizationBase.instantiate_function(
end

if f.hv === nothing
extras_hvp = prepare_hvp(_f, soadtype, x, rand(size(x)))
extras_hvp = prepare_hvp(_f, soadtype.dense_ad, x, rand(size(x)))
hv = function (H, θ, v, args...)
hvp!(_f, H, soadtype, θ, v, extras_hvp)
hvp!(_f, H, soadtype.dense_ad, θ, v, extras_hvp)
end
else
hv = f.hv
Expand Down Expand Up @@ -168,7 +188,7 @@ function OptimizationBase.instantiate_function(
lag_h, f.lag_hess_prototype)
end

function OptimizationBase.instantiate_function(
function instantiate_function(
f::OptimizationFunction{true}, cache::OptimizationBase.ReInitCache,
adtype::ADTypes.AutoSparse{<:AbstractADType}, num_cons = 0)
x = cache.u0
Expand Down Expand Up @@ -198,9 +218,9 @@ function OptimizationBase.instantiate_function(
end

if f.hv === nothing
extras_hvp = prepare_hvp(_f, soadtype, x, rand(size(x)))
extras_hvp = prepare_hvp(_f, soadtype.dense_ad, x, rand(size(x)))
hv = function (H, θ, v, args...)
hvp!(_f, H, soadtype, θ, v, extras_hvp)
hvp!(_f, H, soadtype.dense_ad, θ, v, extras_hvp)
end
else
hv = f.hv
Expand Down Expand Up @@ -258,7 +278,7 @@ function OptimizationBase.instantiate_function(
lag_h, f.lag_hess_prototype)
end

function OptimizationBase.instantiate_function(
function instantiate_function(
f::OptimizationFunction{false}, x, adtype::ADTypes.AutoSparse{<:AbstractADType},
p = SciMLBase.NullParameters(), num_cons = 0)
_f = (θ, args...) -> first(f.f(θ, p, args...))
Expand Down Expand Up @@ -286,9 +306,9 @@ function OptimizationBase.instantiate_function(
end

if f.hv === nothing
extras_hvp = prepare_hvp(_f, soadtype, x, rand(size(x)))
extras_hvp = prepare_hvp(_f, soadtype.dense_ad, x, rand(size(x)))
hv = function (θ, v, args...)
hvp(_f, soadtype, θ, v, extras_hvp)
hvp(_f, soadtype.dense_ad, θ, v, extras_hvp)
end
else
hv = f.hv
Expand Down Expand Up @@ -348,7 +368,7 @@ function OptimizationBase.instantiate_function(
lag_h, f.lag_hess_prototype)
end

function OptimizationBase.instantiate_function(
function instantiate_function(
f::OptimizationFunction{false}, cache::OptimizationBase.ReInitCache,
adtype::ADTypes.AutoSparse{<:AbstractADType}, num_cons = 0)
x = cache.u0
Expand Down Expand Up @@ -378,9 +398,9 @@ function OptimizationBase.instantiate_function(
end

if f.hv === nothing
extras_hvp = prepare_hvp(_f, soadtype, x, rand(size(x)))
extras_hvp = prepare_hvp(_f, soadtype.dense_ad, x, rand(size(x)))
hv = function (θ, v, args...)
hvp(_f, soadtype, θ, v, extras_hvp)
hvp(_f, soadtype.dense_ad, θ, v, extras_hvp)
end
else
hv = f.hv
Expand Down
57 changes: 57 additions & 0 deletions src/function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,63 @@ function that is not defined, an error is thrown.
For more information on the use of automatic differentiation, see the
documentation of the `AbstractADType` types.
"""
function instantiate_function(f::OptimizationFunction{true}, x, ::SciMLBase.NoAD,
p, num_cons = 0)
grad = f.grad === nothing ? nothing : (G, x, args...) -> f.grad(G, x, p, args...)
hess = f.hess === nothing ? nothing : (H, x, args...) -> f.hess(H, x, p, args...)
hv = f.hv === nothing ? nothing : (H, x, v, args...) -> f.hv(H, x, v, p, args...)
cons = f.cons === nothing ? nothing : (res, x) -> f.cons(res, x, p)
cons_j = f.cons_j === nothing ? nothing : (res, x) -> f.cons_j(res, x, p)
cons_h = f.cons_h === nothing ? nothing : (res, x) -> f.cons_h(res, x, p)
hess_prototype = f.hess_prototype === nothing ? nothing :
convert.(eltype(x), f.hess_prototype)
cons_jac_prototype = f.cons_jac_prototype === nothing ? nothing :
convert.(eltype(x), f.cons_jac_prototype)
cons_hess_prototype = f.cons_hess_prototype === nothing ? nothing :
[convert.(eltype(x), f.cons_hess_prototype[i])
for i in 1:num_cons]
expr = symbolify(f.expr)
cons_expr = symbolify.(f.cons_expr)

return OptimizationFunction{true}(f.f, SciMLBase.NoAD(); grad = grad, hess = hess,
hv = hv,
cons = cons, cons_j = cons_j, cons_h = cons_h,
hess_prototype = hess_prototype,
cons_jac_prototype = cons_jac_prototype,
cons_hess_prototype = cons_hess_prototype,
expr = expr, cons_expr = cons_expr,
sys = f.sys,
observed = f.observed)
end

function instantiate_function(f::OptimizationFunction{true}, cache::ReInitCache, ::SciMLBase.NoAD,
num_cons = 0)
grad = f.grad === nothing ? nothing : (G, x, args...) -> f.grad(G, x, cache.p, args...)
hess = f.hess === nothing ? nothing : (H, x, args...) -> f.hess(H, x, cache.p, args...)
hv = f.hv === nothing ? nothing : (H, x, v, args...) -> f.hv(H, x, v, cache.p, args...)
cons = f.cons === nothing ? nothing : (res, x) -> f.cons(res, x, cache.p)
cons_j = f.cons_j === nothing ? nothing : (res, x) -> f.cons_j(res, x, cache.p)
cons_h = f.cons_h === nothing ? nothing : (res, x) -> f.cons_h(res, x, cache.p)
hess_prototype = f.hess_prototype === nothing ? nothing :
convert.(eltype(cache.u0), f.hess_prototype)
cons_jac_prototype = f.cons_jac_prototype === nothing ? nothing :
convert.(eltype(cache.u0), f.cons_jac_prototype)
cons_hess_prototype = f.cons_hess_prototype === nothing ? nothing :
[convert.(eltype(cache.u0), f.cons_hess_prototype[i])
for i in 1:num_cons]
expr = symbolify(f.expr)
cons_expr = symbolify.(f.cons_expr)

return OptimizationFunction{true}(f.f, SciMLBase.NoAD(); grad = grad, hess = hess,
hv = hv,
cons = cons, cons_j = cons_j, cons_h = cons_h,
hess_prototype = hess_prototype,
cons_jac_prototype = cons_jac_prototype,
cons_hess_prototype = cons_hess_prototype,
expr = expr, cons_expr = cons_expr,
sys = f.sys,
observed = f.observed)
end

function instantiate_function(f::OptimizationFunction, x, adtype::ADTypes.AbstractADType,
p, num_cons = 0)
Expand Down

0 comments on commit c1a5e1f

Please sign in to comment.