From 8ba6cb6d0b2f29d6c01c967ea164c4a70629bde5 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Mon, 12 Aug 2024 14:54:18 -0400 Subject: [PATCH] Parameter-Free Optimization Algorithms (#81) * add parameter averaging strategies and parameter free optimizers * rename optimization rules file, add docstrings to averagers * add tests for averaging and custom rules * change `optimize` interface, update tests and docs, add optim docs * update comments and use fancier LaTeX math --- README.md | 4 +- docs/make.jl | 31 ++++--- docs/src/elbo/repgradelbo.md | 92 ++++++++++++------- docs/src/examples.md | 73 +++++++++------ docs/src/optimization.md | 26 ++++++ src/AdvancedVI.jl | 48 +++++++++- src/optimization/averaging.jl | 53 +++++++++++ src/optimization/rules.jl | 89 ++++++++++++++++++ src/optimize.jl | 19 +++- src/utils.jl | 8 ++ test/Project.toml | 1 + test/inference/repgradelbo_distributionsad.jl | 18 ++-- test/inference/repgradelbo_locationscale.jl | 18 ++-- .../repgradelbo_locationscale_bijectors.jl | 18 ++-- test/interface/averaging.jl | 38 ++++++++ test/interface/optimize.jl | 24 +++-- test/interface/rules.jl | 27 ++++++ test/runtests.jl | 12 ++- 18 files changed, 475 insertions(+), 124 deletions(-) create mode 100644 docs/src/optimization.md create mode 100644 src/optimization/averaging.jl create mode 100644 src/optimization/rules.jl create mode 100644 test/interface/averaging.jl create mode 100644 test/interface/rules.jl diff --git a/README.md b/README.md index 1ca0d848..f3bb745f 100644 --- a/README.md +++ b/README.md @@ -98,7 +98,7 @@ q_transformed = Bijectors.TransformedDistribution(q, binv) # Run inference max_iter = 10^3 -q, stats, _ = AdvancedVI.optimize( +q_avg, _, stats, _ = AdvancedVI.optimize( model, elbo, q_transformed, @@ -108,7 +108,7 @@ q, stats, _ = AdvancedVI.optimize( ) # Evaluate final ELBO with 10^3 Monte Carlo samples -estimate_objective(elbo, q, model; n_samples=10^4) +estimate_objective(elbo, q_avg, model; n_samples=10^4) ``` For more examples and details, please refer to the documentation. diff --git a/docs/make.jl b/docs/make.jl index 7ae3bc62..b71d9a4f 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -2,23 +2,24 @@ using AdvancedVI using Documenter -DocMeta.setdocmeta!( - AdvancedVI, :DocTestSetup, :(using AdvancedVI); recursive=true -) +DocMeta.setdocmeta!(AdvancedVI, :DocTestSetup, :(using AdvancedVI); recursive=true) makedocs(; - modules = [AdvancedVI], - sitename = "AdvancedVI.jl", - repo = "https://github.com/TuringLang/AdvancedVI.jl/blob/{commit}{path}#{line}", - format = Documenter.HTML(; prettyurls = get(ENV, "CI", nothing) == "true"), - pages = ["AdvancedVI" => "index.md", - "General Usage" => "general.md", - "Examples" => "examples.md", - "ELBO Maximization" => [ - "Overview" => "elbo/overview.md", - "Reparameterization Gradient Estimator" => "elbo/repgradelbo.md", - "Location-Scale Variational Family" => "locscale.md", - ]], + modules=[AdvancedVI], + sitename="AdvancedVI.jl", + repo="https://github.com/TuringLang/AdvancedVI.jl/blob/{commit}{path}#{line}", + format=Documenter.HTML(; prettyurls=get(ENV, "CI", nothing) == "true"), + pages=[ + "AdvancedVI" => "index.md", + "General Usage" => "general.md", + "Examples" => "examples.md", + "ELBO Maximization" => [ + "Overview" => "elbo/overview.md", + "Reparameterization Gradient Estimator" => "elbo/repgradelbo.md", + "Location-Scale Variational Family" => "locscale.md", + ], + "Optimization" => "optimization.md", + ], ) deploydocs(; repo="github.com/TuringLang/AdvancedVI.jl", push_preview=true) diff --git a/docs/src/elbo/repgradelbo.md b/docs/src/elbo/repgradelbo.md index ee7854b2..10af6a52 100644 --- a/docs/src/elbo/repgradelbo.md +++ b/docs/src/elbo/repgradelbo.md @@ -1,56 +1,67 @@ - # [Reparameterization Gradient Estimator](@id repgradelbo) + ## Overview The reparameterization gradient[^TL2014][^RMW2014][^KW2014] is an unbiased gradient estimator of the ELBO. Consider some variational family + ```math \mathcal{Q} = \{q_{\lambda} \mid \lambda \in \Lambda \}, ``` + where $$\lambda$$ is the *variational parameters* of $$q_{\lambda}$$. If its sampling process can be described by some differentiable reparameterization function $$\mathcal{T}_{\lambda}$$ and a *base distribution* $$\varphi$$ independent of $$\lambda$$ such that + ```math z \sim q_{\lambda} \qquad\Leftrightarrow\qquad z \stackrel{d}{=} \mathcal{T}_{\lambda}\left(\epsilon\right);\quad \epsilon \sim \varphi ``` -we can effectively estimate the gradient of the ELBO by directly differentiating + +we can effectively estimate the gradient of the ELBO by directly differentiating + ```math \widehat{\mathrm{ELBO}}\left(\lambda\right) = \frac{1}{M}\sum^M_{m=1} \log \pi\left(\mathcal{T}_{\lambda}\left(\epsilon_m\right)\right) + \mathbb{H}\left(q_{\lambda}\right), ``` + where $$\epsilon_m \sim \varphi$$ are Monte Carlo samples, with respect to $$\lambda$$. This estimator is called the reparameterization gradient estimator. In addition to the reparameterization gradient, `AdvancedVI` provides the following features: -1. **Posteriors with constrained supports** are handled through [`Bijectors`](), which is known as the automatic differentiation VI (ADVI; [^KTRGB2017]) formulation. (See [this section](@ref bijectors).) -2. **The gradient of the entropy** can be estimated through various strategies depending on the capabilities of the variational family. (See [this section](@ref entropygrad).) -[^TL2014]: Titsias, M., & Lázaro-Gredilla, M. (2014). Doubly stochastic variational Bayes for non-conjugate inference. In *International Conference on Machine Learning*. + 1. **Posteriors with constrained supports** are handled through [`Bijectors`](https://github.com/TuringLang/Bijectors.jl), which is known as the automatic differentiation VI (ADVI; [^KTRGB2017]) formulation. (See [this section](@ref bijectors).) + 2. **The gradient of the entropy** can be estimated through various strategies depending on the capabilities of the variational family. (See [this section](@ref entropygrad).) + +[^TL2014]: Titsias, M., & Lázaro-Gredilla, M. (2014). Doubly stochastic variational Bayes for non-conjugate inference. In *International Conference on Machine Learning*. [^RMW2014]: Rezende, D. J., Mohamed, S., & Wierstra, D. (2014). Stochastic backpropagation and approximate inference in deep generative models. In *International Conference on Machine Learning*. [^KW2014]: Kingma, D. P., & Welling, M. (2014). Auto-encoding variational bayes. In *International Conference on Learning Representations*. - ## The `RepGradELBO` Objective To use the reparameterization gradient, `AdvancedVI` provides the following variational objective: + ```@docs RepGradELBO ``` ## [Handling Constraints with `Bijectors`](@id bijectors) + As mentioned in the docstring, the `RepGradELBO` objective assumes that the variational approximation $$q_{\lambda}$$ and the target distribution $$\pi$$ have the same support for all $$\lambda \in \Lambda$$. However, in general, it is most convenient to use variational families that have the whole Euclidean space $$\mathbb{R}^d$$ as their support. This is the case for the [location-scale distributions](@ref locscale) provided by `AdvancedVI`. For target distributions which the support is not the full $$\mathbb{R}^d$$, we can apply some transformation $$b$$ to $$q_{\lambda}$$ to match its support such that + ```math z \sim q_{b,\lambda} \qquad\Leftrightarrow\qquad z \stackrel{d}{=} b^{-1}\left(\eta\right);\quad \eta \sim q_{\lambda}, ``` + where $$b$$ is often called a *bijector*, since it is often chosen among bijective transformations. This idea is known as automatic differentiation VI[^KTRGB2017] and has subsequently been improved by Tensorflow Probability[^DLTBV2017]. In Julia, [Bijectors.jl](https://github.com/TuringLang/Bijectors.jl)[^FXTYG2020] provides a comprehensive collection of bijections. One caveat of ADVI is that, after applying the bijection, a Jacobian adjustment needs to be applied. -That is, the objective is now +That is, the objective is now + ```math \mathrm{ADVI}\left(\lambda\right) \triangleq @@ -63,28 +74,30 @@ That is, the objective is now This is automatically handled by `AdvancedVI` through `TransformedDistribution` provided by `Bijectors.jl`. See the following example: + ```julia using Bijectors -q = MeanFieldGaussian(μ, L) -b = Bijectors.bijector(dist) -binv = inverse(b) +q = MeanFieldGaussian(μ, L) +b = Bijectors.bijector(dist) +binv = inverse(b) q_transformed = Bijectors.TransformedDistribution(q, binv) ``` + By passing `q_transformed` to `optimize`, the Jacobian adjustment for the bijector `b` is automatically applied. (See [Examples](@ref examples) for a fully working example.) [^KTRGB2017]: Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). Automatic differentiation variational inference. *Journal of Machine Learning Research*. [^DLTBV2017]: Dillon, J. V., Langmore, I., Tran, D., Brevdo, E., Vasudevan, S., Moore, D., ... & Saurous, R. A. (2017). Tensorflow distributions. arXiv. [^FXTYG2020]: Fjelde, T. E., Xu, K., Tarek, M., Yalburgi, S., & Ge, H. (2020,. Bijectors. jl: Flexible transformations for probability distributions. In *Symposium on Advances in Approximate Bayesian Inference*. - ## [Entropy Estimators](@id entropygrad) + For the gradient of the entropy term, we provide three choices with varying requirements. The user can select the entropy estimator by passing it as a keyword argument when constructing the `RepGradELBO` objective. -| Estimator | `entropy(q)` | `logpdf(q)` | Type | -| :--- | :---: | :---: | :--- | -| `ClosedFormEntropy` | required | | Deterministic | -| `MonteCarloEntropy` | | required | Monte Carlo | +| Estimator | `entropy(q)` | `logpdf(q)` | Type | +|:--------------------------- |:------------:|:-----------:|:-------------------------------- | +| `ClosedFormEntropy` | required | | Deterministic | +| `MonteCarloEntropy` | | required | Monte Carlo | | `StickingTheLandingEntropy` | | required | Monte Carlo with control variate | The requirements mean that either `Distributions.entropy` or `Distributions.logpdf` need to be implemented for the choice of variational family. @@ -93,10 +106,13 @@ If `entropy` is not available, then `StickingTheLandingEntropy` is recommended. See the following section for more details. ### The `StickingTheLandingEntropy` Estimator + The `StickingTheLandingEntropy`, or STL estimator, is a control variate approach [^RWD2017]. + ```@docs StickingTheLandingEntropy ``` + It occasionally results in lower variance when ``\pi \approx q_{\lambda}``, and higher variance when ``\pi \not\approx q_{\lambda}``. The conditions for which the STL estimator results in lower variance is still an active subject for research. @@ -165,33 +181,38 @@ This setting is known as "perfect variational family specification." In this case, the `RepGradELBO` estimator with `StickingTheLandingEntropy` is the only estimator known to converge exponentially fast ("linear convergence") to the true solution. Recall that the original ADVI objective with a closed-form entropy (CFE) is given as follows: + ```@example repgradelbo n_montecarlo = 16; -b = Bijectors.bijector(model); -binv = inverse(b) +b = Bijectors.bijector(model); +binv = inverse(b) q0_trans = Bijectors.TransformedDistribution(q0, binv) cfe = AdvancedVI.RepGradELBO(n_montecarlo) nothing ``` + The repgradelbo estimator can instead be created as follows: + ```@example repgradelbo -repgradelbo = AdvancedVI.RepGradELBO(n_montecarlo; entropy = AdvancedVI.StickingTheLandingEntropy()); +repgradelbo = AdvancedVI.RepGradELBO( + n_montecarlo; entropy=AdvancedVI.StickingTheLandingEntropy() +); nothing ``` ```@setup repgradelbo max_iter = 3*10^3 -function callback(; stat, state, params, restructure, gradient) +function callback(; params, restructure, kwargs...) q = restructure(params).dist dist2 = sum(abs2, q.location - vcat([μ_x], μ_y)) + sum(abs2, diag(q.scale) - vcat(σ_x, σ_y)) (dist = sqrt(dist2),) end -_, stats_cfe, _ = AdvancedVI.optimize( +_, _, stats_cfe, _ = AdvancedVI.optimize( model, cfe, q0_trans, @@ -202,7 +223,7 @@ _, stats_cfe, _ = AdvancedVI.optimize( callback = callback, ); -_, stats_stl, _ = AdvancedVI.optimize( +_, _, stats_stl, _ = AdvancedVI.optimize( model, repgradelbo, q0_trans, @@ -227,6 +248,7 @@ plot!(t, dist_stl, label="BBVI STL", xlabel="Iteration", ylabel="distance to opt savefig("advi_stl_dist.svg") nothing ``` + ![](advi_stl_elbo.svg) We can see that the noise of the repgradelbo estimator becomes smaller as VI converges. @@ -243,15 +265,16 @@ Furthermore, in a lot of cases, a low-accuracy solution may be sufficient. [^RWD2017]: Roeder, G., Wu, Y., & Duvenaud, D. K. (2017). Sticking the landing: Simple, lower-variance gradient estimators for variational inference. Advances in Neural Information Processing Systems, 30. [^KMG2024]: Kim, K., Ma, Y., & Gardner, J. (2024). Linear Convergence of Black-Box Variational Inference: Should We Stick the Landing?. In International Conference on Artificial Intelligence and Statistics (pp. 235-243). PMLR. - ## Advanced Usage + There are two major ways to customize the behavior of `RepGradELBO` -* Customize the `Distributions` functions: `rand(q)`, `entropy(q)`, `logpdf(q)`. -* Customize `AdvancedVI.reparam_with_entropy`. + + - Customize the `Distributions` functions: `rand(q)`, `entropy(q)`, `logpdf(q)`. + - Customize `AdvancedVI.reparam_with_entropy`. It is generally recommended to customize `rand(q)`, `entropy(q)`, `logpdf(q)`, since it will easily compose with other functionalities provided by `AdvancedVI`. -The most advanced way is to customize `AdvancedVI.reparam_with_entropy`. +The most advanced way is to customize `AdvancedVI.reparam_with_entropy`. In particular, `reparam_with_entropy` is the function that invokes `rand(q)`, `entropy(q)`, `logpdf(q)`. Thus, it is the most general way to override the behavior of `RepGradELBO`. @@ -267,26 +290,27 @@ In this case, it suffices to override its `rand` specialization as follows: using QuasiMonteCarlo using StatsFuns -qmcrng = SobolSample(R = OwenScramble(base = 2, pad = 32)) +qmcrng = SobolSample(; R=OwenScramble(; base=2, pad=32)) function Distributions.rand( - rng::AbstractRNG, q::MvLocationScale{<:Diagonal, D, L}, num_samples::Int -) where {L, D} - @unpack location, scale, dist = q - n_dims = length(location) - scale_diag = diag(scale) + rng::AbstractRNG, q::MvLocationScale{<:Diagonal,D,L}, num_samples::Int +) where {L,D} + @unpack location, scale, dist = q + n_dims = length(location) + scale_diag = diag(scale) unif_samples = QuasiMonteCarlo.sample(num_samples, length(q), qmcrng) - std_samples = norminvcdf.(unif_samples) - scale_diag.*std_samples .+ location + std_samples = norminvcdf.(unif_samples) + return scale_diag .* std_samples .+ location end nothing ``` + (Note that this is a quick-and-dirty example, and there are more sophisticated ways to implement this.) ```@setup repgradelbo repgradelbo = AdvancedVI.RepGradELBO(n_montecarlo); -_, stats_qmc, _ = AdvancedVI.optimize( +_, _, stats_qmc, _ = AdvancedVI.optimize( model, repgradelbo, q0_trans, diff --git a/docs/src/examples.md b/docs/src/examples.md index dbf1de45..15b8907a 100644 --- a/docs/src/examples.md +++ b/docs/src/examples.md @@ -1,15 +1,18 @@ - ## [Evidence Lower Bound Maximization](@id examples) + In this tutorial, we will work with a `normal-log-normal` model. + ```math \begin{aligned} x &\sim \mathrm{LogNormal}\left(\mu_x, \sigma_x^2\right) \\ y &\sim \mathcal{N}\left(\mu_y, \sigma_y^2\right) \end{aligned} ``` + BBVI with `Bijectors.Exp` bijectors is able to infer this model exactly. Using the `LogDensityProblems` interface, we the model can be defined as follows: + ```@example elboexample using LogDensityProblems using SimpleUnPack @@ -23,43 +26,47 @@ end function LogDensityProblems.logdensity(model::NormalLogNormal, θ) @unpack μ_x, σ_x, μ_y, Σ_y = model - logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end]) + return logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end]) end function LogDensityProblems.dimension(model::NormalLogNormal) - length(model.μ_y) + 1 + return length(model.μ_y) + 1 end function LogDensityProblems.capabilities(::Type{<:NormalLogNormal}) - LogDensityProblems.LogDensityOrder{0}() + return LogDensityProblems.LogDensityOrder{0}() end ``` + Let's now instantiate the model + ```@example elboexample using LinearAlgebra n_dims = 10 -μ_x = randn() -σ_x = exp.(randn()) -μ_y = randn(n_dims) -σ_y = exp.(randn(n_dims)) -model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y.^2)); +μ_x = randn() +σ_x = exp.(randn()) +μ_y = randn(n_dims) +σ_y = exp.(randn(n_dims)) +model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y .^ 2)); nothing ``` Since the `y` follows a log-normal prior, its support is bounded to be the positive half-space ``\mathbb{R}_+``. Thus, we will use [Bijectors](https://github.com/TuringLang/Bijectors.jl) to match the support of our target posterior and the variational approximation. + ```@example elboexample using Bijectors function Bijectors.bijector(model::NormalLogNormal) @unpack μ_x, σ_x, μ_y, Σ_y = model - Bijectors.Stacked( + return Bijectors.Stacked( Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]), - [1:1, 2:1+length(μ_y)]) + [1:1, 2:(1 + length(μ_y))], + ) end -b = Bijectors.bijector(model); +b = Bijectors.bijector(model); binv = inverse(b) nothing ``` @@ -68,62 +75,76 @@ Let's now load `AdvancedVI`. Since BBVI relies on automatic differentiation (AD), we need to load an AD library, *before* loading `AdvancedVI`. Also, the selected AD framework needs to be communicated to `AdvancedVI` using the [ADTypes](https://github.com/SciML/ADTypes.jl) interface. Here, we will use `ForwardDiff`, which can be selected by later passing `ADTypes.AutoForwardDiff()`. + ```@example elboexample using Optimisers using ADTypes, ForwardDiff using AdvancedVI ``` + We now need to select 1. a variational objective, and 2. a variational family. Here, we will use the [`RepGradELBO` objective](@ref repgradelbo), which expects an object implementing the [`LogDensityProblems`](https://github.com/tpapp/LogDensityProblems.jl) interface, and the inverse bijector. + ```@example elboexample n_montecaro = 10; -objective = RepGradELBO(n_montecaro) +objective = RepGradELBO(n_montecaro) ``` + For the variational family, we will use the classic mean-field Gaussian family. + ```@example elboexample -d = LogDensityProblems.dimension(model); -μ = randn(d); -L = Diagonal(ones(d)); +d = LogDensityProblems.dimension(model); +μ = randn(d); +L = Diagonal(ones(d)); q0 = AdvancedVI.MeanFieldGaussian(μ, L) nothing ``` -And then, we now apply the bijector to the variational family. + +And then, we now apply the bijector to the variational family. + ```@example elboexample q0_trans = Bijectors.TransformedDistribution(q0, binv) nothing ``` Passing `objective` and the initial variational approximation `q` to `optimize` performs inference. + ```@example elboexample n_max_iter = 10^4 -q_trans, stats, _ = AdvancedVI.optimize( +q_avg_trans, q_trans, stats, _ = AdvancedVI.optimize( model, objective, q0_trans, n_max_iter; - show_progress = false, - adtype = AutoForwardDiff(), - optimizer = Optimisers.Adam(1e-3) -); + show_progress=false, + adtype=AutoForwardDiff(), + optimizer=Optimisers.Adam(1e-3), +); nothing ``` +`q_avg_trans` is the final output of the optimization procedure. +If a parameter averaging strategy is used through the keyword argument `averager`, `q_avg_trans` is be the output of the averaging strategy, while `q_trans` is the last iterate. + The selected inference procedure stores per-iteration statistics into `stats`. For instance, the ELBO can be ploted as follows: + ```@example elboexample using Plots -t = [stat.iteration for stat ∈ stats] -y = [stat.elbo for stat ∈ stats] -plot(t, y, label="BBVI", xlabel="Iteration", ylabel="ELBO") +t = [stat.iteration for stat in stats] +y = [stat.elbo for stat in stats] +plot(t, y; label="BBVI", xlabel="Iteration", ylabel="ELBO") savefig("bbvi_example_elbo.svg") nothing ``` + ![](bbvi_example_elbo.svg) Further information can be gathered by defining your own `callback!`. The final ELBO can be estimated by calling the objective directly with a different number of Monte Carlo samples as follows: + ```@example elboexample -estimate_objective(objective, q_trans, model; n_samples=10^4) +estimate_objective(objective, q_avg_trans, model; n_samples=10^4) ``` diff --git a/docs/src/optimization.md b/docs/src/optimization.md new file mode 100644 index 00000000..315f896e --- /dev/null +++ b/docs/src/optimization.md @@ -0,0 +1,26 @@ +# [Optimization](@id optim) + +## Parameter-Free Optimization Rules + +We provide custom optimization rules that are not provided out-of-the-box by [Optimisers.jl](https://github.com/FluxML/Optimisers.jl). +The main theme of the provided optimizers is that they are parameter-free. +This means that these optimization rules shouldn't require (or barely) any tuning to obtain performance competitive with well-tuned alternatives. + +```@docs +DoG +DoWG +COCOB +``` + +## Parameter Averaging Strategies + +In some cases, the best optimization performance is obtained by averaging the sequence of parameters generated by the optimization algorithm. +For instance, the `DoG`[^IHC2023] and `DoWG`[^KMJ2024] papers report their best performance through averaging. +The benefits of parameter averaging have been specifically confirmed for ELBO maximization[^DCAMHV2020]. + +```@docs +NoAveraging +PolynomialAveraging +``` + +[^DCAMHV2020]: Dhaka, A. K., Catalina, A., Andersen, M. R., Magnusson, M., Huggins, J., & Vehtari, A. (2020). Robust, accurate stochastic optimization for variational inference. Advances in Neural Information Processing Systems, 33, 10961-10973. diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index c1fa33d7..dfd682d5 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -181,7 +181,53 @@ export MvLocationScale, MeanFieldGaussian, FullRankGaussian include("families/location_scale.jl") -# Optimization Routine +# Optimization Rules + +include("optimization/rules.jl") + +export DoWG, DoG, COCOB + +# Output averaging strategy + +abstract type AbstractAverager end + +""" + init(avg, params) + +Initialize the state of the averaging strategy `avg` with the initial parameters `params`. + +# Arguments +- `avg::AbstractAverager`: Averaging strategy. +- `params`: Initial variational parameters. +""" +init(::AbstractAverager, ::Any) = nothing + +""" + apply(avg, avg_st, params) + +Apply averaging strategy `avg` on `params` given the state `avg_st`. + +# Arguments +- `avg::AbstractAverager`: Averaging strategy. +- `avg_st`: Previous state of the averaging strategy. +- `params`: Initial variational parameters. +""" +function apply(::AbstractAverager, ::Any, ::Any) end + +""" + value(avg, avg_st) + +Compute the output of the averaging strategy `avg` from the state `avg_st`. + +# Arguments +- `avg::AbstractAverager`: Averaging strategy. +- `avg_st`: Previous state of the averaging strategy. +""" +function value(::AbstractAverager, ::Any) end + +include("optimization/averaging.jl") + +export NoAveraging, PolynomialAveraging function optimize end diff --git a/src/optimization/averaging.jl b/src/optimization/averaging.jl new file mode 100644 index 00000000..19c375d8 --- /dev/null +++ b/src/optimization/averaging.jl @@ -0,0 +1,53 @@ + +""" + NoAveraging() + +No averaging. This returns the last-iterate of the optimization rule. +""" +struct NoAveraging <: AbstractAverager end + +init(::NoAveraging, x) = x + +apply(::NoAveraging, state, x) = x + +value(::NoAveraging, state) = state + +""" + PolynomialAveraging(eta) + +Polynomial averaging rule proposed Shamir and Zhang[^SZ2013]. +At iteration `t`, the parameter average \$ \\bar{\\lambda}_t \$ according to the polynomial averaging rule is given as +```math + \\bar{\\lambda}_t = (1 - w_t) \\bar{\\lambda}_{t-1} + w_t \\lambda_t \\, , +``` +where the averaging weight is +```math + w_t = \\frac{\\eta + 1}{t + \\eta} \\, . +``` +Higher `eta` (\$\\eta\$) down-weights earlier iterations. +When \$\\eta=0\$, this is equivalent to uniformly averaging the iterates in an online fashion. +The DoG paper[^IHC2023] suggests \$\\eta=8\$. + +# Parameters +- `eta`: Regularization term. (default: `8`) + +[^SZ2013]: Shamir, O., & Zhang, T. (2013). Stochastic gradient descent for non-smooth optimization: Convergence results and optimal averaging schemes. In International conference on machine learning (pp. 71-79). PMLR. +""" +struct PolynomialAveraging{F} <: AbstractAverager + eta::F +end + +PolynomialAveraging() = PolynomialAveraging(8) + +init(::PolynomialAveraging, x) = (x, 1) + +function apply(avg::PolynomialAveraging, state, x::AbstractVector{T}) where {T} + eta = T(avg.eta) + x_bar, t = state + + w = (eta + 1) / (t + eta) + x_bar = (1 - w) * x_bar + w * x + return (x_bar, t + 1) +end + +value(::PolynomialAveraging, state) = first(state) diff --git a/src/optimization/rules.jl b/src/optimization/rules.jl new file mode 100644 index 00000000..7bb65e86 --- /dev/null +++ b/src/optimization/rules.jl @@ -0,0 +1,89 @@ + +""" + DoWG(repsilon) + +Distance over weighted gradient (DoWG[^KMJ2024]) optimizer. +It's only parameter is the initial guess of the Euclidean distance to the optimum repsilon. + +# Parameters +- `repsilon`: Initial guess of the Euclidean distance between the initial point and + the optimum. (default value: `1e-6`) + +[^KMJ2024]: Khaled, A., Mishchenko, K., & Jin, C. (2023). Dowg unleashed: An efficient universal parameter-free gradient descent method. Advances in Neural Information Processing Systems, 36, 6748-6769. +""" +Optimisers.@def struct DoWG <: Optimisers.AbstractRule + repsilon = 1e-6 +end + +Optimisers.init(o::DoWG, x::AbstractArray{T}) where {T} = (copy(x), zero(T), T(o.repsilon)) + +function Optimisers.apply!(::DoWG, state, x::AbstractArray{T}, dx) where {T} + x0, v, r = state + + r = max(sqrt(sum(abs2, x - x0)), r) + r2 = r * r + v = v + r2 * sum(abs2, dx) + η = r2 / sqrt(v) + dx′ = Optimisers.@lazy dx * η + return (x0, v, r), dx′ +end + +""" + DoG(repsilon) + +Distance over gradient (DoG[^IHC2023]) optimizer. +It's only parameter is the initial guess of the Euclidean distance to the optimum repsilon. +The original paper recommends \$ 10^{-4} ( 1 + \\lVert \\lambda_0 \\rVert ) \$, but the default value is \$ 10^{-6} \$. + +# Parameters +- `repsilon`: Initial guess of the Euclidean distance between the initial point and the optimum. (default value: `1e-6`) + +[^IHC2023]: Ivgi, M., Hinder, O., & Carmon, Y. (2023). Dog is sgd's best friend: A parameter-free dynamic step size schedule. In International Conference on Machine Learning (pp. 14465-14499). PMLR. +""" +Optimisers.@def struct DoG <: Optimisers.AbstractRule + repsilon = 1e-6 +end + +Optimisers.init(o::DoG, x::AbstractArray{T}) where {T} = (copy(x), zero(T), T(o.repsilon)) + +function Optimisers.apply!(::DoG, state, x::AbstractArray{T}, dx) where {T} + x0, v, r = state + + r = max(sqrt(sum(abs2, x - x0)), r) + v = v + sum(abs2, dx) + η = r / sqrt(v) + dx′ = Optimisers.@lazy dx * η + return (x0, v, r), dx′ +end + +""" + COCOB(alpha) + +Continuous Coin Betting (COCOB[^OT2017]) optimizer. +We use the "COCOB-Backprop" variant, which is closer to the Adam optimizer. +It's only parameter is the maximum change per parameter α, which shouldn't need much tuning. + +# Parameters +- `alpha`: Scaling parameter. (default value: `100`) + +[^OT2017]: Orabona, F., & Tommasi, T. (2017). Training deep networks without learning rates through coin betting. Advances in Neural Information Processing Systems, 30. +""" +Optimisers.@def struct COCOB <: Optimisers.AbstractRule + alpha = 100 +end + +function Optimisers.init(::COCOB, x::AbstractArray{T}) where {T} + return (zero(x), zero(x), zero(x), zero(x), copy(x)) +end + +function Optimisers.apply!(o::COCOB, state, x::AbstractArray{T}, dx) where {T} + α = T(o.alpha) + L, G, R, θ, x1 = state + + Optimisers.@.. L = max(L, abs(dx)) + Optimisers.@.. G = G + abs(dx) + Optimisers.@.. R = max(R + (x - x1) * -dx, 0) + Optimisers.@.. θ = θ + -dx + dx′ = Optimisers.@lazy -(x1 - x) - (θ / (L * max(G + L, α * L)) * (L + R)) + return (L, G, R, θ, x1), dx′ +end diff --git a/src/optimize.jl b/src/optimize.jl index 5bef7eec..99ff3a89 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -16,6 +16,7 @@ This requires the variational approximation to be marked as a functor through `F # Keyword Arguments - `adtype::ADtypes.AbstractADType`: Automatic differentiation backend. - `optimizer::Optimisers.AbstractRule`: Optimizer used for inference. (Default: `Adam`.) +- `averager::AbstractAverager` : Parameter averaging strategy. (Default: `NoAveraging()`) - `rng::AbstractRNG`: Random number generator. (Default: `Random.default_rng()`.) - `show_progress::Bool`: Whether to show the progress bar. (Default: `true`.) - `callback`: Callback function called after every iteration. See further information below. (Default: `nothing`.) @@ -23,19 +24,21 @@ This requires the variational approximation to be marked as a functor through `F - `state::NamedTuple`: Initial value for the internal state of optimization. Used to warm-start from the state of a previous run. (See the returned values below.) # Returns -- `params`: Variational parameters optimizing the variational objective. +- `averaged_params`: Variational parameters generated by the algorithm averaged according to `averager`. +- `params`: Last variational parameters generated by the algorithm. - `stats`: Statistics gathered during optimization. - `state`: Collection of the final internal states of optimization. This can used later to warm-start from the last iteration of the corresponding run. # Callback The callback function `callback` has a signature of - callback(; stat, state, params, restructure, gradient) + callback(; stat, state, params, params_average, restructure, gradient) The arguments are as follows: - `stat`: Statistics gathered during the current iteration. The content will vary depending on `objective`. - `state`: Collection of the internal states used for optimization. - `params`: Variational parameters. +- `params_average`: Variational parameters computed by the averaging strategy. - `restructure`: Function that restructures the variational approximation from the variational parameters. Calling `restructure(param)` reconstructs the variational approximation. - `gradient`: The estimated (possibly stochastic) gradient. @@ -53,6 +56,7 @@ function optimize( objargs...; adtype::ADTypes.AbstractADType, optimizer::Optimisers.AbstractRule=Optimisers.Adam(), + averager::AbstractAverager=NoAveraging(), show_progress::Bool=true, state_init::NamedTuple=NamedTuple(), callback=nothing, @@ -63,6 +67,7 @@ function optimize( params, restructure = Optimisers.destructure(deepcopy(q_init)) opt_st = maybe_init_optimizer(state_init, optimizer, params) obj_st = maybe_init_objective(state_init, rng, objective, problem, params, restructure) + avg_st = maybe_init_averager(state_init, averager, params) grad_buf = DiffResults.DiffResult(zero(eltype(params)), similar(params)) stats = NamedTuple[] @@ -86,14 +91,17 @@ function optimize( opt_st, params = update_variational_params!( typeof(q_init), opt_st, params, restructure, grad ) + avg_st = apply(averager, avg_st, params) if !isnothing(callback) + params_average = value(averager, avg_st) stat′ = callback(; stat, restructure, params=params, + params_average=params_average, gradient=grad, - state=(optimizer=opt_st, objective=obj_st), + state=(optimizer=opt_st, averager=avg_st, objective=obj_st), ) stat = !isnothing(stat′) ? merge(stat′, stat) : stat end @@ -103,9 +111,10 @@ function optimize( pm_next!(prog, stat) push!(stats, stat) end - state = (optimizer=opt_st, objective=obj_st) + state = (optimizer=opt_st, averager=avg_st, objective=obj_st) stats = map(identity, stats) - return restructure(params), stats, state + averaged_params = value(averager, avg_st) + return restructure(averaged_params), restructure(params), stats, state end function optimize( diff --git a/src/utils.jl b/src/utils.jl index c504513d..11618677 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -13,6 +13,14 @@ function maybe_init_optimizer( end end +function maybe_init_averager(state_init::NamedTuple, averager::AbstractAverager, params) + if haskey(state_init, :averager) + state_init.averager + else + init(averager, params) + end +end + function maybe_init_objective( state_init::NamedTuple, rng::Random.AbstractRNG, diff --git a/test/Project.toml b/test/Project.toml index 16370aee..d7212699 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -18,6 +18,7 @@ ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/test/inference/repgradelbo_distributionsad.jl b/test/inference/repgradelbo_distributionsad.jl index 646b70b8..0ca2223f 100644 --- a/test/inference/repgradelbo_distributionsad.jl +++ b/test/inference/repgradelbo_distributionsad.jl @@ -37,7 +37,7 @@ @testset "convergence" begin Δλ0 = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true) - q, stats, _ = optimize( + q_avg, _, stats, _ = optimize( rng, model, objective, @@ -48,8 +48,8 @@ adtype=adtype, ) - μ = mean(q) - L = sqrt(cov(q)) + μ = mean(q_avg) + L = sqrt(cov(q_avg)) Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true) @test Δλ ≤ contraction_rate^(T / 2) * Δλ0 @@ -59,7 +59,7 @@ @testset "determinism" begin rng = StableRNG(seed) - q, stats, _ = optimize( + q_avg, _, stats, _ = optimize( rng, model, objective, @@ -69,11 +69,11 @@ show_progress=PROGRESS, adtype=adtype, ) - μ = mean(q) - L = sqrt(cov(q)) + μ = mean(q_avg) + L = sqrt(cov(q_avg)) rng_repl = StableRNG(seed) - q, stats, _ = optimize( + q_avg, _, stats, _ = optimize( rng_repl, model, objective, @@ -83,8 +83,8 @@ show_progress=PROGRESS, adtype=adtype, ) - μ_repl = mean(q) - L_repl = sqrt(cov(q)) + μ_repl = mean(q_avg) + L_repl = sqrt(cov(q_avg)) @test μ == μ_repl @test L == L_repl end diff --git a/test/inference/repgradelbo_locationscale.jl b/test/inference/repgradelbo_locationscale.jl index d2f5f0d7..56ddc7b5 100644 --- a/test/inference/repgradelbo_locationscale.jl +++ b/test/inference/repgradelbo_locationscale.jl @@ -41,7 +41,7 @@ @testset "convergence" begin Δλ0 = sum(abs2, q0.location - μ_true) + sum(abs2, q0.scale - L_true) - q, stats, _ = optimize( + q_avg, _, stats, _ = optimize( rng, model, objective, @@ -52,8 +52,8 @@ adtype=adtype, ) - μ = q.location - L = q.scale + μ = q_avg.location + L = q_avg.scale Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true) @test Δλ ≤ contraction_rate^(T / 2) * Δλ0 @@ -63,7 +63,7 @@ @testset "determinism" begin rng = StableRNG(seed) - q, stats, _ = optimize( + q_avg, _, stats, _ = optimize( rng, model, objective, @@ -73,11 +73,11 @@ show_progress=PROGRESS, adtype=adtype, ) - μ = q.location - L = q.scale + μ = q_avg.location + L = q_avg.scale rng_repl = StableRNG(seed) - q, stats, _ = optimize( + q_avg, _, stats, _ = optimize( rng_repl, model, objective, @@ -87,8 +87,8 @@ show_progress=PROGRESS, adtype=adtype, ) - μ_repl = q.location - L_repl = q.scale + μ_repl = q_avg.location + L_repl = q_avg.scale @test μ == μ_repl @test L == L_repl end diff --git a/test/inference/repgradelbo_locationscale_bijectors.jl b/test/inference/repgradelbo_locationscale_bijectors.jl index 284dd2f8..245c1544 100644 --- a/test/inference/repgradelbo_locationscale_bijectors.jl +++ b/test/inference/repgradelbo_locationscale_bijectors.jl @@ -47,7 +47,7 @@ @testset "convergence" begin Δλ0 = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true) - q, stats, _ = optimize( + q_avg, _, stats, _ = optimize( rng, model, objective, @@ -58,8 +58,8 @@ adtype=adtype, ) - μ = q.dist.location - L = q.dist.scale + μ = q_avg.dist.location + L = q_avg.dist.scale Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true) @test Δλ ≤ contraction_rate^(T / 2) * Δλ0 @@ -69,7 +69,7 @@ @testset "determinism" begin rng = StableRNG(seed) - q, stats, _ = optimize( + q_avg, _, stats, _ = optimize( rng, model, objective, @@ -79,11 +79,11 @@ show_progress=PROGRESS, adtype=adtype, ) - μ = q.dist.location - L = q.dist.scale + μ = q_avg.dist.location + L = q_avg.dist.scale rng_repl = StableRNG(seed) - q, stats, _ = optimize( + q_avg, _, stats, _ = optimize( rng_repl, model, objective, @@ -93,8 +93,8 @@ show_progress=PROGRESS, adtype=adtype, ) - μ_repl = q.dist.location - L_repl = q.dist.scale + μ_repl = q_avg.dist.location + L_repl = q_avg.dist.scale @test μ == μ_repl @test L == L_repl end diff --git a/test/interface/averaging.jl b/test/interface/averaging.jl new file mode 100644 index 00000000..e7a23e5e --- /dev/null +++ b/test/interface/averaging.jl @@ -0,0 +1,38 @@ + +function simulate_sequence_average(realtype::Type{<:Real}, avg::AdvancedVI.AbstractAverager) + d = 3 + n = 10 + xs = randn(realtype, d, n) + xs_it = eachcol(xs) + st = AdvancedVI.init(avg, first(xs_it)) + for x in xs_it + st = AdvancedVI.apply(avg, st, x) + end + return AdvancedVI.value(avg, st), xs +end + +@testset "averaging" begin + avg = NoAveraging() + @testset "$(avg) $(realtype)" for realtype in [Float32, Float64] + x_avg, xs = simulate_sequence_average(realtype, avg) + + @test eltype(x_avg) == realtype + @test x_avg ≈ xs[:, end] + end + + η = 1 + avg = PolynomialAveraging(η) + @testset "$(avg) $(realtype)" for realtype in [Float32, Float64] + x_avg, xs = simulate_sequence_average(realtype, avg) + + T = size(xs, 2) + α = map(1:T) do t + # Formula from the proof of Theorem 4 by Shamir & Zhang (2013) + (η + 1) / (t + η) * (t == T ? 1 : prod(j -> (j - 1) / (j + η), (t + 1):T)) + end + x_avg_true = xs * α + + @test eltype(x_avg) == realtype + @test x_avg ≈ x_avg_true + end +end diff --git a/test/interface/optimize.jl b/test/interface/optimize.jl index eb006c98..268098b5 100644 --- a/test/interface/optimize.jl +++ b/test/interface/optimize.jl @@ -16,14 +16,10 @@ using Test adtype = AutoForwardDiff() optimizer = Optimisers.Adam(1e-2) - - rng = StableRNG(seed) - q_ref, stats_ref, _ = optimize( - rng, model, obj, q0, T; optimizer, show_progress=false, adtype - ) + averager = PolynomialAveraging() @testset "default_rng" begin - optimize(model, obj, q0, T; optimizer, show_progress=false, adtype) + optimize(model, obj, q0, T; optimizer, averager, show_progress=false, adtype) end @testset "callback" begin @@ -33,33 +29,41 @@ using Test callback(; stat, args...) = (test_value=test_values[stat.iteration],) rng = StableRNG(seed) - _, stats, _ = optimize( + _, _, stats, _ = optimize( rng, model, obj, q0, T; show_progress=false, adtype, callback ) @test [stat.test_value for stat in stats] == test_values end + rng = StableRNG(seed) + q_avg_ref, q_ref, _, _ = optimize( + rng, model, obj, q0, T; optimizer, averager, show_progress=false, adtype + ) + @testset "warm start" begin rng = StableRNG(seed) T_first = div(T, 2) T_last = T - T_first - q_first, _, state = optimize( - rng, model, obj, q0, T_first; optimizer, show_progress=false, adtype + _, q_first, _, state = optimize( + rng, model, obj, q0, T_first; optimizer, averager, show_progress=false, adtype ) - q, stats, _ = optimize( + q_avg, q, _, _ = optimize( rng, model, obj, q_first, T_last; optimizer, + averager, show_progress=false, state_init=state, adtype, ) + @test q == q_ref + @test q_avg == q_avg_ref end end diff --git a/test/interface/rules.jl b/test/interface/rules.jl new file mode 100644 index 00000000..39bee1d5 --- /dev/null +++ b/test/interface/rules.jl @@ -0,0 +1,27 @@ + +@testset "rules" begin + @testset "$(rule) $(realtype)" for rule in [DoWG(), DoG(), COCOB()], + realtype in [Float32, Float64] + + T = 10^4 + + d = 10 + n = 1000 + w = randn(realtype, d) + X = rand(realtype, n, d) + w_true = randn(realtype, d) + loss(x, w) = mean((x * w .- x * w_true) .^ 2) + l0 = loss(X, w) + + opt_st = Optimisers.setup(rule, w) + for t in 1:T + i = sample(1:n) + xi = X[i:i, :] + g = ForwardDiff.gradient(Base.Fix1(loss, xi), w) + opt_st, w = Optimisers.update!(opt_st, w, g) + end + + @test eltype(w) == realtype + @test loss(X, w) < l0 / 10 + end +end diff --git a/test/runtests.jl b/test/runtests.jl index ff1bab49..31028167 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,14 +2,16 @@ using Test using Test: @testset, @test +using Base.Iterators using Bijectors -using Random, StableRNGs -using Statistics using Distributions -using LinearAlgebra -using SimpleUnPack: @unpack using FillArrays +using LinearAlgebra using PDMats +using Random, StableRNGs +using SimpleUnPack: @unpack +using Statistics +using StatsBase using Functors using DistributionsAD @@ -41,6 +43,8 @@ if GROUP == "All" || GROUP == "Interface" include("interface/ad.jl") include("interface/optimize.jl") include("interface/repgradelbo.jl") + include("interface/rules.jl") + include("interface/averaging.jl") include("interface/location_scale.jl") end