diff --git a/R/samplers.R b/R/samplers.R index b1591e82..f5771953 100644 --- a/R/samplers.R +++ b/R/samplers.R @@ -309,3 +309,69 @@ slice_sampler <- R6Class( } ) ) + +adaptive_hmc_sampler <- R6Class( + "adaptive_hmc_sampler", + inherit = sampler, + public = list( + parameters = list( + # Lmin = 10, + # Lmax = 20, + max_leapfrog_steps = 1000, + # TODO clean up these parameter usage else where + # epsilon = 0.005, + # diag_sd = 1, + # TODO some kind of validity check of method? Currently this can only be + # "SNAPER". + method = "SNAPER" + ), + accept_target = 0.651, + + define_tf_kernel = function(sampler_param_vec) { + dag <- self$model$dag + tfe <- dag$tf_environment + + free_state_size <- length(sampler_param_vec) - 2 + + adaptive_hmc_max_leapfrog_steps <- tf$cast( + x = sampler_param_vec[0], + dtype = tf$int32 + ) + # TODO pipe that in properly + n_warmup <- sampler_param_vec[1] + # adaptive_hmc_epsilon <- sampler_param_vec[1] + # adaptive_hmc_diag_sd <- sampler_param_vec[2:(1+free_state_size)] + + kernel_base <- tfp$experimental$mcmc$SNAPERHamiltonianMonteCarlo( + target_log_prob_fn = dag$tf_log_prob_function_adjusted, + step_size = 1, + num_adaptation_steps = as.integer(self$warmup), + max_leapfrog_steps = adaptive_hmc_max_leapfrog_steps + ) + + sampler_kernel <- tfp$mcmc$DualAveragingStepSizeAdaptation( + inner_kernel = kernel_base, + num_adaptation_steps = as.integer(self$warmup) + ) + + return( + sampler_kernel + ) + }, + sampler_parameter_values = function() { + # random number of integration steps + max_leapfrog_steps <- self$parameters$max_leapfrog_steps + epsilon <- self$parameters$epsilon + diag_sd <- matrix(self$parameters$diag_sd) + method <- self$parameters$method + + # return named list for replacing tensors + list( + adaptive_hmc_max_leapfrog_steps = max_leapfrog_steps, + # adaptive_hmc_epsilon = epsilon, + # adaptive_hmc_diag_sd = diag_sd, + method = method + ) + } + ) +) diff --git a/R/utils.R b/R/utils.R index 679a479c..4ff9601a 100644 --- a/R/utils.R +++ b/R/utils.R @@ -1079,3 +1079,85 @@ n_warmup <- function(x) { x_info <- attr(x, "model_info") x_info$warmup } + +as_tensor_spec <- function(tensor) { + tf$TensorSpec(shape = tensor$shape, dtype = tensor$dtype) +} + +maybe_make_tensor_shape <- function(x) { + if (tf$is_tensor(x)) { + as_tensor_spec(x) + } else { + x + } +} + +# get the final model parameter state from a chain as returned in the all_states +# object from tfp$mcmc$sample_chain +get_last_state <- function(all_states) { + n_iter <- dim(all_states)[1] + tf$gather(all_states, n_iter - 1L, 0L) +} + +# find out if MCMC steps had non-finite acceptance probabilities +bad_steps <- function(kernel_results) { + log_accept_ratios <- recursive_get_log_accept_ratio(kernel_results) + !is.finite(log_accept_ratios) +} + + +# recursively extract the log accaptance ratio from the MCMC kernel +recursive_get_log_accept_ratio <- function(kernel_results) { + nm <- names(kernel_results) + if ("log_accept_ratio" %in% nm) { + log_accept_ratios <- kernel_results$log_accept_ratio + } else if ("inner_results" %in% nm) { + log_accept_ratios <- recursive_get_log_accept_ratio( + kernel_results$inner_results + ) + } else { + stop("non-standard kernel structure") + } + as.array(log_accept_ratios) +} + + +# given a warmed up sampler object, return a compiled TF function that generates +# a new burst of samples from samples from it +make_sampler_function <- function(warm_sampler) { + # make the uncompiled function (with curried arguments) + sample_raw <- function(current_state, n_samples) { + results <- tfp$mcmc$sample_chain( + # how many iterations + num_results = n_samples, + # where to start from + current_state = current_state, + # kernel + kernel = warm_sampler$kernel, + # tuned sampler settings + previous_kernel_results = warm_sampler$kernel_results, + # what to trace (nothing) + trace_fn = function(current_state, kernel_results) { + # could compute badness here to save memory? + # is.finite(kernel_results$inner_results$inner_results$inner_results$log_accept_ratio) + kernel_results + } + ) + # return the parameter states and the kernel results + list( + all_states = results$all_states, + kernel_results = results$trace + ) + } + + # compile it into a concrete function and return + sample <- tf_function( + sample_raw, + list( + as_tensorspec(warm_sampler$current_state), + tf$TensorSpec(shape = c(), dtype = tf$int32) + ) + ) + + sample +} diff --git a/tests/testthat/helpers.R b/tests/testthat/helpers.R index fe589749..5273e3d8 100644 --- a/tests/testthat/helpers.R +++ b/tests/testthat/helpers.R @@ -714,6 +714,7 @@ p_theta_greta <- function( data, p_theta, p_x_bar_theta, + # TODO note that we might want to change this to adaptive_hmc() sampler = hmc(), warmup = 1000 ) { @@ -934,6 +935,7 @@ get_distribution_name <- function(x) { check_samples <- function( x, iid_function, + # TODO note that we might want to change this to adaptive_hmc sampler = hmc(), n_effective = 3000, title = NULL, diff --git a/tests/testthat/test-adaptive-hmc.R b/tests/testthat/test-adaptive-hmc.R new file mode 100644 index 00000000..45874fa6 --- /dev/null +++ b/tests/testthat/test-adaptive-hmc.R @@ -0,0 +1,61 @@ +set.seed(2025 - 02 - 13) + +test_that("bad mcmc proposals are rejected", { + skip_if_not(check_tf_version()) + + # set up for numerical rejection of initial location + x <- rnorm(10000, 1e60, 1) + z <- normal(-1e60, 1e-60) + distribution(x) <- normal(z, 1e-60) + m <- model(z, precision = "single") + + # # catch badness in the progress bar + out <- get_output( + mcmc(m, n_samples = 10, warmup = 0, pb_update = 10) + ) + expect_match(out, "100% bad") + + expect_snapshot( + error = TRUE, + draws <- mcmc( + m, + chains = 1, + n_samples = 2, + warmup = 0, + verbose = FALSE, + sampler = adaptive_hmc(), + initial_values = initials(z = 1e120) + ) + ) + + # really bad proposals + x <- rnorm(100000, 1e120, 1) + z <- normal(-1e120, 1e-120) + distribution(x) <- normal(z, 1e-120) + m <- model(z, precision = "single") + expect_snapshot( + error = TRUE, + mcmc(m, + chains = 1, + n_samples = 1, + warmup = 0, + sampler = adaptive_hmc(), + verbose = FALSE) + ) + + # proposals that are fine, but rejected anyway + z <- normal(0, 1) + m <- model(z, precision = "single") + expect_ok(mcmc( + m, + adaptive_hmc( + epsilon = 100, + # Lmin = 1, + # Lmax = 1 + ), + chains = 1, + n_samples = 5, + warmup = 0, + verbose = FALSE + )) +}) diff --git a/tests/testthat/test_inference.R b/tests/testthat/test_inference.R index 4f883a04..5776572e 100644 --- a/tests/testthat/test_inference.R +++ b/tests/testthat/test_inference.R @@ -422,6 +422,9 @@ test_that("samplers print informatively", { expect_snapshot( hmc(Lmin = 1) ) + expect_snapshot( + adaptive_hmc(Lmin = 1) + ) # # check print sees changed parameters # out <- capture_output(hmc(Lmin = 1), TRUE) diff --git a/tests/testthat/test_posteriors_binomial.R b/tests/testthat/test_posteriors_binomial.R index 92d4a087..5772e7d7 100644 --- a/tests/testthat/test_posteriors_binomial.R +++ b/tests/testthat/test_posteriors_binomial.R @@ -29,3 +29,35 @@ test_that("posterior is correct (binomial)", { suppressWarnings(test <- ks.test(samples, comparison)) expect_gte(test$p.value, 0.01) }) + +test_that("posterior is correct (binomial) with adaptive_hmc", { + skip_if_not(check_tf_version()) + skip_on_os("windows") + # analytic solution to the posterior of the paramter of a binomial + # distribution, with uniform prior + n <- 100 + pos <- rbinom(1, n, runif(1)) + theta <- uniform(0, 1) + distribution(pos) <- binomial(n, theta) + m <- model(theta) + + draws <- get_enough_draws(m, adaptive_hmc(), 2000, verbose = FALSE) + + samples <- as.matrix(draws) + + # analytic solution to posterior is beta(1 + pos, 1 + N - pos) + shape1 <- 1 + pos + shape2 <- 1 + n - pos + + # qq plot against true quantiles + quants <- (1:99) / 100 + q_target <- qbeta(quants, shape1, shape2) + q_est <- quantile(samples, quants) + plot(q_target ~ q_est, main = "binomial posterior") + abline(0, 1) + + n_draws <- round(coda::effectiveSize(draws)) + comparison <- rbeta(n_draws, shape1, shape2) + suppressWarnings(test <- ks.test(samples, comparison)) + expect_gte(test$p.value, 0.01) +}) diff --git a/tests/testthat/test_posteriors_chi_squared.R b/tests/testthat/test_posteriors_chi_squared.R index 1332d347..63d9e842 100644 --- a/tests/testthat/test_posteriors_chi_squared.R +++ b/tests/testthat/test_posteriors_chi_squared.R @@ -19,3 +19,25 @@ test_that("samplers are unbiased for chi-squared", { expect_gte(stat$p.value, 0.01) }) + +test_that("samplers are unbiased for chi-squared", { + skip_if_not(check_tf_version()) + skip_on_os("windows") + df <- 5 + x <- chi_squared(df) + iid <- function(n) rchisq(n, df) + + chi_squared_checked <- check_samples( + x = x, + iid_function = iid, + sampler = adaptive_hmc() + ) + + # do the plotting + qqplot_checked_samples(chi_squared_checked) + + # do a formal hypothesis test + stat <- ks_test_mcmc_vs_iid(chi_squared_checked) + + expect_gte(stat$p.value, 0.01) +}) diff --git a/tests/testthat/test_posteriors_geweke.R b/tests/testthat/test_posteriors_geweke.R index 40a6cac5..2f2148ea 100644 --- a/tests/testthat/test_posteriors_geweke.R +++ b/tests/testthat/test_posteriors_geweke.R @@ -1,38 +1,35 @@ Sys.setenv("RELEASE_CANDIDATE" = "false") -test_that("samplers pass geweke tests", { +# run geweke tests on this model: +# theta ~ normal(mu1, sd1) +# x[i] ~ normal(theta, sd2) +# for i in N + +n <- 10 +mu1 <- rnorm(1, 0, 3) +sd1 <- rlnorm(1) +sd2 <- rlnorm(1) + +# prior (n draws) +p_theta <- function(n) { + rnorm(n, mu1, sd1) +} + +# likelihood +p_x_bar_theta <- function(theta) { + rnorm(n, theta, sd2) +} + +# define the greta model (single precision for slice sampler) +x <- as_data(rep(0, n)) +greta_theta <- normal(mu1, sd1) +distribution(x) <- normal(greta_theta, sd2) +model <- model(greta_theta, precision = "single") + +test_that("hmc sampler passes geweke tests", { skip_if_not(check_tf_version()) - skip_if_not_release() - # nolint start - # run geweke tests on this model: - # theta ~ normal(mu1, sd1) - # x[i] ~ normal(theta, sd2) - # for i in N - # nolint end - - n <- 10 - mu1 <- rnorm(1, 0, 3) - sd1 <- rlnorm(1) - sd2 <- rlnorm(1) - - # prior (n draws) - p_theta <- function(n) { - rnorm(n, mu1, sd1) - } - - # likelihood - p_x_bar_theta <- function(theta) { - rnorm(n, theta, sd2) - } - - # define the greta model (single precision for slice sampler) - x <- as_data(rep(0, n)) - greta_theta <- normal(mu1, sd1) - distribution(x) <- normal(greta_theta, sd2) - model <- model(greta_theta, precision = "single") - # run tests on all available samplers geweke_hmc <- check_geweke( sampler = hmc(), @@ -48,8 +45,13 @@ test_that("samplers pass geweke tests", { geweke_stat_hmc <- geweke_ks(geweke_hmc) testthat::expect_gte(geweke_stat_hmc$p.value, 0.005) +}) - geweke_hmc_rwmh <- check_geweke( +test_that("rwmh sampler passes geweke tests", { + skip_if_not(check_tf_version()) + skip_if_not_release() + + geweke_rwmh <- check_geweke( sampler = rwmh(), model = model, data = x, @@ -59,13 +61,19 @@ test_that("samplers pass geweke tests", { thin = 5 ) - geweke_qq(geweke_hmc_rwmh, title = "RWMH Geweke test") + geweke_qq(geweke_rwmh, title = "RWMH Geweke test") - geweke_stat_rwmh <- geweke_ks(geweke_hmc_rwmh) + geweke_stat_rwmh <- geweke_ks(geweke_rwmh) testthat::expect_gte(geweke_stat_rwmh$p.value, 0.005) +}) - geweke_hmc_slice <- check_geweke( +test_that("slice sampler passes geweke tests", { + skip_if_not(check_tf_version()) + + skip_if_not_release() + + geweke_slice <- check_geweke( sampler = slice(), model = model, data = x, @@ -73,7 +81,27 @@ test_that("samplers pass geweke tests", { p_x_bar_theta = p_x_bar_theta ) - geweke_qq(geweke_hmc_slice, title = "slice sampler Geweke test") + geweke_qq(geweke_slice, title = "slice sampler Geweke test") + + geweke_stat_slice <- geweke_ks(geweke_slice) + + testthat::expect_gte(geweke_stat_slice$p.value, 0.005) +}) + +test_that("adaptive hmc sampler passes geweke tests", { + skip_if_not(check_tf_version()) + + skip_if_not_release() + + geweke_adaptive_hmc <- check_geweke( + sampler = adaptive_hmc(), + model = model, + data = x, + p_theta = p_theta, + p_x_bar_theta = p_x_bar_theta + ) + + geweke_qq(geweke_adaptive_hmc, title = "adaptive hmc sampler Geweke test") - testthat::expect_gte(geweke_hmc_slice$p.value, 0.005) + testthat::expect_gte(geweke_adaptive_hmc$p.value, 0.005) })