Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DynamicPPL -> 0.29; Julia -> 1.10; Tapir -> Mooncake #2341

Draft
wants to merge 26 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
b36c8b2
CompatHelper: bump compat for DynamicPPL to 0.29, (keep existing comp…
github-actions[bot] Sep 25, 2024
3568649
CompatHelper: bump compat for DynamicPPL to 0.29 for package test, (k…
github-actions[bot] Sep 25, 2024
f43e57d
Replace vectorize(d, r) -> DynamicPPL.tovec(r)
penelopeysm Sep 25, 2024
c8e2337
Require Julia >= 1.9
penelopeysm Sep 26, 2024
97d6869
Add Julia 1.9 tests back
penelopeysm Oct 2, 2024
aac4628
Fix essential/ad tests
penelopeysm Oct 2, 2024
36e4651
Fix reconstruct calls in MH
penelopeysm Oct 3, 2024
2995791
Update src/mcmc/mh.jl
penelopeysm Oct 3, 2024
8387f50
Require Julia 1.10
penelopeysm Oct 3, 2024
410d98e
Change 1 -> I in MvNormal()
penelopeysm Oct 3, 2024
ed17b7e
Simplify tests as we no longer support Julia <1.8
penelopeysm Oct 3, 2024
97ed363
Simplify `set_namedtuple!`
penelopeysm Oct 3, 2024
b9b68c4
Remove conditional loading/exporting of Tapir
penelopeysm Oct 3, 2024
f42d3d8
Tapir -> Mooncake
penelopeysm Oct 3, 2024
7b22570
Update src/essential/Essential.jl
penelopeysm Oct 3, 2024
5e56f09
Remove Requires from Project.toml
penelopeysm Oct 3, 2024
d6d0d21
Bump minor version instead
penelopeysm Oct 3, 2024
e9e20dc
Restrict ADTypes to 1.9.0 for AutoMooncake()
penelopeysm Oct 3, 2024
452d0d0
Re-enable Mooncake tests in mcmc/abstractmcmc
penelopeysm Oct 3, 2024
b0bb31e
Update the currently buggy and incorrect tilde overloads in `mh.jl` (…
torfjelde Oct 7, 2024
2e178d7
More autoformatting (#2359)
mhauru Oct 4, 2024
517e9fb
Fix bad merge
penelopeysm Oct 7, 2024
188bd80
Restrict Mooncake to >= 0.4.9
penelopeysm Oct 7, 2024
a277c29
Add MH test for LKJCholesky
penelopeysm Oct 7, 2024
0156002
Merge branch 'master' into ch
penelopeysm Oct 7, 2024
742364e
Remove Tracker
penelopeysm Oct 10, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 28 additions & 18 deletions .github/workflows/Tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,35 @@ on:

jobs:
test:
# Use matrix.test.name here to avoid it taking up the entire window width
name: test ${{matrix.test.name}} (${{ matrix.os }}, ${{ matrix.version }}, ${{ matrix.arch }}, ${{ matrix.num_threads }})
runs-on: ${{ matrix.os }}
continue-on-error: ${{ matrix.version == 'nightly' }}

strategy:
fail-fast: false
matrix:
test-args:
test:
# Run some of the slower test files individually. The last one catches everything
# not included in the others.
- "essential/ad.jl"
- "mcmc/gibbs.jl"
- "mcmc/hmc.jl"
- "mcmc/abstractmcmc.jl"
- "mcmc/Inference.jl"
- "experimental/gibbs.jl"
- "mcmc/ess.jl"
- "--skip essential/ad.jl mcmc/gibbs.jl mcmc/hmc.jl mcmc/abstractmcmc.jl mcmc/Inference.jl experimental/gibbs.jl mcmc/ess.jl"
- name: "essential/ad"
args: "essential/ad.jl"
- name: "mcmc/gibbs"
args: "mcmc/gibbs.jl"
- name: "mcmc/hmc"
args: "mcmc/hmc.jl"
- name: "mcmc/abstractmcmc"
args: "mcmc/abstractmcmc.jl"
- name: "mcmc/Inference"
args: "mcmc/Inference.jl"
- name: "experimental/gibbs"
args: "experimental/gibbs.jl"
- name: "mcmc/ess"
args: "mcmc/ess.jl"
- name: "everything else"
args: "--skip essential/ad.jl mcmc/gibbs.jl mcmc/hmc.jl mcmc/abstractmcmc.jl mcmc/Inference.jl experimental/gibbs.jl mcmc/ess.jl"
version:
- '1.7'
penelopeysm marked this conversation as resolved.
Show resolved Hide resolved
- '1.10'
- '1'
os:
- ubuntu-latest
Expand All @@ -39,7 +49,7 @@ jobs:
- 1
- 2
exclude:
# With Windows and macOS, only run Julia 1.7, x64, 2 threads. We just want to see
# With Windows and macOS, only run x64, 2 threads. We just want to see
# some combination work on OSes other than Ubuntu.
- os: windows-latest
version: '1'
Expand All @@ -53,11 +63,11 @@ jobs:
num_threads: 1
- os: macOS-latest
num_threads: 1
# It's sufficient to test x86 with one version of Julia and one thread.
- version: '1'
arch: x86
- num_threads: 2
arch: x86
# It's sufficient to test x86 with only Julia 1.10 and 1 thread.
- arch: x86
version: '1'
- arch: x86
num_threads: 2

steps:
- name: Print matrix variables
Expand All @@ -66,7 +76,7 @@ jobs:
echo "Architecture: ${{ matrix.arch }}"
echo "Julia version: ${{ matrix.version }}"
echo "Number of threads: ${{ matrix.num_threads }}"
echo "Test arguments: ${{ matrix.test-args }}"
echo "Test arguments: ${{ matrix.test.args }}"
- name: (De)activate coverage analysis
run: echo "COVERAGE=${{ matrix.version == '1' && matrix.os == 'ubuntu-latest' && matrix.num_threads == 2 }}" >> "$GITHUB_ENV"
shell: bash
Expand All @@ -81,7 +91,7 @@ jobs:
# Custom calls of Pkg.test tend to miss features such as e.g. adjustments for CompatHelper PRs
# Ref https://github.com/julia-actions/julia-runtest/pull/73
- name: Call Pkg.test
run: julia --color=yes --inline=yes --depwarn=yes --check-bounds=yes --threads=${{ matrix.num_threads }} --project=@. -e 'import Pkg; Pkg.test(; coverage=parse(Bool, ENV["COVERAGE"]), test_args=ARGS)' -- ${{ matrix.test-args }}
run: julia --color=yes --inline=yes --depwarn=yes --check-bounds=yes --threads=${{ matrix.num_threads }} --project=@. -e 'import Pkg; Pkg.test(; coverage=parse(Bool, ENV["COVERAGE"]), test_args=ARGS)' -- ${{ matrix.test.args }}
- uses: julia-actions/julia-processcoverage@v1
if: ${{ env.COVERAGE }}
- uses: codecov/codecov-action@v4
Expand Down
18 changes: 8 additions & 10 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.34.1"
version = "0.35.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -32,7 +32,6 @@ OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Expand All @@ -49,22 +48,22 @@ TuringDynamicHMCExt = "DynamicHMC"
TuringOptimExt = "Optim"

[compat]
ADTypes = "0.2, 1"
ADTypes = "1.9"
AbstractMCMC = "5.2"
Accessors = "0.1"
AdvancedHMC = "0.3.0, 0.4.0, 0.5.2, 0.6"
AdvancedMH = "0.8"
AdvancedPS = "0.6.0"
AdvancedVI = "0.2"
BangBang = "0.4"
BangBang = "0.4.2"
Bijectors = "0.13.6"
Compat = "4.15.0"
DataStructures = "0.18"
Distributions = "0.23.3, 0.24, 0.25"
DistributionsAD = "0.6"
DocStringExtensions = "0.8, 0.9"
DynamicHMC = "3.4"
DynamicPPL = "0.28.2"
Compat = "4.15.0"
DynamicPPL = "0.29"
EllipticalSliceSampling = "0.5, 1, 2"
ForwardDiff = "0.10.3"
Libtask = "0.7, 0.8"
Expand All @@ -73,21 +72,20 @@ LogDensityProblems = "2"
LogDensityProblemsAD = "1.7.0"
MCMCChains = "5, 6"
NamedArrays = "0.9, 0.10"
Optim = "1"
Optimization = "3"
OptimizationOptimJL = "0.1, 0.2, 0.3"
OrderedCollections = "1"
Printf = "1"
Random = "1"
Optim = "1"
Reexport = "0.2, 1"
Requires = "0.5, 1.0"
SciMLBase = "1.92.1, 2"
SciMLBase = "2"
SpecialFunctions = "0.7.2, 0.8, 0.9, 0.10, 1, 2"
Statistics = "1.6"
StatsAPI = "1.6"
StatsBase = "0.32, 0.33, 0.34"
StatsFuns = "0.8, 0.9, 1"
julia = "1.7"
julia = "1.10"

[extras]
DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb"
Expand Down
22 changes: 1 addition & 21 deletions src/Turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ export @model, # modelling
AutoForwardDiff, # ADTypes
AutoReverseDiff,
AutoZygote,
AutoTracker,
AutoMooncake,
setprogress!, # debugging
Flat,
FlatPos,
Expand Down Expand Up @@ -136,24 +136,4 @@ export @model, # modelling
MAP,
MLE

# AutoTapir is only supported by ADTypes v1.0 and above.
@static if VERSION >= v"1.10" && pkgversion(ADTypes) >= v"1"
export AutoTapir
end

if !isdefined(Base, :get_extension)
using Requires
end

function __init__()
@static if !isdefined(Base, :get_extension)
@require Optim = "429524aa-4258-5aef-a3af-852621145aeb" include(
"../ext/TuringOptimExt.jl"
)
@require DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb" include(
"../ext/TuringDynamicHMCExt.jl"
)
end
end

end
11 changes: 3 additions & 8 deletions src/essential/Essential.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ using Bijectors: PDMatDistribution
using AdvancedVI
using StatsFuns: logsumexp, softmax
@reexport using DynamicPPL
using ADTypes: ADTypes, AutoForwardDiff, AutoTracker, AutoReverseDiff, AutoZygote
using ADTypes:
ADTypes, AutoForwardDiff, AutoReverseDiff, AutoZygote, AutoMooncake
Comment on lines +14 to +15
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
using ADTypes:
ADTypes, AutoForwardDiff, AutoReverseDiff, AutoZygote, AutoMooncake
using ADTypes: ADTypes, AutoForwardDiff, AutoReverseDiff, AutoZygote, AutoMooncake


using AdvancedPS: AdvancedPS

Expand All @@ -20,16 +21,10 @@ include("container.jl")
export @model,
@varname,
AutoForwardDiff,
AutoTracker,
AutoZygote,
AutoReverseDiff,
AutoMooncake,
@logprob_str,
@prob_str

# AutoTapir is only supported by ADTypes v1.0 and above.
@static if VERSION >= v"1.10" && pkgversion(ADTypes) >= v"1"
using ADTypes: AutoTapir
export AutoTapir
end

end # module
1 change: 0 additions & 1 deletion src/mcmc/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ using DynamicPPL:
getlogp,
VarName,
getsym,
vectorize,
_getvns,
getdist,
Model,
Expand Down
2 changes: 1 addition & 1 deletion src/mcmc/ess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@
dist = getdist(varinfo, vn)
EllipticalSliceSampling.isgaussian(typeof(dist)) ||
error("[ESS] only supports Gaussian prior distributions")
vectorize(dist, mean(dist))
DynamicPPL.tovec(mean(dist))

Check warning on line 88 in src/mcmc/ess.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/ess.jl#L88

Added line #L88 was not covered by tests
end
return new{M,S,V,typeof(μ)}(model, sampler, varinfo, μ)
end
Expand Down
96 changes: 38 additions & 58 deletions src/mcmc/mh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -212,43 +212,20 @@
Places the values of a `NamedTuple` into the relevant places of a `VarInfo`.
"""
function set_namedtuple!(vi::DynamicPPL.VarInfoOrThreadSafeVarInfo, nt::NamedTuple)
# TODO: Replace this with something like
# for vn in keys(vi)
# vi = DynamicPPL.setindex!!(vi, get(nt, vn))
# end
for (n, vals) in pairs(nt)
vns = vi.metadata[n].vns
nvns = length(vns)

# if there is a single variable only
if nvns == 1
# assign the unpacked values
if length(vals) == 1
vi[vns[1]] = [vals[1];]
# otherwise just assign the values
else
vi[vns[1]] = [vals;]
end
# if there are multiple variables
elseif vals isa AbstractArray
nvals = length(vals)
# if values are provided as an array with a single element
if nvals == 1
# iterate over variables and unpacked values
for (vn, val) in zip(vns, vals[1])
vi[vn] = [val;]
end
# otherwise number of variables and number of values have to be equal
elseif nvals == nvns
# iterate over variables and values
for (vn, val) in zip(vns, vals)
vi[vn] = [val;]
end
else
error("Cannot assign `NamedTuple` to `VarInfo`")
end
if vals isa AbstractVector
vals = unvectorize(vals)

Check warning on line 218 in src/mcmc/mh.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/mh.jl#L217-L218

Added lines #L217 - L218 were not covered by tests
end
if length(vns) == 1

Check warning on line 220 in src/mcmc/mh.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/mh.jl#L220

Added line #L220 was not covered by tests
# Only one variable, assign the values to it
DynamicPPL.setindex!(vi, vals, vns[1])

Check warning on line 222 in src/mcmc/mh.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/mh.jl#L222

Added line #L222 was not covered by tests
penelopeysm marked this conversation as resolved.
Show resolved Hide resolved
else
error("Cannot assign `NamedTuple` to `VarInfo`")
# Spread the values across the variables
length(vns) == length(vals) || error("Unequal number of variables and values")
for (vn, val) in zip(vns, vals)
DynamicPPL.setindex!(vi, val, vn)
end

Check warning on line 228 in src/mcmc/mh.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/mh.jl#L225-L228

Added lines #L225 - L228 were not covered by tests
end
end
end
Expand Down Expand Up @@ -285,10 +262,10 @@
unvectorize(dists::AbstractVector) = length(dists) == 1 ? first(dists) : dists

# possibly unpack and reshape samples according to the prior distribution
reconstruct(dist::Distribution, val::AbstractVector) = DynamicPPL.reconstruct(dist, val)
function reconstruct(dist::AbstractVector{<:UnivariateDistribution}, val::AbstractVector)
return val
function reconstruct(dist::Distribution, val::AbstractVector)
return DynamicPPL.from_vec_transform(dist)(val)

Check warning on line 266 in src/mcmc/mh.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/mh.jl#L265-L266

Added lines #L265 - L266 were not covered by tests
end
reconstruct(dist::AbstractVector{<:UnivariateDistribution}, val::AbstractVector) = val

Check warning on line 268 in src/mcmc/mh.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/mh.jl#L268

Added line #L268 was not covered by tests
function reconstruct(dist::AbstractVector{<:MultivariateDistribution}, val::AbstractVector)
offset = 0
return map(dist) do d
Expand Down Expand Up @@ -322,7 +299,7 @@
:(
$name = reconstruct(
unvectorize(DynamicPPL.getdist.(Ref(vi), vns.$name)),
DynamicPPL.getval(vi, vns.$name),
DynamicPPL.getindex_internal(vi, vns.$name),
)
) for name in names
]
Expand Down Expand Up @@ -465,42 +442,45 @@
####
#### Compiler interface, i.e. tilde operators.
####
function DynamicPPL.assume(rng, spl::Sampler{<:MH}, dist::Distribution, vn::VarName, vi)
function DynamicPPL.assume(

Check warning on line 445 in src/mcmc/mh.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/mh.jl#L445

Added line #L445 was not covered by tests
rng::Random.AbstractRNG, spl::Sampler{<:MH}, dist::Distribution, vn::VarName, vi
)
# Just defer to `SampleFromPrior`.
retval = DynamicPPL.assume(rng, SampleFromPrior(), dist, vn, vi)

Check warning on line 449 in src/mcmc/mh.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/mh.jl#L449

Added line #L449 was not covered by tests
# Update the Gibbs IDs because they might have been assigned in the `SampleFromPrior` call.
DynamicPPL.updategid!(vi, vn, spl)
r = vi[vn]
return r, logpdf_with_trans(dist, r, istrans(vi, vn)), vi
# Return.
return retval

Check warning on line 453 in src/mcmc/mh.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/mh.jl#L453

Added line #L453 was not covered by tests
end

function DynamicPPL.dot_assume(
rng,
spl::Sampler{<:MH},
dist::MultivariateDistribution,
vn::VarName,
vns::AbstractVector{<:VarName},
var::AbstractMatrix,
vi,
vi::AbstractVarInfo,
)
@assert dim(dist) == size(var, 1)
getvn = i -> VarName(vn, vn.indexing * "[:,$i]")
vns = getvn.(1:size(var, 2))
DynamicPPL.updategid!.(Ref(vi), vns, Ref(spl))
r = vi[vns]
var .= r
return var, sum(logpdf_with_trans(dist, r, istrans(vi, vns[1]))), vi
# Just defer to `SampleFromPrior`.
retval = DynamicPPL.dot_assume(rng, SampleFromPrior(), dist, vns[1], var, vi)

Check warning on line 465 in src/mcmc/mh.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/mh.jl#L465

Added line #L465 was not covered by tests
# Update the Gibbs IDs because they might have been assigned in the `SampleFromPrior` call.
DynamicPPL.updategid!.((vi,), vns, (spl,))

Check warning on line 467 in src/mcmc/mh.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/mh.jl#L467

Added line #L467 was not covered by tests
# Return.
return retval

Check warning on line 469 in src/mcmc/mh.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/mh.jl#L469

Added line #L469 was not covered by tests
end
function DynamicPPL.dot_assume(
rng,
spl::Sampler{<:MH},
dists::Union{Distribution,AbstractArray{<:Distribution}},
vn::VarName,
vns::AbstractArray{<:VarName},
var::AbstractArray,
vi,
vi::AbstractVarInfo,
)
getvn = ind -> VarName(vn, vn.indexing * "[" * join(Tuple(ind), ",") * "]")
vns = getvn.(CartesianIndices(var))
DynamicPPL.updategid!.(Ref(vi), vns, Ref(spl))
r = reshape(vi[vec(vns)], size(var))
var .= r
return var, sum(logpdf_with_trans.(dists, r, istrans(vi, vns[1]))), vi
# Just defer to `SampleFromPrior`.
retval = DynamicPPL.dot_assume(rng, SampleFromPrior(), dists, vns, var, vi)

Check warning on line 480 in src/mcmc/mh.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/mh.jl#L480

Added line #L480 was not covered by tests
# Update the Gibbs IDs because they might have been assigned in the `SampleFromPrior` call.
DynamicPPL.updategid!.((vi,), vns, (spl,))
return retval

Check warning on line 483 in src/mcmc/mh.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/mh.jl#L482-L483

Added lines #L482 - L483 were not covered by tests
end

function DynamicPPL.observe(spl::Sampler{<:MH}, d::Distribution, value, vi)
Expand Down
2 changes: 1 addition & 1 deletion src/mcmc/particle_mcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@
elseif is_flagged(vi, vn, "del")
unset_flag!(vi, vn, "del") # Reference particle parent
r = rand(trng, dist)
vi[vn] = vectorize(dist, r)
vi[vn] = DynamicPPL.tovec(r)

Check warning on line 383 in src/mcmc/particle_mcmc.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/particle_mcmc.jl#L383

Added line #L383 was not covered by tests
DynamicPPL.setgid!(vi, spl.selector, vn)
setorder!(vi, vn, get_num_produce(vi))
else
Expand Down
Loading
Loading