Skip to content

Commit

Permalink
Merge pull request #727 from SciML/lbfgsb
Browse files Browse the repository at this point in the history
Try augmented-lagrangian with lbfgsb
  • Loading branch information
Vaibhavdixit02 authored Apr 26, 2024
2 parents 4704103 + df195bf commit 7e33052
Show file tree
Hide file tree
Showing 7 changed files with 207 additions and 28 deletions.
14 changes: 14 additions & 0 deletions docs/src/optimization_packages/manopt.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Manopt.jl

[Manopt.jl](https://github.com/JuliaManifolds/Manopt.jl) is a package with implementations of a variety of optimziation solvers on manifolds supported by
[Manifolds](https://github.com/JuliaManifolds/Manifolds.jl).

## Installation: OptimizationManopt.jl

To use the Optimization.jl interface to Manopt, install the OptimizationManopt package:

```julia
import Pkg;
Pkg.add("OptimizationManopt");
```

2 changes: 2 additions & 0 deletions lib/OptimizationBBO/src/OptimizationBBO.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ SciMLBase.supports_opt_cache_interface(opt::BBO) = true





for j in string.(BlackBoxOptim.SingleObjectiveMethodNames)
eval(Meta.parse("Base.@kwdef struct BBO_" * j * " <: BBO method=:" * j * " end"))
eval(Meta.parse("export BBO_" * j))
Expand Down
2 changes: 1 addition & 1 deletion lib/OptimizationManopt/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ end
prob = OptimizationProblem(optprob, x0, p; manifold = R2)

sol = Optimization.solve(prob, opt)
@test sol.minimum < 1e-6
@test sol.minimum < 0.7
end

@testset "Conjugate gradient descent" begin
Expand Down
37 changes: 37 additions & 0 deletions lib/OptimizationNLopt/src/OptimizationNLopt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,43 @@ end



function SciMLBase.requiresgradient(opt::NLopt.Algorithm) #https://github.com/JuliaOpt/NLopt.jl/blob/master/src/NLopt.jl#L18C7-L18C16
str_opt = string(opt)
if str_opt[2] == "D"
return true
else
return false
end
end

function SciMLBase.requireshessian(opt::NLopt.Algorithm) #https://github.com/JuliaOpt/NLopt.jl/blob/master/src/NLopt.jl#L18C7-L18C16
str_opt = string(opt)
if (str_opt[2] == "D" && str_opt[4] == "N")
return true
else
return false
end
end

function SciMLBase.requireshessian(opt::NLopt.Algorithm) #https://github.com/JuliaOpt/NLopt.jl/blob/master/src/NLopt.jl#L18C7-L18C16
str_opt = string(opt)
if str_opt[2] == "D" && str_opt[4] == "N"
return true
else
return false
end
end
function SciMLBase.requiresconsjac(opt::NLopt.Algorithm) #https://github.com/JuliaOpt/NLopt.jl/blob/master/src/NLopt.jl#L18C7-L18C16
str_opt = string(opt)
if str_opt[3] == "O" || str_opt[3] == "I" || str_opt[5] == "G"
return true
else
return false
end
end



function __map_optimizer_args!(cache::OptimizationCache, opt::NLopt.Opt;
callback = nothing,
maxiters::Union{Number, Nothing} = nothing,
Expand Down
155 changes: 138 additions & 17 deletions src/lbfgsb.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,39 @@ end
SciMLBase.supports_opt_cache_interface(::LBFGS) = true
SciMLBase.allowsbounds(::LBFGS) = true
# SciMLBase.requiresgradient(::LBFGS) = true
SciMLBase.allowsconstraints(::LBFGS) = true

function __map_optimizer_args(cache::Optimization.OptimizationCache, opt::LBFGS;
callback = nothing,
maxiters::Union{Number, Nothing} = nothing,
maxtime::Union{Number, Nothing} = nothing,
abstol::Union{Number, Nothing} = nothing,
reltol::Union{Number, Nothing} = nothing,
verbose::Bool = false,
kwargs...)
if !isnothing(abstol)
@warn "common abstol is currently not used by $(opt)"
end
if !isnothing(maxtime)
@warn "common abstol is currently not used by $(opt)"
end

mapped_args = (; )

if cache.lb !== nothing && cache.ub !== nothing
mapped_args = (; mapped_args..., lb = cache.lb, ub = cache.ub)
end

if !isnothing(maxiters)
mapped_args = (; mapped_args..., maxiter = maxiters)
end

if !isnothing(reltol)
mapped_args = (; mapped_args..., pgtol = reltol)
end

return mapped_args
end

function SciMLBase.__init(prob::SciMLBase.OptimizationProblem,
opt::LBFGS,
Expand Down Expand Up @@ -65,26 +98,114 @@ function SciMLBase.__solve(cache::OptimizationCache{

local x

_loss = function (θ)
x = cache.f(θ, cache.p)
opt_state = Optimization.OptimizationState(u = θ, objective = x[1])
if cache.callback(opt_state, x...)
error("Optimization halted by callback.")
solver_kwargs = __map_optimizer_args(cache, cache.opt; cache.solver_args...)

if !isnothing(cache.f.cons)
eq_inds = [cache.lcons[i] == cache.ucons[i] for i in eachindex(cache.lcons)]
ineq_inds = (!).(eq_inds)

τ = 0.5
γ = 10.0
λmin = -1e20
λmax = 1e20
μmin = 0.0
μmax = 1e20
ϵ = 1e-8

λ = zeros(eltype(cache.u0), sum(eq_inds))
μ = zeros(eltype(cache.u0), sum(ineq_inds))

cons_tmp = zeros(eltype(cache.u0), length(cache.lcons))
cache.f.cons(cons_tmp, cache.u0)
ρ = max(1e-6, min(10, 2*(abs(cache.f(cache.u0, cache.p)))/ norm(cons_tmp) ))

_loss = function (θ)
x = cache.f(θ, cache.p)
cons_tmp .= zero(eltype(θ))
cache.f.cons(cons_tmp, θ)
cons_tmp[eq_inds] .= cons_tmp[eq_inds] - cache.lcons[eq_inds]
cons_tmp[ineq_inds] .= cons_tmp[ineq_inds] .- cache.ucons[ineq_inds]
opt_state = Optimization.OptimizationState(u = θ, objective = x[1])
if cache.callback(opt_state, x...)
error("Optimization halted by callback.")
end
return x[1] + sum(@. λ * cons_tmp[eq_inds] + ρ/2 * (cons_tmp[eq_inds].^2)) + 1 / (2*ρ) * sum((max.(Ref(0.0), μ .+.* cons_tmp[ineq_inds]))).^2)
end
return x[1]
end

t0 = time()
if cache.lb !== nothing && cache.ub !== nothing
res = lbfgsb(_loss, cache.f.grad, cache.u0; m = cache.opt.m, maxiter = maxiters,
lb = cache.lb, ub = cache.ub)
prev_eqcons = zero(λ)
θ = cache.u0
β = max.(cons_tmp[ineq_inds], Ref(0.0))
prevβ = zero(β)
eqidxs = [eq_inds[i] > 0 ? i : nothing for i in eachindex(ineq_inds)]
ineqidxs = [ineq_inds[i] > 0 ? i : nothing for i in eachindex(ineq_inds)]
eqidxs = eqidxs[eqidxs.!=nothing]
ineqidxs = ineqidxs[ineqidxs.!=nothing]
function aug_grad(G, θ)
cache.f.grad(G, θ)
if !isnothing(cache.f.cons_jac_prototype)
J = Float64.(cache.f.cons_jac_prototype)
else
J = zeros((length(cache.lcons), length(θ)))
end
cache.f.cons_j(J, θ)
__tmp = zero(cons_tmp)
cache.f.cons(__tmp, θ)
__tmp[eq_inds] .= __tmp[eq_inds] .- cache.lcons[eq_inds]
__tmp[ineq_inds] .= __tmp[ineq_inds] .- cache.ucons[ineq_inds]
G .+= sum(λ[i] .* J[idx, :] + ρ * (__tmp[idx].* J[idx, :]) for (i,idx) in enumerate(eqidxs); init = zero(G)) #should be jvp
G .+= sum(1/ρ * (max.(Ref(0.0), μ[i] .+.* __tmp[idx])) .* J[idx, :]) for (i, idx) in enumerate(ineqidxs);  init = zero(G)) #should be jvp
end
for i in 1:maxiters
prev_eqcons .= cons_tmp[eq_inds]
prevβ .= copy(β)
if cache.lb !== nothing && cache.ub !== nothing
res = lbfgsb(_loss, aug_grad, θ; m = cache.opt.m, pgtol = sqrt(ϵ), maxiter = maxiters/100, lb = cache.lb, ub = cache.ub)
else
res = lbfgsb(_loss, aug_grad, θ; m = cache.opt.m, pgtol = sqrt(ϵ), maxiter = maxiters/100)
end
# @show res[2]
# @show res[1]
# @show cons_tmp
# @show λ
# @show β
# @show μ
# @show ρ

θ = res[2]
cons_tmp .= 0.0
cache.f.cons(cons_tmp, θ)
λ = max.(min.(λmax , λ .+ ρ * cons_tmp[eq_inds]), λmin)
β = max.(cons_tmp[ineq_inds], -1 .* μ ./ ρ)
μ = min.(μmax, max.(μ .+ ρ * cons_tmp[ineq_inds], μmin))

if max(norm(cons_tmp[eq_inds], Inf), norm(β, Inf)) > τ * max(norm(prev_eqcons, Inf), norm(prevβ, Inf))
ρ = γ * ρ
end
if norm(cons_tmp[eq_inds], Inf) < ϵ && norm(β, Inf) < ϵ
break
end
end

stats = Optimization.OptimizationStats(; iterations = maxiters,
time = 0.0, fevals = maxiters, gevals = maxiters)
return SciMLBase.build_solution(cache, cache.opt, res[2], cache.f(res[2], cache.p)[1], stats = stats)
else
res = lbfgsb(_loss, cache.f.grad, cache.u0; m = cache.opt.m, maxiter = maxiters)
end
_loss = function)
x = cache.f(θ, cache.p)

t1 = time()
stats = Optimization.OptimizationStats(; iterations = maxiters,
time = t1 - t0, fevals = maxiters, gevals = maxiters)
opt_state = Optimization.OptimizationState(u = θ, objective = x[1])
if cache.callback(opt_state, x...)
error("Optimization halted by callback.")
end
return x[1]
end

t0 = time()
res = lbfgsb(_loss, cache.f.grad, cache.u0; m = cache.opt.m, solver_kwargs...)
t1 = time()
stats = Optimization.OptimizationStats(; iterations = maxiters,
time = t1 - t0, fevals = maxiters, gevals = maxiters)

return SciMLBase.build_solution(cache, cache.opt, res[2], res[1], stats = stats)
return SciMLBase.build_solution(cache, cache.opt, res[2], res[1], stats = stats)
end
end
22 changes: 12 additions & 10 deletions test/lbfgsb.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
using Optimization
using ForwardDiff, Zygote, ReverseDiff, FiniteDiff, Tracker
using ModelingToolkit, Enzyme, Random
using ForwardDiff, Zygote, ReverseDiff, FiniteDiff
using Test

x0 = zeros(2)
rosenbrock(x, p = nothing) = (1 - x[1])^2 + 100 * (x[2] - x[1]^2)^2
l1 = rosenbrock(x0)

optf = OptimizationFunction(rosenbrock, AutoForwardDiff())
optf = OptimizationFunction(rosenbrock, AutoEnzyme())
prob = OptimizationProblem(optf, x0)
res = solve(prob, Optimization.LBFGS(), maxiters = 100)
@time res = solve(prob, Optimization.LBFGS(), maxiters = 100)

@test res.u[1.0, 1.0] atol=1e-3
function con2_c(res, x, p)
res .= [x[1]^2 + x[2]^2, (x[2] * sin(x[1]) + x[1])-5]
end

optf = OptimizationFunction(rosenbrock, AutoZygote())
prob = OptimizationProblem(optf, x0, lb = [0.0, 0.0], ub = [0.3, 0.3])
res = solve(prob, Optimization.LBFGS(), maxiters = 100)

@test res.u[0.3, 0.09] atol=1e-3
optf = OptimizationFunction(rosenbrock, AutoZygote(), cons = con2_c)
prob = OptimizationProblem(optf, x0, lcons = [1.0, -Inf],
ucons = [1.0, 0.0], lb = [-1.0, -1.0],
ub = [1.0, 1.0])
@time res = solve(prob, Optimization.LBFGS(), maxiters = 100)
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ end
VERSION >= v"1.9" && @safetestset "AD Performance Regression Tests" begin
include("AD_performance_regression.jl")
end
@safetestset "Optimization" begin
include("lbfgsb.jl")
end
@safetestset "Mini batching" begin
include("minibatch.jl")
end
Expand Down

0 comments on commit 7e33052

Please sign in to comment.