Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch more fully to measures terminology #435

Merged
merged 2 commits into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ Breaking changes
likelihood = MyLikeLihood(mydata)
```

This allows for defining likelihoods without depending on BAT. Avoid creating custom subtypes of `BAT.AbstractMeasureOrDensity`.
This allows for defining likelihoods without depending on BAT.

* New behavior of `ValueShapes.NamedTupleShape` and `ValueShapes.NamedTupleDist`: Due to changes in [ValueShapes](https://github.com/oschulz/ValueShapes.jl) v0.10, `NamedTupleShape` and `NamedTupleDist` now either (by default) use `NamedTuple` or (optionally) `ValueShapes.ShapedAsNT`, but no longer a mix of them. As a result, the behavior of BAT has changed as well when using a `NamedTupleDist` as a prior. For example, `mode(samples).result` returns a `NamedTuple` now directly.

Expand Down
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ version = "3.1.2"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
AffineMaps = "2c83c9a8-abf5-4329-a0d7-deffaf474661"
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
ArraysOfArrays = "65a8f2f4-9b39-5baf-92e2-a9cc46fdf018"
Expand All @@ -19,6 +20,7 @@ Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf"
DoubleFloats = "497a8b3b-efae-58df-a0af-a86822472b78"
ElasticArrays = "fdbdab4c-e67f-52f5-8c3f-e7b388dad3d4"
EmpiricalDistributions = "0bbb1fad-0f24-45fe-94a4-415852c5cc3b"
Expand Down Expand Up @@ -85,6 +87,7 @@ BATUltraNestExt = "UltraNest"

[compat]
Accessors = "0.1"
Adapt = "3, 4"
AdvancedHMC = "0.5"
AffineMaps = "0.2.3, 0.3"
ArgCheck = "1, 2.0"
Expand All @@ -102,6 +105,7 @@ Distributed = "1"
Distributions = "0.25"
DistributionsAD = "0.5, 0.6"
DocStringExtensions = "0.8, 0.9"
DomainSets = "0.5, 0.6, 0.7"
DoubleFloats = "0.9, 1"
ElasticArrays = "1.2.3"
EmpiricalDistributions = "0.2, 0.3.1"
Expand Down Expand Up @@ -149,7 +153,7 @@ Tables = "0.2, 1.0"
Transducers = "0.4"
TypedTables = "1.2"
UltraNest = "0.1"
ValueShapes = "0.10.1"
ValueShapes = "0.11"
ZygoteRules = "0.2"
julia = "1.6"

Expand Down
6 changes: 3 additions & 3 deletions docs/src/experimental_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,25 @@ bat_compare
bat_integrated_autocorr_len
bat_marginalmode
BAT.auto_renormalize
BAT.batsampleable
BAT.BinnedModeEstimator
BAT.DistributionTransform
BAT.enable_error_log
BAT.error_log
BAT.EvalException
BAT.ext_default
BAT.get_adselector
BAT.LogUniform
BAT.PackageExtension
BAT.pkgext
BAT.set_rng
batmeasure
BridgeSampling
EllipsoidalNestedSampling
GridSampler
HierarchicalDistribution
PriorImportanceSampler
ReactiveNestedSampling
renormalize_density
SobolSampler
truncate_density
truncate_batmeasure
ValueAndThreshold
```
32 changes: 8 additions & 24 deletions docs/src/internal_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,16 @@ Order = [:macro, :function]
# Documentation

```@docs
BAT.AbstractMeasureOrDensity
BAT.AbstractProposalDist
BAT.AbstractSampleGenerator
BAT.AnyIIDSampleable
BAT.AnyMeasureOrDensity
BAT.AnySampleable
BAT.BasicMvStatistics
BAT.BATMeasure
BAT.BATPushFwdMeasure
BAT.BATPwrMeasure
BAT.BATWeightedMeasure
BAT.CholeskyPartialWhitening
BAT.CholeskyWhitening
BAT.DistLikeMeasure
BAT.DensitySampleMeasure
BAT.ENSAutoProposal
BAT.ENSBound
BAT.ENSEllipsoidBound
Expand All @@ -52,63 +52,47 @@ BAT.ENSRandomWalk
BAT.ENSSlice
BAT.ENSUniformly
BAT.FullMeasureTransform
BAT.GenericDensity
BAT.LFDensity
BAT.LFDensityWithGrad
BAT.LogDVal
BAT.MCMCIterator
BAT.MCMCSampleGenerator
BAT.MeasureLike
BAT.NoWhitening
BAT.OnlineMvCov
BAT.OnlineMvMean
BAT.OnlineUvMean
BAT.OnlineUvVar
BAT.Renormalized
BAT.SampleTransformation
BAT.StandardMvNormal
BAT.StandardMvUniform
BAT.StandardUvNormal
BAT.StandardUvUniform
BAT.StatisticalWhitening
BAT.Transformed
BAT.UnshapeTransformation
BAT.WhiteningAlgorithm


BAT.WrappedNonBATDensity

BAT.trafoof
BAT.logvalof
BAT.bat_report!
BAT.fft_autocor
BAT.fft_autocov
BAT.argchoice_msg
BAT.bat_sampler
BAT.bg_R_2sqr
BAT.checked_logdensityof
BAT.default_val_numtype
BAT.default_var_numtype
BAT.density_valtype
BAT.drop_low_weight_samples
BAT.find_marginalmodes
BAT.fromuhc
BAT.fromuhc!
BAT.fromui
BAT.get_bin_centers
BAT.getlikelihood
BAT.getprior
BAT.gr_Rsqr
BAT.is_log_zero
BAT.issymmetric_around_origin
BAT.log_volume
BAT.log_zero_density
BAT.proposal_rand!
BAT.proposaldist_logpdf
BAT.repetition_to_weights
BAT.smallest_credible_intervals
BAT.spatialvolume
BAT.sum_first_dim
BAT.supports_rand
BAT.trunc_logpdf_ratio
BAT.truncate_dist_hard
BAT.var_bounds
BAT.measure_support
```
25 changes: 8 additions & 17 deletions docs/src/tutorial_lit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,25 +127,20 @@ using BAT, DensityInterface, IntervalSets

# ### Likelihood Definition
#
# First, we need to define the likelihood (function) for our problem.
# First, we need to define the likelihood for our problem.
#
# BAT represents densities like likelihoods and priors as subtypes of
# `BAT.AbstractMeasureOrDensity`. Custom likelihood can be defined by
# creating a new subtype of `AbstractMeasureOrDensity` and by implementing (at minimum)
# `DensityInterface.logdensityof` for that type - in complex uses cases, this may
# become necessary. Typically, however, it is sufficient to define a custom
# likelihood as a simple function that returns the log-likelihood value for
# a given set of parameters. BAT will automatically convert such a
# likelihood function into a subtype of `AbstractMeasureOrDensity`.
# BAT expects likelihoods to implements the `DensityInterface` API. We
# can simply wrap a log-likelihood function with
# `DensityInterface.logfuncdensity` to make it compatible.
#
# For performance reasons, functions should [not access global variables
# directly] (https://docs.julialang.org/en/v1/manual/performance-tips/index.html#Avoid-global-variables-1).
# So we'll use an [anonymous function](https://docs.julialang.org/en/v1/manual/functions/#man-anonymous-functions-1)
# inside of a [let-statement](https://docs.julialang.org/en/v1/base/base/#let)
# to capture the value of the global variable `hist` in a local variable `h`
# (and to shorten function name `fit_function` to `f`, purely for
# convenience). `DensityInterface.logfuncdensity` turns a log-likelihood
# function into a density object.
# convenience). `DensityInterface.logfuncdensity` then turns the
# log-likelihood function into a `DensityInterface` density object.

likelihood = let h = hist, f = fit_function
## Histogram counts for each bin as an array:
Expand Down Expand Up @@ -208,12 +203,8 @@ prior = distprod(

#md nothing # hide

# In general, BAT allows instances of any subtype of `AbstractMeasureOrDensity` to
# be uses as a prior, as long as a sampler is defined for it. This way, users
# may implement complex application-specific priors. You can also
# use `convert(AbstractMeasureOrDensity, distribution)` to convert any
# continuous multivariate `Distributions.Distribution` to a
# `BAT.AbstractMeasureOrDensity` that can be used as a prior (or likelihood).
# BAT supports most `Distributions.Distribution` types, and combinations
# of them, as priors.


# ### Bayesian Model Definition
Expand Down
48 changes: 0 additions & 48 deletions examples/dev-internal/ahmi_example.jl

This file was deleted.

8 changes: 5 additions & 3 deletions ext/BATAdvancedHMCExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,26 @@ using Random
using DensityInterface
using HeterogeneousComputing, AutoDiffOperators

using BAT: AnyMeasureOrDensity, AbstractMeasureOrDensity
using BAT: MeasureLike, BATMeasure

using BAT: get_context, get_adselector, _NoADSelected
using BAT: getalgorithm, getmeasure
using BAT: getalgorithm, mcmc_target
using BAT: MCMCIterator, MCMCIteratorInfo, MCMCChainPoolInit, MCMCMultiCycleBurnin, AbstractMCMCTunerInstance
using BAT: AbstractTransformTarget
using BAT: RNGPartition, set_rng!
using BAT: mcmc_step!, nsamples, nsteps, samples_available, eff_acceptance_ratio
using BAT: get_samples!, get_mcmc_tuning, reset_rng_counters!
using BAT: tuning_init!, tuning_postinit!, tuning_reinit!, tuning_update!, tuning_finalize!, tuning_callback
using BAT: totalndof, var_bounds, checked_logdensityof
using BAT: totalndof, measure_support, checked_logdensityof
using BAT: CURRENT_SAMPLE, PROPOSED_SAMPLE, INVALID_SAMPLE, ACCEPTED_SAMPLE, REJECTED_SAMPLE

using BAT: HamiltonianMC
using BAT: AHMCSampleID, AHMCSampleIDVector
using BAT: HMCMetric, DiagEuclideanMetric, UnitEuclideanMetric, DenseEuclideanMetric
using BAT: HMCTuningAlgorithm, MassMatrixAdaptor, StepSizeAdaptor, NaiveHMCTuning, StanHMCTuning

using ValueShapes: varshape

using Accessors: @set


Expand Down
23 changes: 11 additions & 12 deletions ext/BATCubaExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ end
using BAT
BAT.pkgext(::Val{:Cuba}) = BAT.PackageExtension{:Cuba}()

using BAT: AnyMeasureOrDensity, AbstractMeasureOrDensity
using BAT: MeasureLike, BATMeasure
using BAT: CubaIntegration
using BAT: var_bounds, bat_integrate_impl
using BAT: measure_support, bat_integrate_impl
using BAT: transform_and_unshape, auto_renormalize

using Base.Threads: @threads
Expand Down Expand Up @@ -125,17 +125,16 @@ function _integrate_impl_cuba(integrand::CubaIntegrand, algorithm::CuhreIntegrat
end


function BAT.bat_integrate_impl(target::AnyMeasureOrDensity, algorithm::CubaIntegration, context::BATContext)
measure = convert(AbstractMeasureOrDensity, target)
transformed_measure, _ = transform_and_unshape(algorithm.trafo, target, context)
function BAT.bat_integrate_impl(target::MeasureLike, algorithm::CubaIntegration, context::BATContext)
measure = batmeasure(target)
transformed_measure, _ = transform_and_unshape(algorithm.trafo, measure, context)

vb = var_bounds(transformed_measure)
if !(all(isapprox(0), vb.vol.lo) && all(isapprox(1), vb.vol.hi))
throw(ArgumentError("CUBA integration requires measures that (can be converted to) have unit volume support"))
if !BAT.has_uhc_support(transformed_measure)
throw(ArgumentError("CUBA integration requires measures are supported only on the unit hypercube"))
end

renormalized_measure, logrenormf = auto_renormalize(transformed_measure)
dof = totalndof(renormalized_measure)
renormalized_measure, logweight = auto_renormalize(transformed_measure)
dof = totalndof(varshape(renormalized_measure))
integrand = CubaIntegrand(logdensityof(renormalized_measure), dof)

r_cuba = _integrate_impl_cuba(integrand, algorithm, context)
Expand All @@ -152,9 +151,9 @@ function BAT.bat_integrate_impl(target::AnyMeasureOrDensity, algorithm::CubaInte
end

(value, error) = first(r_cuba.integral), first(r_cuba.error)
rescaled_value, rescaled_error = exp(BigFloat(log(value) - logrenormf)), exp(BigFloat(log(error) - logrenormf))
rescaled_value, rescaled_error = exp(BigFloat(log(value) - logweight)), exp(BigFloat(log(error) - logweight))
result = Measurements.measurement(rescaled_value, rescaled_error)
return (result = result, logrenormf = logrenormf, cuba_result = r_cuba)
return (result = result, logweight = logweight, cuba_result = r_cuba)
end


Expand Down
26 changes: 14 additions & 12 deletions ext/BATNestedSamplersExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ using HeterogeneousComputing

BAT.pkgext(::Val{:NestedSamplers}) = BAT.PackageExtension{:NestedSamplers}()

using BAT: AnyMeasureOrDensity, AbstractMeasureOrDensity
using BAT: MeasureLike, BATMeasure
using BAT: ENSBound, ENSNoBounds, ENSEllipsoidBound, ENSMultiEllipsoidBound
using BAT: ENSProposal, ENSUniformly, ENSAutoProposal, ENSRandomWalk, ENSSlice

Expand Down Expand Up @@ -63,16 +63,18 @@ end



function BAT.bat_sample_impl(target::AnyMeasureOrDensity, algorithm::EllipsoidalNestedSampling, context::BATContext)
function BAT.bat_sample_impl(m::BATMeasure, algorithm::EllipsoidalNestedSampling, context::BATContext)
# ToDo: Forward RNG from context!
rng = get_rng(context)

density_notrafo = convert(AbstractMeasureOrDensity, target)
density, trafo = BAT.transform_and_unshape(algorithm.trafo, density_notrafo, context) # BAT prior transformation
vs = varshape(density)
dims = totalndof(vs)
transformed_m, trafo = BAT.transform_and_unshape(algorithm.trafo, m, context) # BAT prior transformation
dims = totalndof(varshape(transformed_m))

model = NestedModel(logdensityof(density), identity); # identity, because ahead the BAT prior transformation is used instead
if !BAT.has_uhc_support(transformed_m)
throw(ArgumentError("$algorithm doesn't measures that are not limited to the unit hypercube"))
end

model = NestedModel(logdensityof(transformed_m), identity); # identity, because ahead the BAT prior transformation is used instead
bounding = ENSBounding(algorithm.bound)
prop = ENSprop(algorithm.proposal)
sampler = Nested(
Expand All @@ -89,15 +91,15 @@ function BAT.bat_sample_impl(target::AnyMeasureOrDensity, algorithm::Ellipsoidal
weights = samples_w[:, end] # the last elements of the vectors are the weights
nsamples = size(samples_w,1)
samples = [samples_w[i, 1:end-1] for i in 1:nsamples] # the other ones (between 1 and end-1) are the samples
logvals = map(logdensityof(density), samples) # posterior values of the samples
samples_trafo = vs.(BAT.DensitySampleVector(samples, logvals, weight = weights))
samples_notrafo = inverse(trafo).(samples_trafo) # Here the samples are retransformed
logvals = map(logdensityof(transformed_m), samples) # posterior values of the samples
transformed_smpls = BAT.DensitySampleVector(samples, logvals, weight = weights)
smpls = inverse(trafo).(transformed_smpls) # Here the samples are retransformed

logintegral = Measurements.measurement(state.logz, state.logzerr)
ess = bat_eff_sample_size(samples_notrafo, KishESS(), context).result
ess = bat_eff_sample_size(smpls, KishESS(), context).result

return (
result = samples_notrafo, result_trafo = samples_trafo, trafo = trafo,
result = smpls, result_trafo = transformed_smpls, trafo = trafo,
logintegral = logintegral, ess = ess,
info = state
)
Expand Down
Loading
Loading