Skip to content

Commit

Permalink
Add benchmarks (#104)
Browse files Browse the repository at this point in the history
* add subdirectories to compathelper
* add benchmarks
* add compat bound to benchmark
* remove unused code
* refactor benchmark code, run formatter
* disable enzyme, mooncake for now
* update benchmark README
  • Loading branch information
Red-Portal authored Oct 5, 2024
1 parent 4da2f5b commit 65d93ae
Show file tree
Hide file tree
Showing 14 changed files with 133 additions and 84 deletions.
22 changes: 22 additions & 0 deletions bench/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,35 @@ BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
ADTypes = "1"
BenchmarkTools = "1"
Bijectors = "0.13"
Distributions = "0.25.111"
DistributionsAD = "0.6"
Enzyme = "0.13.7"
FillArrays = "1"
ForwardDiff = "0.10"
InteractiveUtils = "1"
LogDensityProblems = "2"
Mooncake = "0.4.5"
Optimisers = "0.3"
Random = "1"
ReverseDiff = "1"
SimpleUnPack = "1"
StableRNGs = "1"
Zygote = "0.6"
julia = "1.10"
10 changes: 10 additions & 0 deletions bench/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,13 @@
This subdirectory contains code for continuous benchmarking of the performance of `AdvancedVI.jl`.
The initial version was heavily inspired by the setup of [Lux.jl](https://github.com/LuxDL/Lux.jl/tree/main).
The Github action and pages integration is provided by https://github.com/benchmark-action/github-action-benchmark/ and [BenchmarkTools.jl](https://github.com/JuliaCI/BenchmarkTools.jl).

To run the benchmarks locally, follow the following steps:

```julia
using Pkg
Pkg.activate(".")
Pkg.instantiate()
Pkg.develop("AdvancedVI")
include("benchmarks.jl")
```
87 changes: 59 additions & 28 deletions bench/benchmarks.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@

using ADTypes, ForwardDiff, ReverseDiff, Zygote
using ADTypes
using AdvancedVI
using BenchmarkTools
using Bijectors
using Distributions
using DistributionsAD
using Enzyme, ForwardDiff, ReverseDiff, Zygote, Mooncake
using FillArrays
using InteractiveUtils
using LinearAlgebra
Expand All @@ -17,37 +18,67 @@ BLAS.set_num_threads(min(4, Threads.nthreads()))
@info sprint(versioninfo)
@info "BLAS threads: $(BLAS.get_num_threads())"

include("utils.jl")
include("normallognormal.jl")
include("unconstrdist.jl")

const SUITES = BenchmarkGroup()

# Comment until https://github.com/TuringLang/Bijectors.jl/pull/315 is merged
# SUITES["normal + bijector"]["meanfield"]["Zygote"] =
# @benchmarkable normallognormal(
# ;
# fptype = Float64,
# adtype = AutoZygote(),
# family = :meanfield,
# objective = :RepGradELBO,
# n_montecarlo = 4,
# )

SUITES["normal + bijector"]["meanfield"]["ReverseDiff"] = @benchmarkable normallognormal(;
fptype=Float64,
adtype=AutoReverseDiff(),
family=:meanfield,
objective=:RepGradELBO,
n_montecarlo=4,
)

SUITES["normal + bijector"]["meanfield"]["ForwardDiff"] = @benchmarkable normallognormal(;
fptype=Float64,
adtype=AutoForwardDiff(),
family=:meanfield,
objective=:RepGradELBO,
n_montecarlo=4,
)
function variational_standard_mvnormal(type::Type, n_dims::Int, family::Symbol)
if family == :meanfield
MeanFieldGaussian(zeros(type, n_dims), Diagonal(ones(type, n_dims)))
else
FullRankGaussian(zeros(type, n_dims), Matrix(type, I, n_dims, n_dims))
end
end

begin
T = Float64

for (probname, prob) in [
("normal + bijector", normallognormal(; n_dims=10, realtype=T))
("normal", normal(; n_dims=10, realtype=T))
]
max_iter = 10^4
d = LogDensityProblems.dimension(prob)
optimizer = Optimisers.Adam(T(1e-3))

for (objname, obj) in [
("RepGradELBO", RepGradELBO(10)),
("RepGradELBO + STL", RepGradELBO(10; entropy=StickingTheLandingEntropy())),
],
(adname, adtype) in [
("Zygote", AutoZygote()),
("ForwardDiff", AutoForwardDiff()),
("ReverseDiff", AutoReverseDiff()),
#("Mooncake", AutoMooncake(; config=Mooncake.Config())),
#("Enzyme", AutoEnzyme()),
],
(familyname, family) in [
("meanfield", MeanFieldGaussian(zeros(T, d), Diagonal(ones(T, d)))),
(
"fullrank",
FullRankGaussian(zeros(T, d), LowerTriangular(Matrix{T}(I, d, d))),
),
]

b = Bijectors.bijector(prob)
binv = inverse(b)
q = Bijectors.TransformedDistribution(family, binv)

SUITES[probname][objname][familyname][adname] = begin
@benchmarkable AdvancedVI.optimize(
$prob,
$obj,
$q,
$max_iter;
adtype=$adtype,
optimizer=$optimizer,
show_progress=false,
)
end
end
end
end

BenchmarkTools.tune!(SUITES; verbose=true)
results = BenchmarkTools.run(SUITES; verbose=true)
Expand Down
32 changes: 6 additions & 26 deletions bench/normallognormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,30 +27,10 @@ function Bijectors.bijector(model::NormalLogNormal)
)
end

function normallognormal(; fptype, adtype, family, objective, max_iter=10^3, kwargs...)
n_dims = 10
μ_x = fptype(5.0)
σ_x = fptype(0.3)
μ_y = Fill(fptype(5.0), n_dims)
σ_y = Fill(fptype(0.3), n_dims)
model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y .^ 2))

obj = variational_objective(objective; kwargs...)

d = LogDensityProblems.dimension(model)
q = variational_standard_mvnormal(fptype, d, family)

b = Bijectors.bijector(model)
binv = inverse(b)
q_transformed = Bijectors.TransformedDistribution(q, binv)

return AdvancedVI.optimize(
model,
obj,
q_transformed,
max_iter;
adtype,
optimizer=Optimisers.Adam(fptype(1e-3)),
show_progress=false,
)
function normallognormal(; n_dims=10, realtype=Float64)
μ_x = realtype(5.0)
σ_x = realtype(0.3)
μ_y = Fill(realtype(5.0), n_dims)
σ_y = Fill(realtype(0.3), n_dims)
return model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y .^ 2))
end
26 changes: 26 additions & 0 deletions bench/unconstrdist.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@

struct UnconstrDist{D<:ContinuousMultivariateDistribution}
dist::D
end

function LogDensityProblems.logdensity(model::UnconstrDist, x)
return logpdf(model.dist, x)
end

function LogDensityProblems.dimension(model::UnconstrDist)
return length(model.dist)
end

function LogDensityProblems.capabilities(::Type{<:UnconstrDist})
return LogDensityProblems.LogDensityOrder{0}()
end

function Bijectors.bijector(model::UnconstrDist)
return identity
end

function normal(; n_dims=10, realtype=Float64)
μ = fill(realtype(5), n_dims)
Σ = Diagonal(ones(realtype, n_dims))
return UnconstrDist(MvNormal(μ, Σ))
end
20 changes: 0 additions & 20 deletions bench/utils.jl

This file was deleted.

2 changes: 1 addition & 1 deletion test/families/location_scale.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@
@test cov(z_samples; dims=2) cov(q_true) rtol = realtype(1e-2)

samples_ref = rand(StableRNG(1), q, n_montecarlo)
@test samples_ref == rand(StableRNG(1), q, n_montecarlo)
@test samples_ref rand(StableRNG(1), q, n_montecarlo)
end

@testset "rand! AbstractVector" begin
Expand Down
2 changes: 1 addition & 1 deletion test/inference/repgradelbo_distributionsad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ AD_distributionsad = Dict(
)

if @isdefined(Mooncake)
AD_distributionsad[:Mooncake] = AutoMooncake(; config=nothing)
AD_distributionsad[:Mooncake] = AutoMooncake(; config=Mooncake.Config())
end

if @isdefined(Enzyme)
Expand Down
2 changes: 1 addition & 1 deletion test/inference/repgradelbo_locationscale.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ AD_locationscale = Dict(
)

if @isdefined(Mooncake)
AD_locationscale[:Mooncake] = AutoMooncake(; config=nothing)
AD_locationscale[:Mooncake] = AutoMooncake(; config=Mooncake.Config())
end

if @isdefined(Enzyme)
Expand Down
2 changes: 1 addition & 1 deletion test/inference/repgradelbo_locationscale_bijectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ AD_locationscale_bijectors = Dict(
)

if @isdefined(Mooncake)
AD_locationscale_bijectors[:Mooncake] = AutoMooncake(; config=nothing)
AD_locationscale_bijectors[:Mooncake] = AutoMooncake(; config=Mooncake.Config())
end

if @isdefined(Enzyme)
Expand Down
4 changes: 2 additions & 2 deletions test/inference/scoregradelbo_distributionsad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ AD_scoregradelbo_distributionsad = Dict(
:Zygote => AutoZygote(),
)

if @isdefined(Tapir)
AD_scoregradelbo_distributionsad[:Tapir] = AutoTapir(; safe_mode=false)
if @isdefined(Mooncake)
AD_scoregradelbo_distributionsad[:Mooncake] = AutoMooncake(; config=Mooncake.Config())
end

#if @isdefined(Enzyme)
Expand Down
2 changes: 1 addition & 1 deletion test/inference/scoregradelbo_locationscale.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ AD_scoregradelbo_locationscale = Dict(
)

if @isdefined(Mooncake)
AD_scoregradelbo_locationscale[:Mooncake] = AutoMooncake(; config=nothing)
AD_scoregradelbo_locationscale[:Mooncake] = AutoMooncake(; config=Mooncake.Config())
end

if @isdefined(Enzyme)
Expand Down
4 changes: 2 additions & 2 deletions test/interface/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ const interface_ad_backends = Dict(
:Zygote => AutoZygote(),
)

if @isdefined(Tapir)
interface_ad_backends[:Tapir] = AutoTapir(; safe_mode=false)
if @isdefined(Mooncake)
interface_ad_backends[:Mooncake] = AutoMooncake(; config=Mooncake.Config())
end

if @isdefined(Enzyme)
Expand Down
2 changes: 1 addition & 1 deletion test/interface/repgradelbo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ end
ADTypes.AutoForwardDiff(), ADTypes.AutoReverseDiff(), ADTypes.AutoZygote()
]
if @isdefined(Mooncake)
push!(ad_backends, AutoMooncake(; config=nothing))
push!(ad_backends, AutoMooncake(; config=Mooncake.Config()))
end
if @isdefined(Enzyme)
push!(
Expand Down

0 comments on commit 65d93ae

Please sign in to comment.