Skip to content

Commit

Permalink
Merge pull request #716 from sethaxen/optimjl_state_grad_hess
Browse files Browse the repository at this point in the history
Store grad/hess in state for Optim.jl
  • Loading branch information
Vaibhavdixit02 authored Mar 16, 2024
2 parents 1723965 + 0a417ea commit 7e115c9
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 13 deletions.
18 changes: 13 additions & 5 deletions lib/OptimizationOptimJL/src/OptimizationOptimJL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,13 @@ function SciMLBase.__solve(cache::OptimizationCache{
error("Use OptimizationFunction to pass the derivatives or automatically generate them with one of the autodiff backends")

function _cb(trace)
θ = cache.opt isa Optim.NelderMead ? decompose_trace(trace).metadata["centroid"] :
decompose_trace(trace).metadata["x"]
metadata = decompose_trace(trace).metadata
θ = metadata[cache.opt isa Optim.NelderMead ? "centroid" : "x"]
opt_state = Optimization.OptimizationState(iter = trace.iteration,
u = θ,
objective = x[1],
grad = get(metadata, "g(x)", nothing),
hess = get(metadata, "h(x)", nothing),
original = trace)
cb_call = cache.callback(opt_state, x...)
if !(cb_call isa Bool)
Expand Down Expand Up @@ -252,12 +254,15 @@ function SciMLBase.__solve(cache::OptimizationCache{
cur, state = iterate(cache.data)

function _cb(trace)
metadata = decompose_trace(trace).metadata
θ = !(cache.opt isa Optim.SAMIN) && cache.opt.method == Optim.NelderMead() ?
decompose_trace(trace).metadata["centroid"] :
decompose_trace(trace).metadata["x"]
metadata["centroid"] :
metadata["x"]
opt_state = Optimization.OptimizationState(iter = trace.iteration,
u = θ,
objective = x[1],
grad = get(metadata, "g(x)", nothing),
hess = get(metadata, "h(x)", nothing),
original = trace)
cb_call = cache.callback(opt_state, x...)
if !(cb_call isa Bool)
Expand Down Expand Up @@ -341,8 +346,11 @@ function SciMLBase.__solve(cache::OptimizationCache{
cur, state = iterate(cache.data)

function _cb(trace)
metadata = decompose_trace(trace).metadata
opt_state = Optimization.OptimizationState(iter = trace.iteration,
u = decompose_trace(trace).metadata["x"],
u = metadata["x"],
grad = get(metadata, "g(x)", nothing),
hess = get(metadata, "h(x)", nothing),
objective = x[1],
original = trace)
cb_call = cache.callback(opt_state, x...)
Expand Down
52 changes: 44 additions & 8 deletions lib/OptimizationOptimJL/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,32 @@ using OptimizationOptimJL,
Random, ModelingToolkit
using Test

struct CallbackTester
dim::Int
has_grad::Bool
has_hess::Bool
end
function CallbackTester(dim::Int; has_grad = false, has_hess = false)
CallbackTester(dim, has_grad, has_hess)
end

function (cb::CallbackTester)(state, loss_val)
@test length(state.u) == cb.dim
if cb.has_grad
@test state.grad isa AbstractVector
@test length(state.grad) == cb.dim
else
@test state.grad === nothing
end
if cb.has_hess
@test state.hess isa AbstractMatrix
@test size(state.hess) == (cb.dim, cb.dim)
else
@test state.hess === nothing
end
return false
end

@testset "OptimizationOptimJL.jl" begin
rosenbrock(x, p) = (p[1] - x[1])^2 + p[2] * (x[2] - x[1]^2)^2
x0 = zeros(2)
Expand All @@ -13,34 +39,43 @@ using Test
sol = solve(prob,
Optim.NelderMead(;
initial_simplex = Optim.AffineSimplexer(; a = 0.025,
b = 0.5)))
b = 0.5)); callback = CallbackTester(length(x0)))
@test 10 * sol.objective < l1

f = OptimizationFunction(rosenbrock, Optimization.AutoForwardDiff())

Random.seed!(1234)
prob = OptimizationProblem(f, x0, _p, lb = [-1.0, -1.0], ub = [0.8, 0.8])
sol = solve(prob, SAMIN())
sol = solve(prob, SAMIN(); callback = CallbackTester(length(x0)))
@test 10 * sol.objective < l1

sol = solve(prob, Optim.IPNewton())
sol = solve(
prob, Optim.IPNewton();
callback = CallbackTester(length(x0); has_grad = true, has_hess = true)
)
@test 10 * sol.objective < l1

prob = OptimizationProblem(f, x0, _p)
Random.seed!(1234)
sol = solve(prob, SimulatedAnnealing())
sol = solve(prob, SimulatedAnnealing(); callback = CallbackTester(length(x0)))
@test 10 * sol.objective < l1

sol = solve(prob, Optim.BFGS())
sol = solve(prob, Optim.BFGS(); callback = CallbackTester(length(x0); has_grad = true))
@test 10 * sol.objective < l1

sol = solve(prob, Optim.Newton())
sol = solve(
prob, Optim.Newton();
callback = CallbackTester(length(x0); has_grad = true, has_hess = true)
)
@test 10 * sol.objective < l1

sol = solve(prob, Optim.KrylovTrustRegion())
@test 10 * sol.objective < l1

sol = solve(prob, Optim.BFGS(), maxiters = 1)
sol = solve(
prob, Optim.BFGS();
maxiters = 1, callback = CallbackTester(length(x0); has_grad = true)
)
@test sol.original.iterations == 1

sol = solve(prob, Optim.BFGS(), maxiters = 1, local_maxiters = 2)
Expand Down Expand Up @@ -92,7 +127,8 @@ using Test
optprob = OptimizationFunction(rosenbrock, Optimization.AutoZygote())

prob = OptimizationProblem(optprob, x0, _p, lb = [-1.0, -1.0], ub = [0.8, 0.8])
sol = solve(prob, Optim.Fminbox())
sol = solve(
prob, Optim.Fminbox(); callback = CallbackTester(length(x0); has_grad = true))
@test 10 * sol.objective < l1

Random.seed!(1234)
Expand Down

0 comments on commit 7e115c9

Please sign in to comment.