Skip to content

Commit

Permalink
Do some type conversions to support Float32s
Browse files Browse the repository at this point in the history
  • Loading branch information
Vaibhavdixit02 committed Jul 11, 2023
1 parent a59e23c commit ea45267
Showing 1 changed file with 13 additions and 6 deletions.
19 changes: 13 additions & 6 deletions lib/OptimizationOptimisers/src/sophia.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ function SciMLBase.__solve(cache::OptimizationCache{
C,
}
local x, cur, state
uType = eltype(cache.u0)
lr = uType(cache.opt.lr)
betas = uType.(cache.opt.betas)
eps = uType(cache.opt.eps)
weight_decay = uType(cache.opt.weight_decay)
rho = uType(cache.opt.rho)

Check warning on line 61 in lib/OptimizationOptimisers/src/sophia.jl

View check run for this annotation

Codecov / codecov/patch

lib/OptimizationOptimisers/src/sophia.jl#L55-L61

Added lines #L55 - L61 were not covered by tests

if cache.data != Optimization.DEFAULT_DATA
maxiters = length(cache.data)
Expand Down Expand Up @@ -91,19 +97,20 @@ function SciMLBase.__solve(cache::OptimizationCache{
elseif cb_call
break

Check warning on line 98 in lib/OptimizationOptimisers/src/sophia.jl

View check run for this annotation

Codecov / codecov/patch

lib/OptimizationOptimisers/src/sophia.jl#L86-L98

Added lines #L86 - L98 were not covered by tests
end
mₜ = cache.opt.betas[1] .* mₜ + (1 - cache.opt.betas[1]) .* gₜ
mₜ = betas[1] .* mₜ + (1 - betas[1]) .* gₜ

Check warning on line 100 in lib/OptimizationOptimisers/src/sophia.jl

View check run for this annotation

Codecov / codecov/patch

lib/OptimizationOptimisers/src/sophia.jl#L100

Added line #L100 was not covered by tests

if i % cache.opt.k == 1
hₜ₋₁ = copy(hₜ)
u = randn(length(θ))
u = randn(uType, length(θ))
f.hv(hₜ, θ, u, d...)
hₜ = cache.opt.betas[2] .* hₜ₋₁ + (1 - cache.opt.betas[2]) .* (u .* hₜ)
hₜ = betas[2] .* hₜ₋₁ + (1 - betas[2]) .* (u .* hₜ)

Check warning on line 106 in lib/OptimizationOptimisers/src/sophia.jl

View check run for this annotation

Codecov / codecov/patch

lib/OptimizationOptimisers/src/sophia.jl#L102-L106

Added lines #L102 - L106 were not covered by tests
end
θ = θ .- cache.opt.lr * cache.opt.weight_decay .* θ
θ = θ .- lr * weight_decay .* θ
θ = θ .-

Check warning on line 109 in lib/OptimizationOptimisers/src/sophia.jl

View check run for this annotation

Codecov / codecov/patch

lib/OptimizationOptimisers/src/sophia.jl#L108-L109

Added lines #L108 - L109 were not covered by tests
cache.opt.lr .* clip.(mₜ ./ max.(hₜ, Ref(cache.opt.eps)), Ref(cache.opt.rho))
lr .* clip.(mₜ ./ max.(hₜ, Ref(eps)), Ref(rho))
end

Check warning on line 111 in lib/OptimizationOptimisers/src/sophia.jl

View check run for this annotation

Codecov / codecov/patch

lib/OptimizationOptimisers/src/sophia.jl#L111

Added line #L111 was not covered by tests

return SciMLBase.build_solution(cache, cache.opt,

Check warning on line 113 in lib/OptimizationOptimisers/src/sophia.jl

View check run for this annotation

Codecov / codecov/patch

lib/OptimizationOptimisers/src/sophia.jl#L113

Added line #L113 was not covered by tests
θ,
cache.f(θ, cache.p))
x)
end

0 comments on commit ea45267

Please sign in to comment.