Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DynamicPPL -> 0.29; Julia -> 1.10; Tapir -> Mooncake #2341

Draft
wants to merge 26 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
b36c8b2
CompatHelper: bump compat for DynamicPPL to 0.29, (keep existing comp…
github-actions[bot] Sep 25, 2024
3568649
CompatHelper: bump compat for DynamicPPL to 0.29 for package test, (k…
github-actions[bot] Sep 25, 2024
f43e57d
Replace vectorize(d, r) -> DynamicPPL.tovec(r)
penelopeysm Sep 25, 2024
c8e2337
Require Julia >= 1.9
penelopeysm Sep 26, 2024
97d6869
Add Julia 1.9 tests back
penelopeysm Oct 2, 2024
aac4628
Fix essential/ad tests
penelopeysm Oct 2, 2024
36e4651
Fix reconstruct calls in MH
penelopeysm Oct 3, 2024
2995791
Update src/mcmc/mh.jl
penelopeysm Oct 3, 2024
8387f50
Require Julia 1.10
penelopeysm Oct 3, 2024
410d98e
Change 1 -> I in MvNormal()
penelopeysm Oct 3, 2024
ed17b7e
Simplify tests as we no longer support Julia <1.8
penelopeysm Oct 3, 2024
97ed363
Simplify `set_namedtuple!`
penelopeysm Oct 3, 2024
b9b68c4
Remove conditional loading/exporting of Tapir
penelopeysm Oct 3, 2024
f42d3d8
Tapir -> Mooncake
penelopeysm Oct 3, 2024
7b22570
Update src/essential/Essential.jl
penelopeysm Oct 3, 2024
5e56f09
Remove Requires from Project.toml
penelopeysm Oct 3, 2024
d6d0d21
Bump minor version instead
penelopeysm Oct 3, 2024
e9e20dc
Restrict ADTypes to 1.9.0 for AutoMooncake()
penelopeysm Oct 3, 2024
452d0d0
Re-enable Mooncake tests in mcmc/abstractmcmc
penelopeysm Oct 3, 2024
b0bb31e
Update the currently buggy and incorrect tilde overloads in `mh.jl` (…
torfjelde Oct 7, 2024
2e178d7
More autoformatting (#2359)
mhauru Oct 4, 2024
517e9fb
Fix bad merge
penelopeysm Oct 7, 2024
188bd80
Restrict Mooncake to >= 0.4.9
penelopeysm Oct 7, 2024
a277c29
Add MH test for LKJCholesky
penelopeysm Oct 7, 2024
0156002
Merge branch 'master' into ch
penelopeysm Oct 7, 2024
742364e
Remove Tracker
penelopeysm Oct 10, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 28 additions & 18 deletions .github/workflows/Tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,35 @@ on:

jobs:
test:
# Use matrix.test.name here to avoid it taking up the entire window width
name: test ${{matrix.test.name}} (${{ matrix.os }}, ${{ matrix.version }}, ${{ matrix.arch }}, ${{ matrix.num_threads }})
runs-on: ${{ matrix.os }}
continue-on-error: ${{ matrix.version == 'nightly' }}

strategy:
fail-fast: false
matrix:
test-args:
test:
# Run some of the slower test files individually. The last one catches everything
# not included in the others.
- "essential/ad.jl"
- "mcmc/gibbs.jl"
- "mcmc/hmc.jl"
- "mcmc/abstractmcmc.jl"
- "mcmc/Inference.jl"
- "experimental/gibbs.jl"
- "mcmc/ess.jl"
- "--skip essential/ad.jl mcmc/gibbs.jl mcmc/hmc.jl mcmc/abstractmcmc.jl mcmc/Inference.jl experimental/gibbs.jl mcmc/ess.jl"
- name: "essential/ad"
args: "essential/ad.jl"
- name: "mcmc/gibbs"
args: "mcmc/gibbs.jl"
- name: "mcmc/hmc"
args: "mcmc/hmc.jl"
- name: "mcmc/abstractmcmc"
args: "mcmc/abstractmcmc.jl"
- name: "mcmc/Inference"
args: "mcmc/Inference.jl"
- name: "experimental/gibbs"
args: "experimental/gibbs.jl"
- name: "mcmc/ess"
args: "mcmc/ess.jl"
- name: "everything else"
args: "--skip essential/ad.jl mcmc/gibbs.jl mcmc/hmc.jl mcmc/abstractmcmc.jl mcmc/Inference.jl experimental/gibbs.jl mcmc/ess.jl"
version:
- '1.7'
penelopeysm marked this conversation as resolved.
Show resolved Hide resolved
- '1.10'
- '1'
os:
- ubuntu-latest
Expand All @@ -39,7 +49,7 @@ jobs:
- 1
- 2
exclude:
# With Windows and macOS, only run Julia 1.7, x64, 2 threads. We just want to see
# With Windows and macOS, only run x64, 2 threads. We just want to see
# some combination work on OSes other than Ubuntu.
- os: windows-latest
version: '1'
Expand All @@ -53,11 +63,11 @@ jobs:
num_threads: 1
- os: macOS-latest
num_threads: 1
# It's sufficient to test x86 with one version of Julia and one thread.
- version: '1'
arch: x86
- num_threads: 2
arch: x86
# It's sufficient to test x86 with only Julia 1.10 and 1 thread.
- arch: x86
version: '1'
- arch: x86
num_threads: 2

steps:
- name: Print matrix variables
Expand All @@ -66,7 +76,7 @@ jobs:
echo "Architecture: ${{ matrix.arch }}"
echo "Julia version: ${{ matrix.version }}"
echo "Number of threads: ${{ matrix.num_threads }}"
echo "Test arguments: ${{ matrix.test-args }}"
echo "Test arguments: ${{ matrix.test.args }}"
- name: (De)activate coverage analysis
run: echo "COVERAGE=${{ matrix.version == '1' && matrix.os == 'ubuntu-latest' && matrix.num_threads == 2 }}" >> "$GITHUB_ENV"
shell: bash
Expand All @@ -81,7 +91,7 @@ jobs:
# Custom calls of Pkg.test tend to miss features such as e.g. adjustments for CompatHelper PRs
# Ref https://github.com/julia-actions/julia-runtest/pull/73
- name: Call Pkg.test
run: julia --color=yes --inline=yes --depwarn=yes --check-bounds=yes --threads=${{ matrix.num_threads }} --project=@. -e 'import Pkg; Pkg.test(; coverage=parse(Bool, ENV["COVERAGE"]), test_args=ARGS)' -- ${{ matrix.test-args }}
run: julia --color=yes --inline=yes --depwarn=yes --check-bounds=yes --threads=${{ matrix.num_threads }} --project=@. -e 'import Pkg; Pkg.test(; coverage=parse(Bool, ENV["COVERAGE"]), test_args=ARGS)' -- ${{ matrix.test.args }}
- uses: julia-actions/julia-processcoverage@v1
if: ${{ env.COVERAGE }}
- uses: codecov/codecov-action@v4
Expand Down
18 changes: 8 additions & 10 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.34.1"
version = "0.35.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -32,7 +32,6 @@ OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Expand All @@ -49,22 +48,22 @@ TuringDynamicHMCExt = "DynamicHMC"
TuringOptimExt = "Optim"

[compat]
ADTypes = "0.2, 1"
ADTypes = "1.9"
AbstractMCMC = "5.2"
Accessors = "0.1"
AdvancedHMC = "0.3.0, 0.4.0, 0.5.2, 0.6"
AdvancedMH = "0.8"
AdvancedPS = "0.6.0"
AdvancedVI = "0.2"
BangBang = "0.4"
BangBang = "0.4.2"
Bijectors = "0.13.6"
Compat = "4.15.0"
DataStructures = "0.18"
Distributions = "0.23.3, 0.24, 0.25"
DistributionsAD = "0.6"
DocStringExtensions = "0.8, 0.9"
DynamicHMC = "3.4"
DynamicPPL = "0.28.2"
Compat = "4.15.0"
DynamicPPL = "0.29"
EllipticalSliceSampling = "0.5, 1, 2"
ForwardDiff = "0.10.3"
Libtask = "0.7, 0.8"
Expand All @@ -73,21 +72,20 @@ LogDensityProblems = "2"
LogDensityProblemsAD = "1.7.0"
MCMCChains = "5, 6"
NamedArrays = "0.9, 0.10"
Optim = "1"
Optimization = "3"
OptimizationOptimJL = "0.1, 0.2, 0.3"
OrderedCollections = "1"
Printf = "1"
Random = "1"
Optim = "1"
Reexport = "0.2, 1"
Requires = "0.5, 1.0"
SciMLBase = "1.92.1, 2"
SciMLBase = "2"
SpecialFunctions = "0.7.2, 0.8, 0.9, 0.10, 1, 2"
Statistics = "1.6"
StatsAPI = "1.6"
StatsBase = "0.32, 0.33, 0.34"
StatsFuns = "0.8, 0.9, 1"
julia = "1.7"
julia = "1.10"

[extras]
DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb"
Expand Down
21 changes: 1 addition & 20 deletions src/Turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ export @model, # modelling
AutoReverseDiff,
AutoZygote,
AutoTracker,
AutoMooncake,
setprogress!, # debugging
Flat,
FlatPos,
Expand Down Expand Up @@ -136,24 +137,4 @@ export @model, # modelling
MAP,
MLE

# AutoTapir is only supported by ADTypes v1.0 and above.
@static if VERSION >= v"1.10" && pkgversion(ADTypes) >= v"1"
export AutoTapir
end

if !isdefined(Base, :get_extension)
using Requires
end

function __init__()
@static if !isdefined(Base, :get_extension)
@require Optim = "429524aa-4258-5aef-a3af-852621145aeb" include(
"../ext/TuringOptimExt.jl"
)
@require DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb" include(
"../ext/TuringDynamicHMCExt.jl"
)
end
end

end
10 changes: 3 additions & 7 deletions src/essential/Essential.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ using Bijectors: PDMatDistribution
using AdvancedVI
using StatsFuns: logsumexp, softmax
@reexport using DynamicPPL
using ADTypes: ADTypes, AutoForwardDiff, AutoTracker, AutoReverseDiff, AutoZygote
using ADTypes:
ADTypes, AutoForwardDiff, AutoTracker, AutoReverseDiff, AutoZygote, AutoMooncake

using AdvancedPS: AdvancedPS

Expand All @@ -23,13 +24,8 @@ export @model,
AutoTracker,
AutoZygote,
AutoReverseDiff,
AutoMooncake,
@logprob_str,
@prob_str

# AutoTapir is only supported by ADTypes v1.0 and above.
@static if VERSION >= v"1.10" && pkgversion(ADTypes) >= v"1"
using ADTypes: AutoTapir
export AutoTapir
end

end # module
2 changes: 1 addition & 1 deletion src/mcmc/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using DynamicPPL: Metadata, VarInfo, TypedVarInfo,
islinked, invlink!, link!,
setindex!!, push!!,
setlogp!!, getlogp,
VarName, getsym, vectorize,
VarName, getsym,
_getvns, getdist,
Model, Sampler, SampleFromPrior, SampleFromUniform,
DefaultContext, PriorContext,
Expand Down
2 changes: 1 addition & 1 deletion src/mcmc/ess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ struct ESSPrior{M<:Model,S<:Sampler{<:ESS},V<:AbstractVarInfo,T}
dist = getdist(varinfo, vn)
EllipticalSliceSampling.isgaussian(typeof(dist)) ||
error("[ESS] only supports Gaussian prior distributions")
vectorize(dist, mean(dist))
DynamicPPL.tovec(mean(dist))
end
return new{M,S,V,typeof(μ)}(model, sampler, varinfo, μ)
end
Expand Down
53 changes: 15 additions & 38 deletions src/mcmc/mh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -212,43 +212,20 @@
Places the values of a `NamedTuple` into the relevant places of a `VarInfo`.
"""
function set_namedtuple!(vi::DynamicPPL.VarInfoOrThreadSafeVarInfo, nt::NamedTuple)
# TODO: Replace this with something like
# for vn in keys(vi)
# vi = DynamicPPL.setindex!!(vi, get(nt, vn))
# end
for (n, vals) in pairs(nt)
vns = vi.metadata[n].vns
nvns = length(vns)

# if there is a single variable only
if nvns == 1
# assign the unpacked values
if length(vals) == 1
vi[vns[1]] = [vals[1];]
# otherwise just assign the values
else
vi[vns[1]] = [vals;]
end
# if there are multiple variables
elseif vals isa AbstractArray
nvals = length(vals)
# if values are provided as an array with a single element
if nvals == 1
# iterate over variables and unpacked values
for (vn, val) in zip(vns, vals[1])
vi[vn] = [val;]
end
# otherwise number of variables and number of values have to be equal
elseif nvals == nvns
# iterate over variables and values
for (vn, val) in zip(vns, vals)
vi[vn] = [val;]
end
else
error("Cannot assign `NamedTuple` to `VarInfo`")
end
if vals isa AbstractVector
vals = unvectorize(vals)

Check warning on line 218 in src/mcmc/mh.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/mh.jl#L218

Added line #L218 was not covered by tests
end
if length(vns) == 1
# Only one variable, assign the values to it
DynamicPPL.setindex!(vi, vals, vns[1])
penelopeysm marked this conversation as resolved.
Show resolved Hide resolved
else
error("Cannot assign `NamedTuple` to `VarInfo`")
# Spread the values across the variables
length(vns) == length(vals) || error("Unequal number of variables and values")
for (vn, val) in zip(vns, vals)
DynamicPPL.setindex!(vi, val, vn)
end

Check warning on line 228 in src/mcmc/mh.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/mh.jl#L225-L228

Added lines #L225 - L228 were not covered by tests
end
end
end
Expand Down Expand Up @@ -285,10 +262,10 @@
unvectorize(dists::AbstractVector) = length(dists) == 1 ? first(dists) : dists

# possibly unpack and reshape samples according to the prior distribution
reconstruct(dist::Distribution, val::AbstractVector) = DynamicPPL.reconstruct(dist, val)
function reconstruct(dist::AbstractVector{<:UnivariateDistribution}, val::AbstractVector)
return val
function reconstruct(dist::Distribution, val::AbstractVector)
return DynamicPPL.from_vec_transform(dist)(val)
end
reconstruct(dist::AbstractVector{<:UnivariateDistribution}, val::AbstractVector) = val

Check warning on line 268 in src/mcmc/mh.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/mh.jl#L268

Added line #L268 was not covered by tests
function reconstruct(dist::AbstractVector{<:MultivariateDistribution}, val::AbstractVector)
offset = 0
return map(dist) do d
Expand Down Expand Up @@ -322,7 +299,7 @@
:(
$name = reconstruct(
unvectorize(DynamicPPL.getdist.(Ref(vi), vns.$name)),
DynamicPPL.getval(vi, vns.$name),
DynamicPPL.getindex_internal(vi, vns.$name),
)
) for name in names
]
Expand Down
2 changes: 1 addition & 1 deletion src/mcmc/particle_mcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ function DynamicPPL.assume(
elseif is_flagged(vi, vn, "del")
unset_flag!(vi, vn, "del") # Reference particle parent
r = rand(trng, dist)
vi[vn] = vectorize(dist, r)
vi[vn] = DynamicPPL.tovec(r)
DynamicPPL.setgid!(vi, spl.selector, vn)
setorder!(vi, vn, get_num_produce(vi))
else
Expand Down
3 changes: 2 additions & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
penelopeysm marked this conversation as resolved.
Show resolved Hide resolved
NamedArrays = "86f7a689-2022-50b4-a561-43c23ac3c673"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
Expand Down Expand Up @@ -45,7 +46,7 @@ Clustering = "0.14, 0.15"
Distributions = "0.25"
DistributionsAD = "0.6.3"
DynamicHMC = "2.1.6, 3.0"
DynamicPPL = "0.28"
DynamicPPL = "0.29"
FiniteDifferences = "0.10.8, 0.11, 0.12"
ForwardDiff = "0.10.12 - 0.10.32, 0.10"
HypothesisTests = "0.11"
Expand Down
8 changes: 4 additions & 4 deletions test/essential/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module AdTests

using ..Models: gdemo_default
using Distributions: logpdf
using DynamicPPL: getlogp, getval
using DynamicPPL: getlogp, getindex_internal
using ForwardDiff
using LinearAlgebra
using LogDensityProblems: LogDensityProblems
Expand All @@ -24,7 +24,7 @@ function test_model_ad(model, f, syms::Vector{Symbol})
s = syms[i]
vnms[i] = getfield(vi.metadata, s).vns[1]

vals = getval(vi, vnms[i])
vals = getindex_internal(vi, vnms[i])
for i in eachindex(vals)
push!(vnvals, vals[i])
end
Expand Down Expand Up @@ -61,8 +61,8 @@ end
ad_test_f(vi, SampleFromPrior())
svn = vi.metadata.s.vns[1]
mvn = vi.metadata.m.vns[1]
_s = getval(vi, svn)[1]
_m = getval(vi, mvn)[1]
_s = getindex_internal(vi, svn)[1]
_m = getindex_internal(vi, mvn)[1]

dist_s = InverseGamma(2, 3)

Expand Down
5 changes: 2 additions & 3 deletions test/mcmc/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@ using LinearAlgebra: I
import MCMCChains
import Random
import ReverseDiff
import Mooncake
using Test: @test, @test_throws, @testset
using Turing

ADUtils.install_tapir && import Tapir

@testset "Testing inference.jl with $adbackend" for adbackend in ADUtils.adbackends
# Only test threading if 1.3+.
if VERSION > v"1.2"
Expand Down Expand Up @@ -578,7 +577,7 @@ ADUtils.install_tapir && import Tapir
); true)

@model function demo_incorrect_missing(y)
y[1:1] ~ MvNormal(zeros(1), 1)
y[1:1] ~ MvNormal(zeros(1), I)
end
@test_throws ErrorException sample(
demo_incorrect_missing([missing]), NUTS(), 1000; check_model=true
Expand Down
Loading
Loading