Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test Enzyme and reexport ADTypes.AutoEnzyme #1887

Draft
wants to merge 72 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 60 commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
29017e7
Add support for Enzyme
devmotion Sep 28, 2022
fdf0d43
Merge branch 'master' into dw/enzyme
yebai Nov 7, 2022
d7ef23e
Merge branch 'master' into dw/enzyme
yebai Nov 12, 2022
874b9e7
Merge branch 'master' into dw/enzyme
devmotion Dec 23, 2022
43ef4c4
Apply suggestions from code review
devmotion Dec 23, 2022
3e5841f
Add Enzyme to test dependencies
devmotion Dec 29, 2022
66bce4e
Test Enzyme
devmotion Dec 29, 2022
7890134
Update ad.jl
devmotion Dec 29, 2022
120e9f5
Merge branch 'master' into dw/enzyme
yebai Feb 3, 2023
f4bd1bf
Update Project.toml
yebai Feb 3, 2023
c8e01d0
Update advi.jl
yebai Feb 3, 2023
d8d7729
Merge branch 'master' into dw/enzyme
yebai Feb 16, 2023
1d1dba0
Merge branch 'master' into dw/enzyme
yebai Mar 1, 2023
946e594
Do not call `Bijectors.setadbackend`
devmotion Mar 7, 2023
edd19a4
Merge branch 'master' into dw/enzyme
yebai Apr 12, 2023
e9eedd1
Update Project.toml
devmotion Apr 13, 2023
9cfb589
Merge branch 'master' into dw/enzyme
yebai May 25, 2023
cf03624
Merge branch 'master' into dw/enzyme
yebai Jun 14, 2023
0a5c42b
Merge branch 'master' into dw/enzyme
devmotion Jun 24, 2023
8d8d031
Address comments
devmotion Jun 26, 2023
e591630
Update runtests.jl
devmotion Jun 27, 2023
568cdac
Update Project.toml
devmotion Jul 7, 2023
d00f297
Merge branch 'master' into dw/enzyme
devmotion Jul 7, 2023
6f0bf67
Update Project.toml
devmotion Jul 7, 2023
5ba7ac6
Update Project.toml
devmotion Jul 13, 2023
162755b
Test against Enzyme#main
devmotion Jul 14, 2023
ce26c3c
Merge branch 'master' into dw/enzyme
devmotion Jul 14, 2023
b35ab28
Merge branch 'master' into dw/enzyme
yebai Jul 21, 2023
e44e756
Try addr13 branch
devmotion Jul 24, 2023
1f1b114
Update runtests.jl
devmotion Jul 27, 2023
1b3fa60
Merge branch 'master' into dw/enzyme
devmotion Jul 27, 2023
aad8a1a
Merge branch 'master' into dw/enzyme
yebai Jul 30, 2023
bb795e6
Disable Gibbs tests temporarily
yebai Jul 31, 2023
1c7f20e
Update test/Project.toml
yebai Jul 31, 2023
2a40639
Merge branch 'master' into dw/enzyme
yebai Aug 8, 2023
1b87d2e
Merge branch 'master' into dw/enzyme
yebai Aug 14, 2023
dad6b97
Merge branch 'master' into dw/enzyme
yebai Sep 4, 2023
012a0cb
disable tests unrelated to enzyme + limit CI to avoid over-use of res…
torfjelde Sep 23, 2023
552b01f
Merge branch 'master' into dw/enzyme
sunxd3 Dec 12, 2023
5777344
import `AutoEnzyme`
sunxd3 Dec 12, 2023
1ffdbca
Merge branch 'master' into dw/enzyme
sunxd3 Dec 16, 2023
121df7d
Test hmc only
sunxd3 Dec 16, 2023
a164707
Update sghmc.jl
wsmoses Dec 21, 2023
97f1fb6
Update runtests.jl
wsmoses Dec 21, 2023
c7b6cf4
disable Type unstable getfield
wsmoses Jan 25, 2024
efdd8e7
use release
wsmoses Jan 25, 2024
2fdf546
Remove seemingly unnecessary definition
devmotion Jan 25, 2024
4d8cd23
Run tests on Enzyme#main again
devmotion Jan 26, 2024
47292a7
Merge branch 'master' into dw/enzyme
wsmoses Jan 27, 2024
4b00f0d
Merge branch 'master' into dw/enzyme
wsmoses Feb 10, 2024
889275e
Merge branch 'master' into dw/enzyme
yebai Feb 27, 2024
578967b
Merge branch 'master' into dw/enzyme
yebai Mar 4, 2024
b8296be
Test with cholesky fixes
devmotion Mar 13, 2024
0385250
Merge branch 'master' into dw/enzyme
yebai Apr 8, 2024
24cc3a9
Merge branch 'master' into dw/enzyme
yebai May 29, 2024
2b54d69
Update Project.toml
yebai May 29, 2024
2823a41
Update Turing.jl
yebai May 29, 2024
f4c72bd
Merge remote-tracking branch 'origin/master' into dw/enzyme
mhauru Jun 21, 2024
6b7159c
Merge branch 'master' into dw/enzyme
devmotion Jul 1, 2024
0c376a6
Attempt at fix for `bnn` tests as outlined in #2277
torfjelde Jul 1, 2024
76b5e48
Update test/runtests.jl
yebai Jul 9, 2024
784b8cb
Update runtests.jl
yebai Jul 9, 2024
5bfd06d
remove implicit usage of `hvcat`
torfjelde Jul 9, 2024
836e29b
Merge branch 'master' into dw/enzyme
yebai Jul 9, 2024
e299042
Merge branch 'master' into dw/enzyme
devmotion Jul 10, 2024
d4d55d6
Merge branch 'master' into dw/enzyme
wsmoses Jul 21, 2024
ce13e03
Re-activate CIs disabled for Enzyme testing
torfjelde Jul 25, 2024
19a3332
Merge branch 'master' into dw/enzyme
yebai Jul 31, 2024
e2c0693
Re-enable tests with other AD backends
devmotion Aug 15, 2024
387018d
Merge branch 'master' into dw/enzyme
devmotion Aug 15, 2024
2115d52
Load `@test_broken`
devmotion Aug 15, 2024
b7ad9db
Merge branch 'master' into dw/enzyme
yebai Sep 3, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .github/workflows/Tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,15 @@ jobs:
- "mcmc/ess.jl"
- "--skip essential/ad.jl mcmc/gibbs.jl mcmc/hmc.jl mcmc/abstractmcmc.jl mcmc/Inference.jl experimental/gibbs.jl mcmc/ess.jl"
version:
- '1.7'
#- '1.7' TODO(mhauru): Temporarily disabled for Enzyme
- '1'
os:
- ubuntu-latest
- windows-latest
- macOS-latest
#- windows-latest TODO(mhauru): Temporarily disabled for Enzyme
#- macOS-latest TODO(mhauru): Temporarily disabled for Enzyme
arch:
- x64
- x86
#- x86 TODO(mhauru): Temporarily disabled for Enzyme
num_threads:
- 1
- 2
Expand Down
1 change: 1 addition & 0 deletions src/Turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ export @model, # modelling
AutoForwardDiff, # ADTypes
AutoReverseDiff,
AutoZygote,
AutoEnzyme,
AutoTracker,
setprogress!, # debugging
Flat,
Expand Down
4 changes: 3 additions & 1 deletion src/essential/Essential.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,16 @@ using Bijectors: PDMatDistribution
using AdvancedVI
using StatsFuns: logsumexp, softmax
@reexport using DynamicPPL
using ADTypes: ADTypes, AutoForwardDiff, AutoTracker, AutoReverseDiff, AutoZygote
using ADTypes:
ADTypes, AutoForwardDiff, AutoEnzyme, AutoTracker, AutoReverseDiff, AutoZygote

using AdvancedPS: AdvancedPS

include("container.jl")

export @model,
@varname,
AutoEnzyme,
AutoForwardDiff,
AutoTracker,
AutoZygote,
Expand Down
3 changes: 2 additions & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb"
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5"
Expand Down Expand Up @@ -45,6 +46,7 @@ Clustering = "0.14, 0.15"
Distributions = "0.25"
DistributionsAD = "0.6.3"
DynamicHMC = "2.1.6, 3.0"
Enzyme = "0.12"
DynamicPPL = "0.28"
FiniteDifferences = "0.10.8, 0.11, 0.12"
ForwardDiff = "0.10.12 - 0.10.32, 0.10"
Expand All @@ -70,4 +72,3 @@ StatsFuns = "0.9.5, 1"
TimerOutputs = "0.5"
Tracker = "0.2.11"
Zygote = "0.5.4, 0.6"
julia = "1.3"
12 changes: 11 additions & 1 deletion test/mcmc/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using Distributions: Bernoulli, Beta, InverseGamma, Normal
using Distributions: sample
import DynamicPPL
using DynamicPPL: Sampler, getlogp
import Enzyme
import ForwardDiff
using LinearAlgebra: I
import MCMCChains
Expand All @@ -14,7 +15,13 @@ import ReverseDiff
using Test: @test, @test_throws, @testset
using Turing

@testset "Testing inference.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(; compile=false))
Enzyme.API.typeWarning!(false)

# Enable runtime activity (workaround)
Enzyme.API.runtimeActivity!(true)

# @testset "Testing inference.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(; compile=false))
@testset "Testing inference.jl with $adbackend" for adbackend in (AutoEnzyme(),)
# Only test threading if 1.3+.
if VERSION > v"1.2"
@testset "threaded sampling" begin
Expand Down Expand Up @@ -367,6 +374,8 @@ using Turing
alg = Gibbs(HMC(0.2, 3, :m; adtype=adbackend), PG(10, :s))
chn = sample(gdemo_default, alg, 1000)
end
# Type unstable getfield of tuple not supported in Enzyme yet
if adbackend != AutoEnzyme()
@testset "vectorization @." begin
# https://github.com/FluxML/Tracker.jl/issues/119
@model function vdemo1(x)
Expand Down Expand Up @@ -549,6 +558,7 @@ using Turing
vdemo3kw(; T) = vdemo3(T)
sample(vdemo3kw(; T=DynamicPPL.TypeWrap{Vector{Float64}}()), alg, 250)
end
end

@testset "names_values" begin
ks, xs = Turing.Inference.names_values([
Expand Down
14 changes: 11 additions & 3 deletions test/mcmc/hmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using ..NumericalTests: check_gdemo, check_numerical
using Distributions: Bernoulli, Beta, Categorical, Dirichlet, Normal, Wishart, sample
import DynamicPPL
using DynamicPPL: Sampler
import Enzyme
import ForwardDiff
using HypothesisTests: ApproximateTwoSampleKSTest, pvalue
import ReverseDiff
Expand All @@ -16,7 +17,14 @@ using StatsFuns: logistic
using Test: @test, @test_logs, @testset
using Turing

@testset "Testing hmc.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(; compile=false))
# Disable Enzyme warnings
Enzyme.API.typeWarning!(false)

# Enable runtime activity (workaround)
Enzyme.API.runtimeActivity!(true)

# @testset "Testing hmc.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(; compile=false))
@testset "Testing hmc.jl with $adbackend" for adbackend in (AutoEnzyme(),)
# Set a seed
rng = StableRNG(123)
@testset "constrained bounded" begin
Expand Down Expand Up @@ -103,7 +111,7 @@ using Turing
alpha = 0.16 # regularizatin term
var_prior = sqrt(1.0 / alpha) # variance of the Gaussian prior

@model function bnn(ts)
@model function bnn(ts, var_prior)
b1 ~ MvNormal([0. ;0.; 0.],
[var_prior 0. 0.; 0. var_prior 0.; 0. 0. var_prior])
w11 ~ MvNormal([0.; 0.], [var_prior 0.; 0. var_prior])
Expand All @@ -121,7 +129,7 @@ using Turing
end

# Sampling
chain = sample(rng, bnn(ts), HMC(0.1, 5; adtype=adbackend), 10)
chain = sample(rng, bnn(ts, var_prior), HMC(0.1, 5; adtype=adbackend), 10)
end

@testset "hmcda inference" begin
Expand Down
13 changes: 11 additions & 2 deletions test/mcmc/sghmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,22 @@ module SGHMCTests
using ..Models: gdemo_default
using ..NumericalTests: check_gdemo
using Distributions: sample
import Enzyme
import ForwardDiff
using LinearAlgebra: dot
import ReverseDiff
using StableRNGs: StableRNG
using Test: @test, @testset
using Turing

@testset "Testing sghmc.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(; compile=false))
# Disable Enzyme warnings
Enzyme.API.typeWarning!(false)

# Enable runtime activity (workaround)
Enzyme.API.runtimeActivity!(true)

# @testset "Testing sghmc.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(; compile=false))
@testset "Testing sghmc.jl with $adbackend" for adbackend in (AutoEnzyme(),)
@testset "sghmc constructor" begin
alg = SGHMC(; learning_rate=0.01, momentum_decay=0.1, adtype=adbackend)
@test alg isa SGHMC
Expand All @@ -36,7 +44,8 @@ using Turing
end
end

@testset "Testing sgld.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(; compile=false))
# @testset "Testing sgld.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(; compile=false))
@testset "Testing sgld.jl with $adbackend" for adbackend in (AutoEnzyme(),)
@testset "sgld constructor" begin
alg = SGLD(; stepsize=PolynomialStepsize(0.25), adtype=adbackend)
@test alg isa SGLD
Expand Down
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import Pkg
Pkg.add(Pkg.PackageSpec(; url="https://github.com/simsurace/Enzyme.jl.git", rev="fix-cholesky"))
yebai marked this conversation as resolved.
Show resolved Hide resolved

include("test_utils/SelectiveTests.jl")
using .SelectiveTests: isincluded, parse_args
using Pkg
Expand Down
Loading