Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

first draft of implementing the bones/structure of adaptive_hmc #766

Draft
wants to merge 11 commits into
base: master
Choose a base branch
from
Draft
269 changes: 213 additions & 56 deletions R/inference_class.R

Large diffs are not rendered by default.

41 changes: 41 additions & 0 deletions R/samplers.R
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,47 @@ slice <- function(max_doublings = 5) {
obj
}

#' @rdname samplers
#' @export
#'
#' @param epsilon leapfrog stepsize hyperparameter (positive, will be tuned)
#' @param diag_sd estimate of the posterior marginal standard deviations
#' (positive, will be tuned).
#' @param max_leapfrog_steps numeric. Default 1000. Maximum number of leapfrog
#' steps used. The algorithm will determine the optimal number less than this.
#' @param method character length one. Currently can only be "SNAPER" but in
#' the future this may expand to other adaptive samplers.
#' @details For `adaptive_hmc()`. The Lmin and Lmax parameters are learnt and so
#' not provided in this. The number of chains cannot be less than 2, due to
#' how adaptive HMC works. `diag_sd` is used to rescale the parameter space to
#' make it more uniform, and make sampling more efficient.
adaptive_hmc <- function(
max_leapfrog_steps = 1000,
epsilon = 0.1,
diag_sd = 1,
method = "SNAPER"
) {

method <- rlang::arg_match(
arg = method,
values = "SNAPER"
)

# nolint end
obj <- list(
parameters = list(
max_leapfrog_steps = max_leapfrog_steps,
epsilon = epsilon,
diag_sd = diag_sd
),
class = adaptive_hmc_sampler
)
class(obj) <- c("adaptive_hmc_sampler", "sampler")
obj
}



#' @noRd
#' @export
print.sampler <- function(x, ...) {
Expand Down
84 changes: 84 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -1003,3 +1003,87 @@ n_warmup <- function(x){
x_info <- attr(x, "model_info")
x_info$warmup
}

build_tensor_spec <- function(tensor){
tf$TensorSpec(shape = tensor$shape,
dtype = tensor$dtype)
}

maybe_make_tensor_shape <- function(x){
if (tf$is_tensor(x)) {
build_tensor_spec(x)
} else{
x
}
}

# get the final model parameter state from a chain as returned in the all_states
# object from tfp$mcmc$sample_chain
get_last_state <- function(all_states) {
n_iter <- dim(all_states)[1]
tf$gather(all_states, n_iter - 1L, 0L)
}

# find out if MCMC steps had non-finite acceptance probabilities
bad_steps <- function(kernel_results) {
log_accept_ratios <- recursive_get_log_accept_ratio(kernel_results)
!is.finite(log_accept_ratios)
}


# recursively extract the log accaptance ratio from the MCMC kernel
recursive_get_log_accept_ratio <- function(kernel_results) {
nm <- names(kernel_results)
if("log_accept_ratio" %in% nm) {
log_accept_ratios <- kernel_results$log_accept_ratio
} else if ("inner_results" %in% nm) {
log_accept_ratios <- recursive_get_log_accept_ratio(
kernel_results$inner_results
)
} else {
stop("non-standard kernel structure")
}
as.array(log_accept_ratios)
}


# given a warmed up sampler object, return a compiled TF function that generates
# a new burst of samples from samples from it
make_sampler_function <- function(warm_sampler) {

# make the uncompiled function (with curried arguments)
sample_raw <- function(current_state, n_samples) {
results <- tfp$mcmc$sample_chain(
# how many iterations
num_results = n_samples,
# where to start from
current_state = current_state,
# kernel
kernel = warm_sampler$kernel,
# tuned sampler settings
previous_kernel_results = warm_sampler$kernel_results,
# what to trace (nothing)
trace_fn = function(current_state, kernel_results) {
# could compute badness here to save memory?
# is.finite(kernel_results$inner_results$inner_results$inner_results$log_accept_ratio)
kernel_results
}
)
# return the parameter states and the kernel results
list(
all_states = results$all_states,
kernel_results = results$trace
)
}

# compile it into a concrete function and return
sample <- tf_function(sample_raw,
list(
as_tensorspec(warm_sampler$current_state),
tf$TensorSpec(shape = c(),
dtype = tf$int32)
))

sample

}
2 changes: 2 additions & 0 deletions tests/testthat/helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,7 @@ p_theta_greta <- function(
data,
p_theta,
p_x_bar_theta,
# TODO note that we might want to change this to adaptive_hmc()
sampler = hmc(),
warmup = 1000
) {
Expand Down Expand Up @@ -926,6 +927,7 @@ get_distribution_name <- function(x){
check_samples <- function(
x,
iid_function,
# TODO note that we might want to change this to adaptive_hmc
sampler = hmc(),
n_effective = 3000,
title = NULL,
Expand Down
61 changes: 61 additions & 0 deletions tests/testthat/test-adaptive-hmc.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
set.seed(2025 - 02 - 13)

test_that("bad mcmc proposals are rejected", {
skip_if_not(check_tf_version())

# set up for numerical rejection of initial location
x <- rnorm(10000, 1e60, 1)
z <- normal(-1e60, 1e-60)
distribution(x) <- normal(z, 1e-60)
m <- model(z, precision = "single")

# # catch badness in the progress bar
out <- get_output(
mcmc(m, n_samples = 10, warmup = 0, pb_update = 10)
)
expect_match(out, "100% bad")

expect_snapshot(
error = TRUE,
draws <- mcmc(
m,
chains = 1,
n_samples = 2,
warmup = 0,
verbose = FALSE,
sampler = adaptive_hmc(),
initial_values = initials(z = 1e120)
)
)

# really bad proposals
x <- rnorm(100000, 1e120, 1)
z <- normal(-1e120, 1e-120)
distribution(x) <- normal(z, 1e-120)
m <- model(z, precision = "single")
expect_snapshot(
error = TRUE,
mcmc(m,
chains = 1,
n_samples = 1,
warmup = 0,
sampler = adaptive_hmc(),
verbose = FALSE)
)

# proposals that are fine, but rejected anyway
z <- normal(0, 1)
m <- model(z, precision = "single")
expect_ok(mcmc(
m,
adaptive_hmc(
epsilon = 100,
# Lmin = 1,
# Lmax = 1
),
chains = 1,
n_samples = 5,
warmup = 0,
verbose = FALSE
))
})
3 changes: 3 additions & 0 deletions tests/testthat/test_inference.R
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,9 @@ test_that("samplers print informatively", {
expect_snapshot(
hmc(Lmin = 1)
)
expect_snapshot(
adaptive_hmc(Lmin = 1)
)

# # check print sees changed parameters
# out <- capture_output(hmc(Lmin = 1), TRUE)
Expand Down
34 changes: 33 additions & 1 deletion tests/testthat/test_posteriors_binomial.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
test_that("posterior is correct (binomial)", {
test_that("posterior is correct (binomial) with hmc", {
skip_if_not(check_tf_version())
skip_on_os("windows")
# analytic solution to the posterior of the paramter of a binomial
Expand Down Expand Up @@ -29,3 +29,35 @@ test_that("posterior is correct (binomial)", {
suppressWarnings(test <- ks.test(samples, comparison))
expect_gte(test$p.value, 0.01)
})

test_that("posterior is correct (binomial) with adaptive hmc", {
skip_if_not(check_tf_version())
skip_on_os("windows")
# analytic solution to the posterior of the paramter of a binomial
# distribution, with uniform prior
n <- 100
pos <- rbinom(1, n, runif(1))
theta <- uniform(0, 1)
distribution(pos) <- binomial(n, theta)
m <- model(theta)

draws <- get_enough_draws(m, adaptive_hmc(), 2000, verbose = FALSE)

samples <- as.matrix(draws)

# analytic solution to posterior is beta(1 + pos, 1 + N - pos)
shape1 <- 1 + pos
shape2 <- 1 + n - pos

# qq plot against true quantiles
quants <- (1:99) / 100
q_target <- qbeta(quants, shape1, shape2)
q_est <- quantile(samples, quants)
plot(q_target ~ q_est, main = "binomial posterior")
abline(0, 1)

n_draws <- round(coda::effectiveSize(draws))
comparison <- rbeta(n_draws, shape1, shape2)
suppressWarnings(test <- ks.test(samples, comparison))
expect_gte(test$p.value, 0.01)
})
4 changes: 4 additions & 0 deletions tests/testthat/test_posteriors_bivariate_normal.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,13 @@ test_that("samplers are unbiased for bivariate normals", {
hmc_mvn_samples <- check_mvn_samples(sampler = hmc())
expect_lte(max(hmc_mvn_samples), stats::qnorm(0.99))

adaptive_hmc_mvn_samples <- check_mvn_samples(sampler = adaptive_hmc())
expect_lte(max(adaptive_hmc_mvn_samples), stats::qnorm(0.99))

rwmh_mvn_samples <- check_mvn_samples(sampler = rwmh())
expect_lte(max(rwmh_mvn_samples), stats::qnorm(0.99))

slice_mvn_samples <- check_mvn_samples(sampler = slice())
expect_lte(max(rwmh_mvn_samples), stats::qnorm(0.99))
})

20 changes: 20 additions & 0 deletions tests/testthat/test_posteriors_chi_squared.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,23 @@ test_that("samplers are unbiased for chi-squared", {

expect_gte(stat$p.value, 0.01)
})

test_that("samplers are unbiased for chi-squared with adaptive hmc", {
skip_if_not(check_tf_version())
skip_on_os("windows")
df <- 5
x <- chi_squared(df)
iid <- function(n) rchisq(n, df)

chi_squared_checked <- check_samples(x = x,
iid_function = iid,
sampler = adaptive_hmc())

# do the plotting
qqplot_checked_samples(chi_squared_checked)

# do a formal hypothesis test
stat <- ks_test_mcmc_vs_iid(chi_squared_checked)

expect_gte(stat$p.value, 0.01)
})
Loading
Loading