Skip to content

Commit 10ef594

Browse files
committed
Merge remote-tracking branch 'origin/v0.2-backport' into v0.2-backport
2 parents 0b9778c + 150a35a commit 10ef594

File tree

10 files changed

+256
-131
lines changed

10 files changed

+256
-131
lines changed

Project.toml

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
name = "AdvancedVI"
22
uuid = "b5ca4192-6429-45e5-a2d9-87aec30a685c"
3-
version = "0.2.4"
3+
version = "0.2.5"
44

55
[deps]
6+
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
67
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
8+
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
79
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
810
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
911
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
@@ -16,22 +18,39 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1618
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
1719
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
1820

21+
[weakdeps]
22+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
23+
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
24+
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
25+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
26+
27+
[extensions]
28+
AdvancedVIEnzymeExt = ["Enzyme"]
29+
AdvancedVIFluxExt = ["Flux"]
30+
AdvancedVIReverseDiffExt = ["ReverseDiff"]
31+
AdvancedVIZygoteExt = ["Zygote"]
32+
1933
[compat]
2034
Bijectors = "0.11, 0.12, 0.13"
2135
Distributions = "0.21, 0.22, 0.23, 0.24, 0.25"
2236
DistributionsAD = "0.2, 0.3, 0.4, 0.5, 0.6"
2337
DocStringExtensions = "0.8, 0.9"
38+
Enzyme = "0.12"
39+
LinearAlgebra = "1.6"
2440
ForwardDiff = "0.10.3"
41+
Flux = "0.14"
2542
ProgressMeter = "1.0.0"
26-
Requires = "0.5, 1.0"
43+
Random = "1.6"
44+
Requires = "1"
45+
ReverseDiff = "1"
2746
StatsBase = "0.32, 0.33, 0.34"
2847
StatsFuns = "0.8, 0.9, 1"
2948
Tracker = "0.2.3"
49+
Zygote = "0.6"
3050
julia = "1.6"
3151

3252
[extras]
33-
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
34-
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
35-
36-
[targets]
37-
test = ["Pkg", "Test"]
53+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
54+
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
55+
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
56+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

ext/AdvancedVIEnzymeExt.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
module AdvancedVIEnzymeExt
2+
3+
if isdefined(Base, :get_extension)
4+
using AdvancedVI: AdvancedVI, ADTypes, DiffResults, Distributions
5+
using Enzyme: Enzyme
6+
else
7+
using ..AdvancedVI: AdvancedVI, ADTypes, DiffResults, Distributions
8+
using ..Enzyme: Enzyme
9+
end
10+
11+
AdvancedVI.ADBackend(::Val{:enzyme}) = ADTypes.AutoEnzyme()
12+
function AdvancedVI.setadbackend(::Val{:enzyme})
13+
Base.depwarn("`setadbackend` is deprecated. Please pass a `ADTypes.AbstractADType` as a keyword argument to the VI algorithm.", :setadbackend)
14+
AdvancedVI.ADBACKEND[] = :enzyme
15+
end
16+
17+
function AdvancedVI.grad!(
18+
vo,
19+
alg::AdvancedVI.VariationalInference{<:ADTypes.AutoEnzyme},
20+
q,
21+
model,
22+
θ::AbstractVector{<:Real},
23+
out::DiffResults.MutableDiffResult,
24+
args...
25+
)
26+
f(θ) =
27+
if (q isa Distributions.Distribution)
28+
-vo(alg, AdvancedVI.update(q, θ), model, args...)
29+
else
30+
-vo(alg, q(θ), model, args...)
31+
end
32+
# Use `Enzyme.ReverseWithPrimal` once it is released:
33+
# https://github.com/EnzymeAD/Enzyme.jl/pull/598
34+
y = f(θ)
35+
DiffResults.value!(out, y)
36+
dy = DiffResults.gradient(out)
37+
fill!(dy, 0)
38+
Enzyme.autodiff(Enzyme.ReverseWithPrimal, f, Enzyme.Active, Enzyme.Duplicated(θ, dy))
39+
return out
40+
end
41+
42+
end

ext/AdvancedVIFluxExt.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
module AdvancedVIFluxExt
2+
3+
if isdefined(Base, :get_extension)
4+
using AdvancedVI: AdvancedVI
5+
using Flux: Flux
6+
else
7+
using ..AdvancedVI: AdvancedVI
8+
using ..Flux: Flux
9+
end
10+
11+
AdvancedVI.apply!(o::Flux.Optimise.AbstractOptimiser, x, Δ) = Flux.Optimise.apply!(o, x, Δ)
12+
13+
end

ext/AdvancedVIReverseDiffExt.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
module AdvancedVIReverseDiffExt
2+
3+
if isdefined(Base, :get_extension)
4+
using AdvancedVI: AdvancedVI, ADTypes, DiffResults, Distributions
5+
using ReverseDiff: ReverseDiff
6+
else
7+
using ..AdvancedVI: AdvancedVI, ADTypes, DiffResults, Distributions
8+
using ..ReverseDiff: ReverseDiff
9+
end
10+
11+
AdvancedVI.ADBackend(::Val{:reversediff}) = ADTypes.AutoReverseDiff()
12+
13+
function AdvancedVI.setadbackend(::Val{:reversediff})
14+
Base.depwarn("`setadbackend` is deprecated. Please pass a `ADTypes.AbstractADType` as a keyword argument to the VI algorithm.", :setadbackend)
15+
AdvancedVI.ADBACKEND[] = :reversediff
16+
end
17+
18+
tape(f, x) = ReverseDiff.GradientTape(f, x)
19+
20+
function AdvancedVI.grad!(
21+
vo,
22+
alg::AdvancedVI.VariationalInference{<:ADTypes.AutoReverseDiff},
23+
q,
24+
model,
25+
θ::AbstractVector{<:Real},
26+
out::DiffResults.MutableDiffResult,
27+
args...
28+
)
29+
f(θ) =
30+
if (q isa Distributions.Distribution)
31+
-vo(alg, AdvancedVI.update(q, θ), model, args...)
32+
else
33+
-vo(alg, q(θ), model, args...)
34+
end
35+
tp = tape(f, θ)
36+
ReverseDiff.gradient!(out, tp, θ)
37+
return out
38+
end
39+
40+
end

ext/AdvancedVIZygoteExt.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
module AdvancedVIZygoteExt
2+
3+
if isdefined(Base, :get_extension)
4+
using AdvancedVI: AdvancedVI, ADTypes, DiffResults, Distributions
5+
using Zygote: Zygote
6+
else
7+
using ..AdvancedVI: AdvancedVI, ADTypes, DiffResults, Distributions
8+
using ..Zygote: Zygote
9+
end
10+
11+
AdvancedVI.ADBackend(::Val{:zygote}) = ADTypes.AutoZygote()
12+
function AdvancedVI.setadbackend(::Val{:zygote})
13+
Base.depwarn("`setadbackend` is deprecated. Please pass a `ADTypes.AbstractADType` as a keyword argument to the VI algorithm.", :setadbackend)
14+
AdvancedVI.ADBACKEND[] = :zygote
15+
end
16+
17+
function AdvancedVI.grad!(
18+
vo,
19+
alg::AdvancedVI.VariationalInference{<:ADTypes.AutoZygote},
20+
q,
21+
model,
22+
θ::AbstractVector{<:Real},
23+
out::DiffResults.MutableDiffResult,
24+
args...
25+
)
26+
f(θ) =
27+
if (q isa Distributions.Distribution)
28+
-vo(alg, AdvancedVI.update(q, θ), model, args...)
29+
else
30+
-vo(alg, q(θ), model, args...)
31+
end
32+
y, back = Zygote.pullback(f, θ)
33+
dy = first(back(1.0))
34+
DiffResults.value!(out, y)
35+
DiffResults.gradient!(out, dy)
36+
return out
37+
end
38+
39+
end

src/AdvancedVI.jl

Lines changed: 34 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,11 @@ using DocStringExtensions
77

88
using ProgressMeter, LinearAlgebra
99

10-
using ForwardDiff
11-
using Tracker
10+
using ADTypes: ADTypes
11+
using DiffResults: DiffResults
12+
13+
using ForwardDiff: ForwardDiff
14+
using Tracker: Tracker
1215

1316
const PROGRESS = Ref(true)
1417
function turnprogress(switch::Bool)
@@ -18,94 +21,6 @@ end
1821

1922
const DEBUG = Bool(parse(Int, get(ENV, "DEBUG_ADVANCEDVI", "0")))
2023

21-
include("ad.jl")
22-
include("utils.jl")
23-
24-
using Requires
25-
function __init__()
26-
@require Flux="587475ba-b771-5e3f-ad9e-33799f191a9c" begin
27-
apply!(o, x, Δ) = Flux.Optimise.apply!(o, x, Δ)
28-
Flux.Optimise.apply!(o::TruncatedADAGrad, x, Δ) = apply!(o, x, Δ)
29-
Flux.Optimise.apply!(o::DecayedADAGrad, x, Δ) = apply!(o, x, Δ)
30-
end
31-
@require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin
32-
include("compat/zygote.jl")
33-
export ZygoteAD
34-
35-
function AdvancedVI.grad!(
36-
vo,
37-
alg::VariationalInference{<:AdvancedVI.ZygoteAD},
38-
q,
39-
model,
40-
θ::AbstractVector{<:Real},
41-
out::DiffResults.MutableDiffResult,
42-
args...
43-
)
44-
f(θ) = if (q isa Distribution)
45-
- vo(alg, update(q, θ), model, args...)
46-
else
47-
- vo(alg, q(θ), model, args...)
48-
end
49-
y, back = Zygote.pullback(f, θ)
50-
dy = first(back(1.0))
51-
DiffResults.value!(out, y)
52-
DiffResults.gradient!(out, dy)
53-
return out
54-
end
55-
end
56-
@require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" begin
57-
include("compat/reversediff.jl")
58-
export ReverseDiffAD
59-
60-
function AdvancedVI.grad!(
61-
vo,
62-
alg::VariationalInference{<:AdvancedVI.ReverseDiffAD{false}},
63-
q,
64-
model,
65-
θ::AbstractVector{<:Real},
66-
out::DiffResults.MutableDiffResult,
67-
args...
68-
)
69-
f(θ) = if (q isa Distribution)
70-
- vo(alg, update(q, θ), model, args...)
71-
else
72-
- vo(alg, q(θ), model, args...)
73-
end
74-
tp = AdvancedVI.tape(f, θ)
75-
ReverseDiff.gradient!(out, tp, θ)
76-
return out
77-
end
78-
end
79-
@require Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" begin
80-
include("compat/enzyme.jl")
81-
export EnzymeAD
82-
83-
function AdvancedVI.grad!(
84-
vo,
85-
alg::VariationalInference{<:AdvancedVI.EnzymeAD},
86-
q,
87-
model,
88-
θ::AbstractVector{<:Real},
89-
out::DiffResults.MutableDiffResult,
90-
args...
91-
)
92-
f(θ) = if (q isa Distribution)
93-
- vo(alg, update(q, θ), model, args...)
94-
else
95-
- vo(alg, q(θ), model, args...)
96-
end
97-
# Use `Enzyme.ReverseWithPrimal` once it is released:
98-
# https://github.com/EnzymeAD/Enzyme.jl/pull/598
99-
y = f(θ)
100-
DiffResults.value!(out, y)
101-
dy = DiffResults.gradient(out)
102-
fill!(dy, 0)
103-
Enzyme.autodiff(Enzyme.ReverseWithPrimal, f, Enzyme.Active, Enzyme.Duplicated(θ, dy))
104-
return out
105-
end
106-
end
107-
end
108-
10924
export
11025
vi,
11126
ADVI,
@@ -115,10 +30,12 @@ export
11530
DecayedADAGrad,
11631
VariationalInference
11732

33+
include("utils.jl")
34+
include("ad.jl")
35+
11836
abstract type VariationalInference{AD} end
11937

120-
getchunksize(::Type{<:VariationalInference{AD}}) where AD = getchunksize(AD)
121-
getADtype(::VariationalInference{AD}) where AD = AD
38+
getchunksize(::ADTypes.AutoForwardDiff{chunk}) where chunk = chunk === nothing ? 0 : chunk
12239

12340
abstract type VariationalObjective end
12441

@@ -129,7 +46,7 @@ const VariationalPosterior = Distribution{Multivariate, Continuous}
12946
grad!(vo, alg::VariationalInference, q, model::Model, θ, out, args...)
13047
13148
Computes the gradients used in `optimize!`. Default implementation is provided for
132-
`VariationalInference{AD}` where `AD` is either `ForwardDiffAD` or `TrackerAD`.
49+
`VariationalInference{AD}` where `AD` is either `ADTypes.AutoForwardDiff` or `ADTypes.AutoTracker`.
13350
This implicitly also gives a default implementation of `optimize!`.
13451
13552
Variance reduction techniques, e.g. control variates, should be implemented in this function.
@@ -158,7 +75,7 @@ function update end
15875
# default implementations
15976
function grad!(
16077
vo,
161-
alg::VariationalInference{<:ForwardDiffAD},
78+
alg::VariationalInference{<:ADTypes.AutoForwardDiff},
16279
q,
16380
model,
16481
θ::AbstractVector{<:Real},
@@ -172,7 +89,7 @@ function grad!(
17289
end
17390

17491
# Set chunk size and do ForwardMode.
175-
chunk_size = getchunksize(typeof(alg))
92+
chunk_size = getchunksize(alg.adtype)
17693
config = if chunk_size == 0
17794
ForwardDiff.GradientConfig(f, θ)
17895
else
@@ -183,7 +100,7 @@ end
183100

184101
function grad!(
185102
vo,
186-
alg::VariationalInference{<:TrackerAD},
103+
alg::VariationalInference{<:ADTypes.AutoTracker},
187104
q,
188105
model,
189106
θ::AbstractVector{<:Real},
@@ -267,4 +184,25 @@ include("optimisers.jl")
267184
# VI algorithms
268185
include("advi.jl")
269186

187+
if !isdefined(Base, :get_extension)
188+
using Requires
189+
end
190+
191+
@static if !isdefined(Base, :get_extension)
192+
function __init__()
193+
@require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" include(
194+
"../ext/AdvancedVIReverseDiffExt.jl"
195+
)
196+
@require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" include(
197+
"../ext/AdvancedVIZygoteExt.jl"
198+
)
199+
@require Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" include(
200+
"../ext/AdvancedVIEnzymeExt.jl"
201+
)
202+
@require Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" include(
203+
"../ext/AdvancedVIFluxExt.jl"
204+
)
205+
end
206+
end
207+
270208
end # module

0 commit comments

Comments
 (0)