Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MGVI and rename several things in API #457

Merged
merged 13 commits into from
Oct 23, 2024
9 changes: 8 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ Several algorithms have changed their names, but also their role:
changed (no deprecation for the parameter changes). Tuning and
sample weighting scheme selection have moved to `TransformedMCMC`.

* `PriorToGaussian` has become `PriorToNormal`.

Partial deprecations are available for the above, a lot of old code should
run more or less unchanged (with deprecation warnings). Also:

Expand All @@ -27,6 +29,9 @@ run more or less unchanged (with deprecation warnings). Also:

* `MCMCTuningAlgorithm` has been replaced by `MCMCTransformTuning`.

* The `trafo` parameter of algorithms has been renamed to `pretransform`, the
`trafo` field in algorithm results has been renamed to `f_pretransform`.


### New features

Expand All @@ -39,11 +44,13 @@ run more or less unchanged (with deprecation warnings). Also:
via tunable space transformations instead of tuning covariance matrices
in proposal distributions.

MCMC tuning has been split into proposal tuning (algorithms of type
* MCMC tuning has been split into proposal tuning (algorithms of type
`MCMCProposalTuning`) and transform turning (algorithms of type
`MCMCTransformTuning`). Proposal tuning has now a much more limited role
and often may be `NoMCMCProposalTuning()` (e.g. for `RandomWalk`).

* Added `MGVISampling` for Metric Gaussian Variational Inference.


BAT.jl v3.0.0
-------------
Expand Down
10 changes: 7 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,33 +69,36 @@ AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
Cuba = "8a292aeb-7a57-582c-b821-06e4c11590b1"
Folds = "41a02a25-b8f0-4f67-bc48-60067656b558"
HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f"
MGVI = "fdae7790-d271-4276-880d-f72bbddf129c"
NestedSamplers = "41ceaf6f-1696-4a54-9b49-2e7a9ec3782e"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999"
UltraNest = "6822f173-b0be-4018-9ee2-28bf56348d09"


[extensions]
BATAdvancedHMCExt = "AdvancedHMC"
BATCubaExt = "Cuba"
BATFoldsExt = ["Folds", "Transducers"]
BATHDF5Ext = "HDF5"
BATMGVIExt = "MGVI"
BATNestedSamplersExt = "NestedSamplers"
BATOptimExt = "Optim"
BATOptimizationExt = ["Optimization", "ADTypes"]
BATPlotsExt = "Plots"
BATUltraNestExt = "UltraNest"

[compat]
ADTypes = "0.1, 0.2, 1"
Accessors = "0.1"
Adapt = "3, 4"
ADTypes = "0.1, 0.2, 1"
AdvancedHMC = "0.5, 0.6"
AffineMaps = "0.2.3, 0.3"
ArgCheck = "1, 2.0"
ArraysOfArrays = "0.4, 0.5, 0.6"
AutoDiffOperators = "0.1.8, 0.2"
AutoDiffOperators = "0.2.1"
ChainRulesCore = "0.9.44, 0.10, 1"
ChangesOfVariables = "0.1.1"
Clustering = "0.13, 0.14, 0.15"
Expand Down Expand Up @@ -132,13 +135,14 @@ MacroTools = "0.5"
Markdown = "1"
MeasureBase = "0.12, 0.13, 0.14"
Measurements = "2"
MGVI = "0.4"
NamedArrays = "0.9, 0.10"
NestedSamplers = "0.8"
Optim = "0.19,0.20, 0.21, 0.22, 1"
Optimization = "3, 4"
PDMats = "0.9, 0.10, 0.11"
ParallelProcessingTools = "0.4"
Parameters = "0.12, 0.13"
PDMats = "0.9, 0.10, 0.11"
Plots = "1"
PositiveFactorizations = "0.2"
Printf = "1"
Expand Down
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
MGVI = "fdae7790-d271-4276-880d-f72bbddf129c"
NestedSamplers = "41ceaf6f-1696-4a54-9b49-2e7a9ec3782e"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Expand Down
4 changes: 2 additions & 2 deletions docs/src/list_of_algorithms.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ BAT sampling algorithm type: [`TransformedMCMC`](@ref), MCMC algorithm subtype:

```julia
import AdvancedHMC, ForwardDiff
set_batcontext(ad = ADSelector(ForwardDiff))
set_batcontext(ad = ForwardDiff)
bat_sample(target, TransformedMCMC(mcalg = HamiltonianMC()))
```
Requires the [AdvancedHMC](https://github.com/TuringLang/AdvancedHMC.jl) Julia package to be loaded explicitly.
Expand Down Expand Up @@ -160,7 +160,7 @@ using Optim
bat_findmode(target, OptimAlg(optalg = Optim.NelderMead()))

import ForwardDiff
set_batcontext(ad = ADSelector(ForwardDiff))
set_batcontext(ad = ForwardDiff)
bat_findmode(target, OptimAlg(optalg = Optim.LBFGS()))
```

Expand Down
6 changes: 5 additions & 1 deletion docs/src/stable_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ bat_transform

get_batcontext
set_batcontext
log_batdebug

distbind
distprod
Expand All @@ -69,6 +70,7 @@ EffSampleSizeAlgorithm
EffSampleSizeFromAC
EvaluatedMeasure
ExplicitInit
FixedMGVISchedule
FixedNBins
FreedmanDiaconisBinning
GelmanRubinConvergence
Expand All @@ -90,6 +92,7 @@ MCMCInitAlgorithm
MCMCMultiCycleBurnin
MCMCProposalTuning
MCMCTransformTuning
MGVISampling
ModeAsDefined
NoMCMCProposalTuning
NoMCMCTransformTuning
Expand All @@ -98,7 +101,7 @@ OptimizationAlg
OrderedResampling
PosteriorMeasure
PriorSubstitution
PriorToGaussian
PriorToNormal
PriorToUniform
RAMTuning
RandomWalk
Expand All @@ -120,4 +123,5 @@ BAT.AbstractMedianEstimator
BAT.AbstractModeEstimator
BAT.AbstractSamplingAlgorithm
BAT.ConvergenceTest
BAT.MGVISchedule
```
6 changes: 3 additions & 3 deletions docs/src/tutorial_lit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -222,10 +222,10 @@ posterior = PosteriorMeasure(likelihood, prior)
# possible parameter values for the histogram fit.
#
# To increase the verbosity level of BAT logging output, you may want to set
# the Julia logging level for BAT to debug via `ENV["JULIA_DEBUG"] = "BAT"`.
# the Julia logging level for BAT to debug via `bat_logdebug()`.

#nb ENV["JULIA_DEBUG"] = "BAT"
#jl ENV["JULIA_DEBUG"] = "BAT"
#nb bat_logdebug()
#jl bat_logdebug()

# Now we can generate a set of MCMC samples via [`bat_sample`](@ref). We'll
# use 4 MCMC chains with 10^5 MC steps in each chain (after tuning/burn-in):
Expand Down
2 changes: 1 addition & 1 deletion examples/benchmarks/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ using BAT, ValueShapes, IntervalSets, Distributions, Plots, EmpiricalDistributio
using BATTestCases
using AHMI
using StatsBase, ArraysOfArrays, LinearAlgebra, LaTeXStrings, QuadGK, PrettyTables, HypothesisTests, Statistics
ENV["JULIA_DEBUG"] = "BAT"
bat_logdebug()

function setup_benchmark()
if(!(("plots1D" in readdir()) && ("plots2D" in readdir()) && ("results" in readdir())))
Expand Down
8 changes: 4 additions & 4 deletions examples/benchmarks/run_benchmark_ND.jl
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ function run_ND_benchmark(;
dis,
TransformedMCMC(
mcalg = algorithm,
trafo = DoNotTransform(),
pretransform = DoNotTransform(),
nchains = nchains,
nsteps = nsteps,
init = init,
Expand All @@ -262,7 +262,7 @@ function run_ND_benchmark(;
elseif isa(algorithm,BAT.HamiltonianMC)
mcmc_sample = bat_sample(
dis,
TransformedMCMC(mcalg = algorithm, trafo = DoNotTransform(), nchains = nchains, nsteps = nsteps)
TransformedMCMC(mcalg = algorithm, pretransform = DoNotTransform(), nchains = nchains, nsteps = nsteps)
).result
end
taf = time()
Expand All @@ -276,7 +276,7 @@ function run_ND_benchmark(;
dis,
TransformedMCMC(
mcalg = algorithm,
trafo = DoNotTransform(),
pretransform = DoNotTransform(),
nchains = nchains,
nsteps = nsteps,
init = init,
Expand All @@ -289,7 +289,7 @@ function run_ND_benchmark(;
elseif isa(algorithm,BAT.HamiltonianMC)
bat_sample(
dis,
TransformedMCMC(mcalg = algorithm, trafo = DoNotTransform(), nchains = nchains, nsteps = nsteps)
TransformedMCMC(mcalg = algorithm, pretransform = DoNotTransform(), nchains = nchains, nsteps = nsteps)
).result
end
taf = time()
Expand Down
8 changes: 4 additions & 4 deletions examples/benchmarks/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -207,10 +207,10 @@ function run1D(
)

sample_stats_all = []
samples, chains = bat_sample(testfunctions[key].posterior, TransformedMCMC(mcalg = algorithm, trafo = DoNotTransform(), nchains = nchains, nsteps = nsteps))
samples, chains = bat_sample(testfunctions[key].posterior, TransformedMCMC(mcalg = algorithm, pretransform = DoNotTransform(), nchains = nchains, nsteps = nsteps))
for i in 1:n_runs
time_before = time()
samples, chains = bat_sample(testfunctions[key].posterior, TransformedMCMC(mcalg = algorithm, trafo = DoNotTransform(), nchains = nchains, nsteps = nsteps))
samples, chains = bat_sample(testfunctions[key].posterior, TransformedMCMC(mcalg = algorithm, pretransform = DoNotTransform(), nchains = nchains, nsteps = nsteps))
time_after = time()

h = plot1D(samples,testfunctions,key,sample_stats)# posterior, key, analytical_stats,sample_stats)
Expand Down Expand Up @@ -438,10 +438,10 @@ function run2D(

sample_stats_all = []

samples, stats = bat_sample(testfunctions[key].posterior, TransformedMCMC(mcalg = algorithm, trafo = DoNotTransform(), nchains = nchains, nsteps = nsteps))
samples, stats = bat_sample(testfunctions[key].posterior, TransformedMCMC(mcalg = algorithm, pretransform = DoNotTransform(), nchains = nchains, nsteps = nsteps))
for i in 1:n_runs
time_before = time()
samples, stats = bat_sample(testfunctions[key].posterior, TransformedMCMC(mcalg = algorithm, trafo = DoNotTransform(), nchains = nchains, nsteps = nsteps))
samples, stats = bat_sample(testfunctions[key].posterior, TransformedMCMC(mcalg = algorithm, pretransform = DoNotTransform(), nchains = nchains, nsteps = nsteps))
time_after = time()

h = plot2D(samples, testfunctions, key, sample_stats)
Expand Down
8 changes: 3 additions & 5 deletions examples/dev-internal/test_findmode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ my_function(; nt...)

context = get_batcontext()
target = posterior
transformed_density, trafo = BAT.transform_and_unshape(PriorToGaussian(), target, context)
inv_trafo = inverse(trafo)
initalg = BAT.apply_trafo_to_init(trafo, InitFromTarget())
transformed_density, f_transform = BAT.transform_and_unshape(PriorToNormal(), target, context)
inv_trafo = inverse(f_transform)
initalg = BAT.apply_trafo_to_init(f_transform, InitFromTarget())
x_init = collect(bat_initval(transformed_density, initalg, context).result)

f = fchain(inv_trafo, logdensityof(target), -)
Expand All @@ -83,8 +83,6 @@ rosenbrock(x, p) = (p[1] - x[1])^2 + p[2] * (x[2] - x[1]^2)^2
f = rosenbrock


using AutoDiffOperators

b = Optimization.SciMLBase.NoAD()
supertype(typeof(b))

Expand Down
3 changes: 1 addition & 2 deletions examples/paper-example/paper_example.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@ using ArraysOfArrays
using TypedTables
using CSV
import Cuba, AdvancedHMC, ForwardDiff
using AutoDiffOperators
#using AHMI

BAT.set_batcontext(ad = ADSelector(ForwardDiff))
BAT.set_batcontext(ad = ForwardDiff)


function log_pdf_poisson(λ::T, k::U) where {T<:Real,U<:Real}
Expand Down
2 changes: 1 addition & 1 deletion ext/BATCubaExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ end

function BAT.bat_integrate_impl(target::MeasureLike, algorithm::CubaIntegration, context::BATContext)
measure = batmeasure(target)
transformed_measure, _ = transform_and_unshape(algorithm.trafo, measure, context)
transformed_measure, _ = transform_and_unshape(algorithm.pretransform, measure, context)

if !BAT.has_uhc_support(transformed_measure)
throw(ArgumentError("CUBA integration requires measures are supported only on the unit hypercube"))
Expand Down
Loading
Loading