Skip to content

Commit cdbde76

Browse files
committed
Add Enzyme to AD tests
1 parent 7060896 commit cdbde76

File tree

3 files changed

+54
-49
lines changed

3 files changed

+54
-49
lines changed

test/Project.toml

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
1111
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1212
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
1313
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
14+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
1415
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
1516
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1617
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"

test/ad.jl

+8-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
using DynamicPPL: LogDensityFunction
2+
using Enzyme: Enzyme
3+
using EnzymeCore: set_runtime_activity, Forward, Reverse
24

3-
@testset "Automatic differentiation" begin
5+
@testset verbose = true "Automatic differentiation" begin
46
@testset "Unsupported backends" begin
57
@model demo() = x ~ Normal()
68
@test_logs (:warn, r"not officially supported") LogDensityFunction(
@@ -23,9 +25,11 @@ using DynamicPPL: LogDensityFunction
2325
ref_logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ref_ldf, x)
2426

2527
@testset "$adtype" for adtype in [
26-
AutoReverseDiff(; compile=false),
27-
AutoReverseDiff(; compile=true),
28-
AutoMooncake(; config=nothing),
28+
AutoEnzyme(; mode=set_runtime_activity(Forward, true)),
29+
AutoEnzyme(; mode=set_runtime_activity(Reverse, true)),
30+
# AutoReverseDiff(; compile=false),
31+
# AutoReverseDiff(; compile=true),
32+
# AutoMooncake(; config=nothing),
2933
]
3034
@info "Testing AD on: $(m.f) - $(short_varinfo_name(varinfo)) - $adtype"
3135

test/runtests.jl

+45-45
Original file line numberDiff line numberDiff line change
@@ -45,56 +45,56 @@ include("test_util.jl")
4545
# groups are chosen to make both groups take roughly the same amount of
4646
# time, but beyond that there is no particular reason for the split.
4747
if GROUP == "All" || GROUP == "Group1"
48-
include("utils.jl")
49-
include("compiler.jl")
50-
include("varnamedvector.jl")
51-
include("varinfo.jl")
52-
include("simple_varinfo.jl")
53-
include("model.jl")
54-
include("sampler.jl")
55-
include("independence.jl")
56-
include("distribution_wrappers.jl")
57-
include("logdensityfunction.jl")
58-
include("linking.jl")
59-
include("serialization.jl")
60-
include("pointwise_logdensities.jl")
61-
include("lkj.jl")
62-
include("deprecated.jl")
48+
# include("utils.jl")
49+
# include("compiler.jl")
50+
# include("varnamedvector.jl")
51+
# include("varinfo.jl")
52+
# include("simple_varinfo.jl")
53+
# include("model.jl")
54+
# include("sampler.jl")
55+
# include("independence.jl")
56+
# include("distribution_wrappers.jl")
57+
# include("logdensityfunction.jl")
58+
# include("linking.jl")
59+
# include("serialization.jl")
60+
# include("pointwise_logdensities.jl")
61+
# include("lkj.jl")
62+
# include("deprecated.jl")
6363
end
6464

6565
if GROUP == "All" || GROUP == "Group2"
66-
include("contexts.jl")
67-
include("context_implementations.jl")
68-
include("threadsafe.jl")
69-
include("debug_utils.jl")
70-
@testset "compat" begin
71-
include(joinpath("compat", "ad.jl"))
72-
end
73-
@testset "extensions" begin
74-
include("ext/DynamicPPLMCMCChainsExt.jl")
75-
include("ext/DynamicPPLJETExt.jl")
76-
end
66+
# include("contexts.jl")
67+
# include("context_implementations.jl")
68+
# include("threadsafe.jl")
69+
# include("debug_utils.jl")
70+
# @testset "compat" begin
71+
# include(joinpath("compat", "ad.jl"))
72+
# end
73+
# @testset "extensions" begin
74+
# include("ext/DynamicPPLMCMCChainsExt.jl")
75+
# include("ext/DynamicPPLJETExt.jl")
76+
# end
7777
@testset "ad" begin
78-
include("ext/DynamicPPLMooncakeExt.jl")
78+
# include("ext/DynamicPPLMooncakeExt.jl")
7979
include("ad.jl")
8080
end
81-
@testset "prob and logprob macro" begin
82-
@test_throws ErrorException prob"..."
83-
@test_throws ErrorException logprob"..."
84-
end
85-
@testset "doctests" begin
86-
DocMeta.setdocmeta!(
87-
DynamicPPL,
88-
:DocTestSetup,
89-
:(using DynamicPPL, Distributions);
90-
recursive=true,
91-
)
92-
doctestfilters = [
93-
# Ignore the source of a warning in the doctest output, since this is dependent on host.
94-
# This is a line that starts with "└ @ " and ends with the line number.
95-
r"└ @ .+:[0-9]+",
96-
]
97-
doctest(DynamicPPL; manual=false, doctestfilters=doctestfilters)
98-
end
81+
# @testset "prob and logprob macro" begin
82+
# @test_throws ErrorException prob"..."
83+
# @test_throws ErrorException logprob"..."
84+
# end
85+
# @testset "doctests" begin
86+
# DocMeta.setdocmeta!(
87+
# DynamicPPL,
88+
# :DocTestSetup,
89+
# :(using DynamicPPL, Distributions);
90+
# recursive=true,
91+
# )
92+
# doctestfilters = [
93+
# # Ignore the source of a warning in the doctest output, since this is dependent on host.
94+
# # This is a line that starts with "└ @ " and ends with the line number.
95+
# r"└ @ .+:[0-9]+",
96+
# ]
97+
# doctest(DynamicPPL; manual=false, doctestfilters=doctestfilters)
98+
# end
9999
end
100100
end

0 commit comments

Comments
 (0)