Skip to content

Commit

Permalink
Support callback
Browse files Browse the repository at this point in the history
  • Loading branch information
Vaibhavdixit02 committed Jul 11, 2023
1 parent 50388f5 commit a59e23c
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions lib/OptimizationOptimisers/src/sophia.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,13 @@ function SciMLBase.__solve(cache::OptimizationCache{
hₜ = zero(θ)
for (i, d) in enumerate(data)
f.grad(gₜ, θ, d...)
x = cache.f(θ, cache.p, d...)
cb_call = cache.callback(θ, x...)
if !(typeof(cb_call) <: Bool)
error("The callback should return a boolean `halt` for whether to stop the optimization process. Please see the sciml_train documentation for information.")
elseif cb_call
break
end
mₜ = cache.opt.betas[1] .* mₜ + (1 - cache.opt.betas[1]) .* gₜ
if i % cache.opt.k == 1
hₜ₋₁ = copy(hₜ)
Expand Down

0 comments on commit a59e23c

Please sign in to comment.