Skip to content

Commit 8f8018a

Browse files
committed
Add ForwardDiffExt tests back
1 parent 11d0e8a commit 8f8018a

File tree

2 files changed

+33
-0
lines changed

2 files changed

+33
-0
lines changed

test/ext/DynamicPPLForwardDiffExt.jl

+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
module DynamicPPLForwardDiffExtTests
2+
3+
using DynamicPPL
4+
using ADTypes: AutoForwardDiff
5+
using ForwardDiff: ForwardDiff
6+
using Distributions: MvNormal
7+
using LinearAlgebra: I
8+
using Test: @test, @testset
9+
10+
# get_chunksize(ad::AutoForwardDiff{chunk}) where {chunk} = chunk
11+
12+
@testset "ForwardDiff tweak_adtype" begin
13+
MODEL_SIZE = 10
14+
@model f() = x ~ MvNormal(zeros(MODEL_SIZE), I)
15+
model = f()
16+
varinfo = VarInfo(model)
17+
context = DefaultContext()
18+
19+
@testset "Chunk size setting" for chunksize in (nothing, 0)
20+
base_adtype = AutoForwardDiff(; chunksize=chunksize)
21+
new_adtype = DynamicPPL.tweak_adtype(base_adtype, model, varinfo, context)
22+
@test new_adtype isa AutoForwardDiff{MODEL_SIZE}
23+
end
24+
25+
@testset "Tag setting" begin
26+
base_adtype = AutoForwardDiff()
27+
new_adtype = DynamicPPL.tweak_adtype(base_adtype, model, varinfo, context)
28+
@test new_adtype.tag isa ForwardDiff.Tag{DynamicPPL.DynamicPPLTag}
29+
end
30+
end
31+
32+
end

test/runtests.jl

+1
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ include("test_util.jl")
7575
include("ext/DynamicPPLJETExt.jl")
7676
end
7777
@testset "ad" begin
78+
include("ext/DynamicPPLForwardDiffExt.jl")
7879
include("ext/DynamicPPLMooncakeExt.jl")
7980
include("ad.jl")
8081
end

0 commit comments

Comments
 (0)