Skip to content

Commit

Permalink
Merge branch 'master' of github.com:TuringLang/AdvancedVI.jl into ena…
Browse files Browse the repository at this point in the history
…ble_enzyme
  • Loading branch information
Red-Portal committed Aug 10, 2024
2 parents 1061ae8 + 4207a87 commit 4978d1f
Show file tree
Hide file tree
Showing 22 changed files with 458 additions and 474 deletions.
2 changes: 2 additions & 0 deletions .JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
style="blue"
format_markdown = true
22 changes: 11 additions & 11 deletions ext/AdvancedVIBijectorsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ function AdvancedVI.update_variational_params!(
opt_st,
params,
restructure,
grad
grad,
)
opt_st, params = Optimisers.update!(opt_st, params, grad)
q = restructure(params)
Expand All @@ -32,18 +32,18 @@ function AdvancedVI.update_variational_params!(

params, _ = Optimisers.destructure(q)

opt_st, params
return opt_st, params
end

function AdvancedVI.reparam_with_entropy(
rng ::Random.AbstractRNG,
q ::Bijectors.TransformedDistribution,
q_stop ::Bijectors.TransformedDistribution,
rng::Random.AbstractRNG,
q::Bijectors.TransformedDistribution,
q_stop::Bijectors.TransformedDistribution,
n_samples::Int,
ent_est ::AdvancedVI.AbstractEntropyEstimator
ent_est::AdvancedVI.AbstractEntropyEstimator,
)
transform = q.transform
q_unconst = q.dist
transform = q.transform
q_unconst = q.dist
q_unconst_stop = q_stop.dist

# Draw samples and compute entropy of the uncontrained distribution
Expand All @@ -58,14 +58,14 @@ function AdvancedVI.reparam_with_entropy(
samples_and_logjac = mapreduce(
AdvancedVI.catsamples_and_acc,
Iterators.drop(unconstr_iter, 1);
init=(reshape(samples_init, (:,1)), logjac_init)
init=(reshape(samples_init, (:, 1)), logjac_init),
) do sample
with_logabsdet_jacobian(transform, sample)
end
samples = first(samples_and_logjac)
logjac = last(samples_and_logjac)/n_samples
logjac = last(samples_and_logjac) / n_samples

entropy = unconst_entropy + logjac
samples, entropy
return samples, entropy
end
end
27 changes: 10 additions & 17 deletions ext/AdvancedVIEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,46 +11,39 @@ else
using ..AdvancedVI: ADTypes, DiffResults
end


AdvancedVI.restructure_ad_forward(
::ADTypes.AutoEnzyme, restructure, params
) = restructure(params)::typeof(restructure.model)
function AdvancedVI.restructure_ad_forward(::ADTypes.AutoEnzyme, restructure, params)
return restructure(params)::typeof(restructure.model)
end

function AdvancedVI.value_and_gradient!(
::ADTypes.AutoEnzyme,
f,
x ::AbstractVector{<:Real},
out::DiffResults.MutableDiffResult
::ADTypes.AutoEnzyme, f, x::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult
)
Enzyme.API.runtimeActivity!(true)
∇x = DiffResults.gradient(out)
fill!(∇x, zero(eltype(∇x)))
_, y = Enzyme.autodiff(
Enzyme.ReverseWithPrimal,
f,
Enzyme.Active,
Enzyme.Duplicated(x, ∇x)
Enzyme.ReverseWithPrimal, Enzyme.Const(f), Enzyme.Active, Enzyme.Duplicated(x, ∇x)
)
DiffResults.value!(out, y)
return out
end

function AdvancedVI.value_and_gradient!(
::ADTypes.AutoEnzyme,
::ADTypes.AutoEnzyme,
f,
x ::AbstractVector{<:Real},
x::AbstractVector{<:Real},
aux,
out ::DiffResults.MutableDiffResult
out::DiffResults.MutableDiffResult,
)
Enzyme.API.runtimeActivity!(true)
∇x = DiffResults.gradient(out)
fill!(∇x, zero(eltype(∇x)))
_, y = Enzyme.autodiff(
Enzyme.ReverseWithPrimal,
f,
Enzyme.Const(f),
Enzyme.Active,
Enzyme.Duplicated(x, ∇x),
Enzyme.Const(aux)
Enzyme.Const(aux),
)
DiffResults.value!(out, y)
return out
Expand Down
16 changes: 8 additions & 8 deletions ext/AdvancedVIForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ end
getchunksize(::ADTypes.AutoForwardDiff{chunksize}) where {chunksize} = chunksize

function AdvancedVI.value_and_gradient!(
ad ::ADTypes.AutoForwardDiff,
ad::ADTypes.AutoForwardDiff,
f,
x ::AbstractVector{<:Real},
out::DiffResults.MutableDiffResult
x::AbstractVector{<:Real},
out::DiffResults.MutableDiffResult,
)
chunk_size = getchunksize(ad)
config = if isnothing(chunk_size)
Expand All @@ -30,13 +30,13 @@ function AdvancedVI.value_and_gradient!(
end

function AdvancedVI.value_and_gradient!(
ad ::ADTypes.AutoForwardDiff,
ad::ADTypes.AutoForwardDiff,
f,
x ::AbstractVector,
aux,
out::DiffResults.MutableDiffResult
x::AbstractVector,
aux,
out::DiffResults.MutableDiffResult,
)
AdvancedVI.value_and_gradient!(ad, x′ -> f(x′, aux), x, out)
return AdvancedVI.value_and_gradient!(ad, x′ -> f(x′, aux), x, out)
end

end
6 changes: 3 additions & 3 deletions ext/AdvancedVIReverseDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ function AdvancedVI.value_and_gradient!(
ad::ADTypes.AutoReverseDiff,
f,
x::AbstractVector{<:Real},
out::DiffResults.MutableDiffResult
out::DiffResults.MutableDiffResult,
)
tp = ReverseDiff.GradientTape(f, x)
ReverseDiff.gradient!(out, tp, x)
Expand All @@ -28,9 +28,9 @@ function AdvancedVI.value_and_gradient!(
f,
x::AbstractVector{<:Real},
aux,
out::DiffResults.MutableDiffResult
out::DiffResults.MutableDiffResult,
)
AdvancedVI.value_and_gradient!(ad, x′ -> f(x′, aux), x, out)
return AdvancedVI.value_and_gradient!(ad, x′ -> f(x′, aux), x, out)
end

end
9 changes: 3 additions & 6 deletions ext/AdvancedVIZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,7 @@ else
end

function AdvancedVI.value_and_gradient!(
::ADTypes.AutoZygote,
f,
x::AbstractVector{<:Real},
out::DiffResults.MutableDiffResult
::ADTypes.AutoZygote, f, x::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult
)
y, back = Zygote.pullback(f, x)
∇x = back(one(y))
Expand All @@ -31,9 +28,9 @@ function AdvancedVI.value_and_gradient!(
f,
x::AbstractVector{<:Real},
aux,
out::DiffResults.MutableDiffResult
out::DiffResults.MutableDiffResult,
)
AdvancedVI.value_and_gradient!(ad, x′ -> f(x′, aux), x, out)
return AdvancedVI.value_and_gradient!(ad, x′ -> f(x′, aux), x, out)
end

end
29 changes: 6 additions & 23 deletions src/AdvancedVI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,9 @@ Instead, the return values should be used.
"""
function update_variational_params! end

update_variational_params!(::Type, opt_st, params, restructure, grad) =
Optimisers.update!(opt_st, params, grad)
function update_variational_params!(::Type, opt_st, params, restructure, grad)
return Optimisers.update!(opt_st, params, grad)
end

# estimators
"""
Expand All @@ -105,13 +106,7 @@ This function needs to be implemented only if `obj` is stateful.
- `params`: Initial variational parameters.
- `restructure`: Function that reconstructs the variational approximation from `λ`.
"""
init(
::Random.AbstractRNG,
::AbstractVariationalObjective,
::Any,
::Any,
::Any,
) = nothing
init(::Random.AbstractRNG, ::AbstractVariationalObjective, ::Any, ::Any, ::Any) = nothing

"""
estimate_objective([rng,] obj, q, prob; kwargs...)
Expand All @@ -135,7 +130,6 @@ function estimate_objective end

export estimate_objective


"""
estimate_gradient!(rng, obj, adtype, out, prob, λ, restructure, obj_state)
Expand Down Expand Up @@ -176,25 +170,16 @@ Estimate the entropy of `q`.
"""
function estimate_entropy end

export
RepGradELBO,
ClosedFormEntropy,
StickingTheLandingEntropy,
MonteCarloEntropy
export RepGradELBO, ClosedFormEntropy, StickingTheLandingEntropy, MonteCarloEntropy

include("objectives/elbo/entropy.jl")
include("objectives/elbo/repgradelbo.jl")


# Variational Families
export
MvLocationScale,
MeanFieldGaussian,
FullRankGaussian
export MvLocationScale, MeanFieldGaussian, FullRankGaussian

include("families/location_scale.jl")


# Optimization Routine

function optimize end
Expand All @@ -204,7 +189,6 @@ export optimize
include("utils.jl")
include("optimize.jl")


# optional dependencies
if !isdefined(Base, :get_extension) # check whether :get_extension is defined in Base
using Requires
Expand All @@ -231,4 +215,3 @@ end
end

end

Loading

0 comments on commit 4978d1f

Please sign in to comment.