Skip to content

Commit 150a35a

Browse files
authored
Backport of some features to v0.2 (#56)
* added Turing integration tests * added weakdeps * fixed Project.toml * added extensions * moved to usage of extensions + ADTypes.jl * added test toml * added Flux and Enzyme as weakdeps * added Enzyme ext * fixed accidental includes * fix requires and disa le Enzyme tests * bump patch version * fixed Requires usage * another Project.toml fix * more toml fixing * maybe now
1 parent b0c4be3 commit 150a35a

File tree

11 files changed

+306
-131
lines changed

11 files changed

+306
-131
lines changed

.github/workflows/IntegrationTest.yml

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
name: IntegrationTest
2+
3+
on:
4+
push:
5+
branches:
6+
- master
7+
merge_group:
8+
types: [checks_requested]
9+
pull_request:
10+
branches: [v0.2-backport]
11+
12+
jobs:
13+
test:
14+
name: ${{ matrix.package.repo }}
15+
runs-on: ubuntu-latest
16+
strategy:
17+
fail-fast: false
18+
matrix:
19+
package:
20+
- {user: TuringLang, repo: Turing.jl}
21+
22+
steps:
23+
- uses: actions/checkout@v2
24+
- uses: julia-actions/setup-julia@v1
25+
with:
26+
version: 1
27+
arch: x64
28+
- uses: julia-actions/julia-buildpkg@latest
29+
- name: Clone Downstream
30+
uses: actions/checkout@v2
31+
with:
32+
repository: ${{ matrix.package.user }}/${{ matrix.package.repo }}
33+
path: downstream
34+
- name: Load this and run the downstream tests
35+
shell: julia --color=yes --project=downstream {0}
36+
run: |
37+
using Pkg
38+
try
39+
# force it to use this PR's version of the package
40+
Pkg.develop(PackageSpec(path=".")) # resolver may fail with main deps
41+
Pkg.update()
42+
Pkg.test(julia_args=["--depwarn=no"]) # resolver may fail with test time deps
43+
catch err
44+
err isa Pkg.Resolve.ResolverError || rethrow()
45+
# If we can't resolve that means this is incompatible by SemVer and this is fine
46+
# It means we marked this as a breaking change, so we don't need to worry about
47+
# Mistakenly introducing a breaking change, as we have intentionally made one
48+
@info "Not compatible with this release. No problem." exception=err
49+
exit(0) # Exit immediately, as a success
50+
end

Project.toml

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
name = "AdvancedVI"
22
uuid = "b5ca4192-6429-45e5-a2d9-87aec30a685c"
3-
version = "0.2.4"
3+
version = "0.2.5"
44

55
[deps]
6+
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
67
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
8+
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
79
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
810
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
911
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
@@ -16,22 +18,39 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1618
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
1719
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
1820

21+
[weakdeps]
22+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
23+
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
24+
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
25+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
26+
27+
[extensions]
28+
AdvancedVIEnzymeExt = ["Enzyme"]
29+
AdvancedVIFluxExt = ["Flux"]
30+
AdvancedVIReverseDiffExt = ["ReverseDiff"]
31+
AdvancedVIZygoteExt = ["Zygote"]
32+
1933
[compat]
2034
Bijectors = "0.11, 0.12, 0.13"
2135
Distributions = "0.21, 0.22, 0.23, 0.24, 0.25"
2236
DistributionsAD = "0.2, 0.3, 0.4, 0.5, 0.6"
2337
DocStringExtensions = "0.8, 0.9"
38+
Enzyme = "0.12"
39+
LinearAlgebra = "1.6"
2440
ForwardDiff = "0.10.3"
41+
Flux = "0.14"
2542
ProgressMeter = "1.0.0"
26-
Requires = "0.5, 1.0"
43+
Random = "1.6"
44+
Requires = "1"
45+
ReverseDiff = "1"
2746
StatsBase = "0.32, 0.33, 0.34"
2847
StatsFuns = "0.8, 0.9, 1"
2948
Tracker = "0.2.3"
49+
Zygote = "0.6"
3050
julia = "1.6"
3151

3252
[extras]
33-
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
34-
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
35-
36-
[targets]
37-
test = ["Pkg", "Test"]
53+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
54+
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
55+
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
56+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

ext/AdvancedVIEnzymeExt.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
module AdvancedVIEnzymeExt
2+
3+
if isdefined(Base, :get_extension)
4+
using AdvancedVI: AdvancedVI, ADTypes, DiffResults, Distributions
5+
using Enzyme: Enzyme
6+
else
7+
using ..AdvancedVI: AdvancedVI, ADTypes, DiffResults, Distributions
8+
using ..Enzyme: Enzyme
9+
end
10+
11+
AdvancedVI.ADBackend(::Val{:enzyme}) = ADTypes.AutoEnzyme()
12+
function AdvancedVI.setadbackend(::Val{:enzyme})
13+
Base.depwarn("`setadbackend` is deprecated. Please pass a `ADTypes.AbstractADType` as a keyword argument to the VI algorithm.", :setadbackend)
14+
AdvancedVI.ADBACKEND[] = :enzyme
15+
end
16+
17+
function AdvancedVI.grad!(
18+
vo,
19+
alg::AdvancedVI.VariationalInference{<:ADTypes.AutoEnzyme},
20+
q,
21+
model,
22+
θ::AbstractVector{<:Real},
23+
out::DiffResults.MutableDiffResult,
24+
args...
25+
)
26+
f(θ) =
27+
if (q isa Distributions.Distribution)
28+
-vo(alg, AdvancedVI.update(q, θ), model, args...)
29+
else
30+
-vo(alg, q(θ), model, args...)
31+
end
32+
# Use `Enzyme.ReverseWithPrimal` once it is released:
33+
# https://github.com/EnzymeAD/Enzyme.jl/pull/598
34+
y = f(θ)
35+
DiffResults.value!(out, y)
36+
dy = DiffResults.gradient(out)
37+
fill!(dy, 0)
38+
Enzyme.autodiff(Enzyme.ReverseWithPrimal, f, Enzyme.Active, Enzyme.Duplicated(θ, dy))
39+
return out
40+
end
41+
42+
end

ext/AdvancedVIFluxExt.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
module AdvancedVIFluxExt
2+
3+
if isdefined(Base, :get_extension)
4+
using AdvancedVI: AdvancedVI
5+
using Flux: Flux
6+
else
7+
using ..AdvancedVI: AdvancedVI
8+
using ..Flux: Flux
9+
end
10+
11+
AdvancedVI.apply!(o::Flux.Optimise.AbstractOptimiser, x, Δ) = Flux.Optimise.apply!(o, x, Δ)
12+
13+
end

ext/AdvancedVIReverseDiffExt.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
module AdvancedVIReverseDiffExt
2+
3+
if isdefined(Base, :get_extension)
4+
using AdvancedVI: AdvancedVI, ADTypes, DiffResults, Distributions
5+
using ReverseDiff: ReverseDiff
6+
else
7+
using ..AdvancedVI: AdvancedVI, ADTypes, DiffResults, Distributions
8+
using ..ReverseDiff: ReverseDiff
9+
end
10+
11+
AdvancedVI.ADBackend(::Val{:reversediff}) = ADTypes.AutoReverseDiff()
12+
13+
function AdvancedVI.setadbackend(::Val{:reversediff})
14+
Base.depwarn("`setadbackend` is deprecated. Please pass a `ADTypes.AbstractADType` as a keyword argument to the VI algorithm.", :setadbackend)
15+
AdvancedVI.ADBACKEND[] = :reversediff
16+
end
17+
18+
tape(f, x) = ReverseDiff.GradientTape(f, x)
19+
20+
function AdvancedVI.grad!(
21+
vo,
22+
alg::AdvancedVI.VariationalInference{<:ADTypes.AutoReverseDiff},
23+
q,
24+
model,
25+
θ::AbstractVector{<:Real},
26+
out::DiffResults.MutableDiffResult,
27+
args...
28+
)
29+
f(θ) =
30+
if (q isa Distributions.Distribution)
31+
-vo(alg, AdvancedVI.update(q, θ), model, args...)
32+
else
33+
-vo(alg, q(θ), model, args...)
34+
end
35+
tp = tape(f, θ)
36+
ReverseDiff.gradient!(out, tp, θ)
37+
return out
38+
end
39+
40+
end

ext/AdvancedVIZygoteExt.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
module AdvancedVIZygoteExt
2+
3+
if isdefined(Base, :get_extension)
4+
using AdvancedVI: AdvancedVI, ADTypes, DiffResults, Distributions
5+
using Zygote: Zygote
6+
else
7+
using ..AdvancedVI: AdvancedVI, ADTypes, DiffResults, Distributions
8+
using ..Zygote: Zygote
9+
end
10+
11+
AdvancedVI.ADBackend(::Val{:zygote}) = ADTypes.AutoZygote()
12+
function AdvancedVI.setadbackend(::Val{:zygote})
13+
Base.depwarn("`setadbackend` is deprecated. Please pass a `ADTypes.AbstractADType` as a keyword argument to the VI algorithm.", :setadbackend)
14+
AdvancedVI.ADBACKEND[] = :zygote
15+
end
16+
17+
function AdvancedVI.grad!(
18+
vo,
19+
alg::AdvancedVI.VariationalInference{<:ADTypes.AutoZygote},
20+
q,
21+
model,
22+
θ::AbstractVector{<:Real},
23+
out::DiffResults.MutableDiffResult,
24+
args...
25+
)
26+
f(θ) =
27+
if (q isa Distributions.Distribution)
28+
-vo(alg, AdvancedVI.update(q, θ), model, args...)
29+
else
30+
-vo(alg, q(θ), model, args...)
31+
end
32+
y, back = Zygote.pullback(f, θ)
33+
dy = first(back(1.0))
34+
DiffResults.value!(out, y)
35+
DiffResults.gradient!(out, dy)
36+
return out
37+
end
38+
39+
end

0 commit comments

Comments
 (0)