Skip to content

Commit

Permalink
Merge branch 'master' of github.com:TuringLang/AdvancedVI.jl into low…
Browse files Browse the repository at this point in the history
…rank
  • Loading branch information
Red-Portal committed Aug 10, 2024
2 parents 1d56953 + 4207a87 commit 6752c6b
Show file tree
Hide file tree
Showing 27 changed files with 450 additions and 491 deletions.
2 changes: 2 additions & 0 deletions .JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
style="blue"
format_markdown = true
5 changes: 1 addition & 4 deletions .github/workflows/Benchmark.yml
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
name: Benchmarks
on:
push:
branches:
- master
pull_request:
branches:
- master
Expand Down Expand Up @@ -52,4 +49,4 @@ jobs:
alert-threshold: "200%"
fail-on-alert: true
benchmark-data-dir-path: benchmarks
auto-push: ${{ github.event_name != 'pull_request' }}
auto-push: false
13 changes: 8 additions & 5 deletions .github/workflows/DocNav.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,16 @@ jobs:
# Define the URL of the navbar to be used
NAVBAR_URL="https://raw.githubusercontent.com/TuringLang/turinglang.github.io/main/assets/scripts/TuringNavbar.html"
# Update all HTML files in the current directory (gh-pages root)
./insert_navbar.sh . $NAVBAR_URL
# Define file & folder to exclude (comma-separated list), Un-Comment the below line for excluding anything!
EXCLUDE_PATHS="benchmarks"
# Update all HTML files in the current directory (gh-pages root), use `--exclude` only if requred!
./insert_navbar.sh . $NAVBAR_URL --exclude "$EXCLUDE_PATHS"
# Remove the insert_navbar.sh file
rm insert_navbar.sh
# Check if there are any changes
if [[ -n $(git status -s) ]]; then
git add .
Expand Down
33 changes: 0 additions & 33 deletions .github/workflows/JuliaNightly.yml

This file was deleted.

1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
[![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://turinglang.org/AdvancedVI.jl/stable/)
[![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://turinglang.org/AdvancedVI.jl/dev/)
[![Build Status](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/CI.yml/badge.svg?branch=master)](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/CI.yml?query=branch%3Amaster)
[![JuliaNightly](https://github.com/TuringLang/AdvancedVI.jl/workflows/JuliaNightly/badge.svg?branch=master)](https://github.com/TuringLang/AdvancedVI.jl/actions?query=workflow%3AJuliaNightly+branch%3Amaster)
[![Coverage](https://codecov.io/gh/TuringLang/AdvancedVI.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/TuringLang/AdvancedVI.jl)

# AdvancedVI.jl
Expand Down
22 changes: 11 additions & 11 deletions ext/AdvancedVIBijectorsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ function AdvancedVI.update_variational_params!(
opt_st,
params,
restructure,
grad
grad,
)
opt_st, params = Optimisers.update!(opt_st, params, grad)
q = restructure(params)
Expand All @@ -32,18 +32,18 @@ function AdvancedVI.update_variational_params!(

params, _ = Optimisers.destructure(q)

opt_st, params
return opt_st, params
end

function AdvancedVI.reparam_with_entropy(
rng ::Random.AbstractRNG,
q ::Bijectors.TransformedDistribution,
q_stop ::Bijectors.TransformedDistribution,
rng::Random.AbstractRNG,
q::Bijectors.TransformedDistribution,
q_stop::Bijectors.TransformedDistribution,
n_samples::Int,
ent_est ::AdvancedVI.AbstractEntropyEstimator
ent_est::AdvancedVI.AbstractEntropyEstimator,
)
transform = q.transform
q_unconst = q.dist
transform = q.transform
q_unconst = q.dist
q_unconst_stop = q_stop.dist

# Draw samples and compute entropy of the uncontrained distribution
Expand All @@ -58,14 +58,14 @@ function AdvancedVI.reparam_with_entropy(
samples_and_logjac = mapreduce(
AdvancedVI.catsamples_and_acc,
Iterators.drop(unconstr_iter, 1);
init=(reshape(samples_init, (:,1)), logjac_init)
init=(reshape(samples_init, (:, 1)), logjac_init),
) do sample
with_logabsdet_jacobian(transform, sample)
end
samples = first(samples_and_logjac)
logjac = last(samples_and_logjac)/n_samples
logjac = last(samples_and_logjac) / n_samples

entropy = unconst_entropy + logjac
samples, entropy
return samples, entropy
end
end
8 changes: 4 additions & 4 deletions ext/AdvancedVIEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@ else
using ..AdvancedVI: ADTypes, DiffResults
end

# Enzyme doesn't support f::Bijectors (see https://github.com/EnzymeAD/Enzyme.jl/issues/916)
function AdvancedVI.value_and_gradient!(
ad::ADTypes.AutoEnzyme, f, θ::AbstractVector{T}, out::DiffResults.MutableDiffResult
) where {T<:Real}
y = f(θ)
DiffResults.value!(out, y)
∇θ = DiffResults.gradient(out)
fill!(∇θ, zero(T))
Enzyme.autodiff(Enzyme.ReverseWithPrimal, f, Enzyme.Active, Enzyme.Duplicated(θ, ∇θ))
_, y = Enzyme.autodiff(
Enzyme.ReverseWithPrimal, Enzyme.Const(f), Enzyme.Active, Enzyme.Duplicated(θ, ∇θ)
)
DiffResults.value!(out, y)
return out
end

Expand Down
16 changes: 8 additions & 8 deletions ext/AdvancedVIForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ end
getchunksize(::ADTypes.AutoForwardDiff{chunksize}) where {chunksize} = chunksize

function AdvancedVI.value_and_gradient!(
ad ::ADTypes.AutoForwardDiff,
ad::ADTypes.AutoForwardDiff,
f,
x ::AbstractVector{<:Real},
out::DiffResults.MutableDiffResult
x::AbstractVector{<:Real},
out::DiffResults.MutableDiffResult,
)
chunk_size = getchunksize(ad)
config = if isnothing(chunk_size)
Expand All @@ -30,13 +30,13 @@ function AdvancedVI.value_and_gradient!(
end

function AdvancedVI.value_and_gradient!(
ad ::ADTypes.AutoForwardDiff,
ad::ADTypes.AutoForwardDiff,
f,
x ::AbstractVector,
aux,
out::DiffResults.MutableDiffResult
x::AbstractVector,
aux,
out::DiffResults.MutableDiffResult,
)
AdvancedVI.value_and_gradient!(ad, x′ -> f(x′, aux), x, out)
return AdvancedVI.value_and_gradient!(ad, x′ -> f(x′, aux), x, out)
end

end
6 changes: 3 additions & 3 deletions ext/AdvancedVIReverseDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ function AdvancedVI.value_and_gradient!(
ad::ADTypes.AutoReverseDiff,
f,
x::AbstractVector{<:Real},
out::DiffResults.MutableDiffResult
out::DiffResults.MutableDiffResult,
)
tp = ReverseDiff.GradientTape(f, x)
ReverseDiff.gradient!(out, tp, x)
Expand All @@ -28,9 +28,9 @@ function AdvancedVI.value_and_gradient!(
f,
x::AbstractVector{<:Real},
aux,
out::DiffResults.MutableDiffResult
out::DiffResults.MutableDiffResult,
)
AdvancedVI.value_and_gradient!(ad, x′ -> f(x′, aux), x, out)
return AdvancedVI.value_and_gradient!(ad, x′ -> f(x′, aux), x, out)
end

end
9 changes: 3 additions & 6 deletions ext/AdvancedVIZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,7 @@ else
end

function AdvancedVI.value_and_gradient!(
::ADTypes.AutoZygote,
f,
x::AbstractVector{<:Real},
out::DiffResults.MutableDiffResult
::ADTypes.AutoZygote, f, x::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult
)
y, back = Zygote.pullback(f, x)
∇x = back(one(y))
Expand All @@ -31,9 +28,9 @@ function AdvancedVI.value_and_gradient!(
f,
x::AbstractVector{<:Real},
aux,
out::DiffResults.MutableDiffResult
out::DiffResults.MutableDiffResult,
)
AdvancedVI.value_and_gradient!(ad, x′ -> f(x′, aux), x, out)
return AdvancedVI.value_and_gradient!(ad, x′ -> f(x′, aux), x, out)
end

end
28 changes: 6 additions & 22 deletions src/AdvancedVI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,9 @@ Instead, the return values should be used.
"""
function update_variational_params! end

update_variational_params!(::Type, opt_st, params, restructure, grad) =
Optimisers.update!(opt_st, params, grad)
function update_variational_params!(::Type, opt_st, params, restructure, grad)
return Optimisers.update!(opt_st, params, grad)
end

# estimators
"""
Expand All @@ -106,13 +107,7 @@ This function needs to be implemented only if `obj` is stateful.
- `params`: Initial variational parameters.
- `restructure`: Function that reconstructs the variational approximation from `λ`.
"""
init(
::Random.AbstractRNG,
::AbstractVariationalObjective,
::Any,
::Any,
::Any,
) = nothing
init(::Random.AbstractRNG, ::AbstractVariationalObjective, ::Any, ::Any, ::Any) = nothing

"""
estimate_objective([rng,] obj, q, prob; kwargs...)
Expand All @@ -136,7 +131,6 @@ function estimate_objective end

export estimate_objective


"""
estimate_gradient!(rng, obj, adtype, out, prob, λ, restructure, obj_state)
Expand Down Expand Up @@ -177,21 +171,13 @@ Estimate the entropy of `q`.
"""
function estimate_entropy end

export
RepGradELBO,
ClosedFormEntropy,
StickingTheLandingEntropy,
MonteCarloEntropy
export RepGradELBO, ClosedFormEntropy, StickingTheLandingEntropy, MonteCarloEntropy

include("objectives/elbo/entropy.jl")
include("objectives/elbo/repgradelbo.jl")


# Variational Families
export
MvLocationScale,
MeanFieldGaussian,
FullRankGaussian
export MvLocationScale, MeanFieldGaussian, FullRankGaussian

include("families/location_scale.jl")

Expand All @@ -211,7 +197,6 @@ export optimize
include("utils.jl")
include("optimize.jl")


# optional dependencies
if !isdefined(Base, :get_extension) # check whether :get_extension is defined in Base
using Requires
Expand All @@ -238,4 +223,3 @@ end
end

end

Loading

0 comments on commit 6752c6b

Please sign in to comment.