From ea45267d6f14adb9f55d8318a6acbfc9e2a02d5d Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Tue, 11 Jul 2023 21:30:15 +0530 Subject: [PATCH] Do some type conversions to support Float32s --- lib/OptimizationOptimisers/src/sophia.jl | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/lib/OptimizationOptimisers/src/sophia.jl b/lib/OptimizationOptimisers/src/sophia.jl index 05bc4c0aa..47adbf351 100644 --- a/lib/OptimizationOptimisers/src/sophia.jl +++ b/lib/OptimizationOptimisers/src/sophia.jl @@ -53,6 +53,12 @@ function SciMLBase.__solve(cache::OptimizationCache{ C, } local x, cur, state + uType = eltype(cache.u0) + lr = uType(cache.opt.lr) + betas = uType.(cache.opt.betas) + eps = uType(cache.opt.eps) + weight_decay = uType(cache.opt.weight_decay) + rho = uType(cache.opt.rho) if cache.data != Optimization.DEFAULT_DATA maxiters = length(cache.data) @@ -91,19 +97,20 @@ function SciMLBase.__solve(cache::OptimizationCache{ elseif cb_call break end - mₜ = cache.opt.betas[1] .* mₜ + (1 - cache.opt.betas[1]) .* gₜ + mₜ = betas[1] .* mₜ + (1 - betas[1]) .* gₜ + if i % cache.opt.k == 1 hₜ₋₁ = copy(hₜ) - u = randn(length(θ)) + u = randn(uType, length(θ)) f.hv(hₜ, θ, u, d...) - hₜ = cache.opt.betas[2] .* hₜ₋₁ + (1 - cache.opt.betas[2]) .* (u .* hₜ) + hₜ = betas[2] .* hₜ₋₁ + (1 - betas[2]) .* (u .* hₜ) end - θ = θ .- cache.opt.lr * cache.opt.weight_decay .* θ + θ = θ .- lr * weight_decay .* θ θ = θ .- - cache.opt.lr .* clip.(mₜ ./ max.(hₜ, Ref(cache.opt.eps)), Ref(cache.opt.rho)) + lr .* clip.(mₜ ./ max.(hₜ, Ref(eps)), Ref(rho)) end return SciMLBase.build_solution(cache, cache.opt, θ, - cache.f(θ, cache.p)) + x) end