From 58c23a3afa5af4958d372f7edb1d8844677b3845 Mon Sep 17 00:00:00 2001 From: njtierney Date: Wed, 12 Feb 2025 17:09:04 +0800 Subject: [PATCH 01/11] first draft of implementing the bones/structure of snaper_hmc --- R/inference_class.R | 86 +++++++++++++++++++++++++++++++++++++++++++++ R/samplers.R | 33 +++++++++++++++++ 2 files changed, 119 insertions(+) diff --git a/R/inference_class.R b/R/inference_class.R index 0ff00469..3ea43919 100644 --- a/R/inference_class.R +++ b/R/inference_class.R @@ -1055,3 +1055,89 @@ slice_sampler <- R6Class( } ) ) + +snaper_hmc_sampler <- R6Class( + "snaper_hmc_sampler", + inherit = sampler, + public = list( + parameters = list( + # Lmin = 10, + # Lmax = 20, + max_leapfrog_steps = 1000, + epsilon = 0.005, + diag_sd = 1 + ), + accept_target = 0.651, + + define_tf_kernel = function(sampler_param_vec) { + + dag <- self$model$dag + tfe <- dag$tf_environment + + # TODO double check this + free_state_size <- length(sampler_param_vec) - 2 + + # TF1/2 check + # this will likely get replaced... + + s_hmc_max_leapfrog_steps <- sampler_param_vec[0] + s_hmc_epsilon <- sampler_param_vec[2] + s_hmc_diag_sd <- sampler_param_vec[2:(1+free_state_size)] + + hmc_step_sizes <- tf$cast( + x = tf$reshape( + hmc_epsilon * (s_hmc_diag_sd / tf$reduce_sum(s_hmc_diag_sd)), + shape = shape(free_state_size) + ), + dtype = tf$float64 + ) + # TF1/2 check + # where is "free_state" pulled from, given that it is the + # argument to this function, "generate_log_prob_function" ? + # log probability function + + # build the kernel + # nolint start + + # sampler_kernel <- tfp$mcmc$HamiltonianMonteCarlo( + # target_log_prob_fn = dag$tf_log_prob_function_adjusted, + # step_size = hmc_step_sizes, + # num_leapfrog_steps = hmc_l + # ) + + kernel_base <- tfp$experimental$mcmc$SNAPERHamiltonianMonteCarlo( + target_log_prob_fn = dag$tf_log_prob_function_adjusted, + # TODO do we want to use hmc_step_size or 1? + # step_size = 1, + step_size = hmc_step_sizes, + # num_adaptation_steps = as.integer(n_warmup * 0.9)) + # TODO check n_warmup? + num_adaptation_steps = as.integer(s_hmc_max_leapfrog_steps * 0.9)) + + sampler_kernel <- tfp$mcmc$DualAveragingStepSizeAdaptation( + inner_kernel = kernel_base, + # num_adaptation_steps = as.integer(n_warmup)) + # TODO check n_warmup? + num_adaptation_steps = as.integer(s_hmc_max_leapfrog_steps)) + + + return( + sampler_kernel + ) + }, + sampler_parameter_values = function() { + + # random number of integration steps + max_leapfrog_steps <- self$parameters$max_leapfrog_steps + epsilon <- self$parameters$epsilon + diag_sd <- matrix(self$parameters$diag_sd) + + # return named list for replacing tensors + list( + s_hmc_max_leapfrog_steps = max_leapfrog_steps, + s_hmc_epsilon = epsilon, + s_hmc_diag_sd = diag_sd + ) + } + ) +) diff --git a/R/samplers.R b/R/samplers.R index e83d3125..dbd051a5 100644 --- a/R/samplers.R +++ b/R/samplers.R @@ -90,6 +90,39 @@ slice <- function(max_doublings = 5) { obj } +#' @rdname samplers +#' @export +#' +#' @param epsilon leapfrog stepsize hyperparameter (positive, will be tuned) +#' @param diag_sd estimate of the posterior marginal standard deviations +#' (positive, will be tuned). +#' @param max_leapfrog_steps numeric. Default 1000. Maximum number of leapfrog +#' steps used. The algorithm will determine the optimal number less than this. +#' +#' @details For `snaper_hmc()`. The Lmin and Lmax parameters are learnt and so +#' not probivided in this. The number of chains cannot be less than 2, due to +#' how Snaper HMC works. `diag_sd` is used to rescale the parameter space to +#' make it more uniform, and make sampling more efficient. +snaper_hmc <- function( + max_leapfrog_steps = 1000, + epsilon = 0.1, + diag_sd = 1 + ) { + # nolint end + obj <- list( + parameters = list( + max_leapfrog_steps = max_leapfrog_steps, + epsilon = epsilon, + diag_sd = diag_sd + ), + class = snaper_hmc_sampler + ) + class(obj) <- c("snaper hmc sampler", "sampler") + obj +} + + + #' @noRd #' @export print.sampler <- function(x, ...) { From 57e25be7b109d29ddd771e116be5b41bd1303bf7 Mon Sep 17 00:00:00 2001 From: njtierney Date: Thu, 13 Feb 2025 14:53:02 +0800 Subject: [PATCH 02/11] Some test driven development: add adaptive_hmc tests into tests before it is all working --- tests/testthat/helpers.R | 2 + tests/testthat/test-adaptive-hmc.R | 51 +++++++++ tests/testthat/test_inference.R | 3 + tests/testthat/test_posteriors_binomial.R | 34 +++++- .../test_posteriors_bivariate_normal.R | 4 + tests/testthat/test_posteriors_chi_squared.R | 20 ++++ tests/testthat/test_posteriors_geweke.R | 100 ++++++++++++------ 7 files changed, 178 insertions(+), 36 deletions(-) create mode 100644 tests/testthat/test-adaptive-hmc.R diff --git a/tests/testthat/helpers.R b/tests/testthat/helpers.R index 4a2a8959..83a2361a 100644 --- a/tests/testthat/helpers.R +++ b/tests/testthat/helpers.R @@ -706,6 +706,7 @@ p_theta_greta <- function( data, p_theta, p_x_bar_theta, + # TODO note that we might want to change this to adaptive_hmc() sampler = hmc(), warmup = 1000 ) { @@ -926,6 +927,7 @@ get_distribution_name <- function(x){ check_samples <- function( x, iid_function, + # TODO note that we might want to change this to adaptive_hmc sampler = hmc(), n_effective = 3000, title = NULL, diff --git a/tests/testthat/test-adaptive-hmc.R b/tests/testthat/test-adaptive-hmc.R new file mode 100644 index 00000000..450c2fe7 --- /dev/null +++ b/tests/testthat/test-adaptive-hmc.R @@ -0,0 +1,51 @@ +set.seed(2025 - 02 - 13) + +test_that("bad mcmc proposals are rejected", { + skip_if_not(check_tf_version()) + + # set up for numerical rejection of initial location + x <- rnorm(10000, 1e60, 1) + z <- normal(-1e60, 1e-60) + distribution(x) <- normal(z, 1e-60) + m <- model(z, precision = "single") + + # # catch badness in the progress bar + out <- get_output( + mcmc(m, n_samples = 10, warmup = 0, pb_update = 10) + ) + expect_match(out, "100% bad") + + expect_snapshot(error = TRUE, + draws <- mcmc(m, + chains = 1, + n_samples = 2, + warmup = 0, + verbose = FALSE, + initial_values = initials(z = 1e120) + ) + ) + + # really bad proposals + x <- rnorm(100000, 1e120, 1) + z <- normal(-1e120, 1e-120) + distribution(x) <- normal(z, 1e-120) + m <- model(z, precision = "single") + expect_snapshot(error = TRUE, + mcmc(m, chains = 1, n_samples = 1, warmup = 0, verbose = FALSE) + ) + + # proposals that are fine, but rejected anyway + z <- normal(0, 1) + m <- model(z, precision = "single") + expect_ok(mcmc(m, + adaptive_hmc( + epsilon = 100, + # Lmin = 1, + # Lmax = 1 + ), + chains = 1, + n_samples = 5, + warmup = 0, + verbose = FALSE + )) +}) diff --git a/tests/testthat/test_inference.R b/tests/testthat/test_inference.R index b70f2f15..c67d5b11 100644 --- a/tests/testthat/test_inference.R +++ b/tests/testthat/test_inference.R @@ -381,6 +381,9 @@ test_that("samplers print informatively", { expect_snapshot( hmc(Lmin = 1) ) + expect_snapshot( + adaptive_hmc(Lmin = 1) + ) # # check print sees changed parameters # out <- capture_output(hmc(Lmin = 1), TRUE) diff --git a/tests/testthat/test_posteriors_binomial.R b/tests/testthat/test_posteriors_binomial.R index 92d4a087..6a8dc2dc 100644 --- a/tests/testthat/test_posteriors_binomial.R +++ b/tests/testthat/test_posteriors_binomial.R @@ -1,4 +1,4 @@ -test_that("posterior is correct (binomial)", { +test_that("posterior is correct (binomial) with hmc", { skip_if_not(check_tf_version()) skip_on_os("windows") # analytic solution to the posterior of the paramter of a binomial @@ -29,3 +29,35 @@ test_that("posterior is correct (binomial)", { suppressWarnings(test <- ks.test(samples, comparison)) expect_gte(test$p.value, 0.01) }) + +test_that("posterior is correct (binomial) with adaptive hmc", { + skip_if_not(check_tf_version()) + skip_on_os("windows") + # analytic solution to the posterior of the paramter of a binomial + # distribution, with uniform prior + n <- 100 + pos <- rbinom(1, n, runif(1)) + theta <- uniform(0, 1) + distribution(pos) <- binomial(n, theta) + m <- model(theta) + + draws <- get_enough_draws(m, adaptive_hmc(), 2000, verbose = FALSE) + + samples <- as.matrix(draws) + + # analytic solution to posterior is beta(1 + pos, 1 + N - pos) + shape1 <- 1 + pos + shape2 <- 1 + n - pos + + # qq plot against true quantiles + quants <- (1:99) / 100 + q_target <- qbeta(quants, shape1, shape2) + q_est <- quantile(samples, quants) + plot(q_target ~ q_est, main = "binomial posterior") + abline(0, 1) + + n_draws <- round(coda::effectiveSize(draws)) + comparison <- rbeta(n_draws, shape1, shape2) + suppressWarnings(test <- ks.test(samples, comparison)) + expect_gte(test$p.value, 0.01) +}) diff --git a/tests/testthat/test_posteriors_bivariate_normal.R b/tests/testthat/test_posteriors_bivariate_normal.R index eb36bcea..a1b4b362 100644 --- a/tests/testthat/test_posteriors_bivariate_normal.R +++ b/tests/testthat/test_posteriors_bivariate_normal.R @@ -5,9 +5,13 @@ test_that("samplers are unbiased for bivariate normals", { hmc_mvn_samples <- check_mvn_samples(sampler = hmc()) expect_lte(max(hmc_mvn_samples), stats::qnorm(0.99)) + adaptive_hmc_mvn_samples <- check_mvn_samples(sampler = adaptive_hmc()) + expect_lte(max(adaptive_hmc_mvn_samples), stats::qnorm(0.99)) + rwmh_mvn_samples <- check_mvn_samples(sampler = rwmh()) expect_lte(max(rwmh_mvn_samples), stats::qnorm(0.99)) slice_mvn_samples <- check_mvn_samples(sampler = slice()) expect_lte(max(rwmh_mvn_samples), stats::qnorm(0.99)) }) + diff --git a/tests/testthat/test_posteriors_chi_squared.R b/tests/testthat/test_posteriors_chi_squared.R index 68228fd6..7eae3229 100644 --- a/tests/testthat/test_posteriors_chi_squared.R +++ b/tests/testthat/test_posteriors_chi_squared.R @@ -17,3 +17,23 @@ test_that("samplers are unbiased for chi-squared", { expect_gte(stat$p.value, 0.01) }) + +test_that("samplers are unbiased for chi-squared with adaptive hmc", { + skip_if_not(check_tf_version()) + skip_on_os("windows") + df <- 5 + x <- chi_squared(df) + iid <- function(n) rchisq(n, df) + + chi_squared_checked <- check_samples(x = x, + iid_function = iid, + sampler = adaptive_hmc()) + + # do the plotting + qqplot_checked_samples(chi_squared_checked) + + # do a formal hypothesis test + stat <- ks_test_mcmc_vs_iid(chi_squared_checked) + + expect_gte(stat$p.value, 0.01) +}) diff --git a/tests/testthat/test_posteriors_geweke.R b/tests/testthat/test_posteriors_geweke.R index ef810518..d7fdaa5f 100644 --- a/tests/testthat/test_posteriors_geweke.R +++ b/tests/testthat/test_posteriors_geweke.R @@ -1,38 +1,36 @@ Sys.setenv("RELEASE_CANDIDATE" = "false") -test_that("samplers pass geweke tests", { +# run geweke tests on this model: +# theta ~ normal(mu1, sd1) +# x[i] ~ normal(theta, sd2) +# for i in N + +n <- 10 +mu1 <- rnorm(1, 0, 3) +sd1 <- rlnorm(1) +sd2 <- rlnorm(1) + +# prior (n draws) +p_theta <- function(n) { + rnorm(n, mu1, sd1) +} + +# likelihood +p_x_bar_theta <- function(theta) { + rnorm(n, theta, sd2) +} + +# define the greta model (single precision for slice sampler) +x <- as_data(rep(0, n)) +greta_theta <- normal(mu1, sd1) +distribution(x) <- normal(greta_theta, sd2) +model <- model(greta_theta, precision = "single") + +test_that("hmc sampler passes geweke tests", { skip_if_not(check_tf_version()) skip_if_not_release() - # nolint start - # run geweke tests on this model: - # theta ~ normal(mu1, sd1) - # x[i] ~ normal(theta, sd2) - # for i in N - # nolint end - - n <- 10 - mu1 <- rnorm(1, 0, 3) - sd1 <- rlnorm(1) - sd2 <- rlnorm(1) - - # prior (n draws) - p_theta <- function(n) { - rnorm(n, mu1, sd1) - } - - # likelihood - p_x_bar_theta <- function(theta) { - rnorm(n, theta, sd2) - } - - # define the greta model (single precision for slice sampler) - x <- as_data(rep(0, n)) - greta_theta <- normal(mu1, sd1) - distribution(x) <- normal(greta_theta, sd2) - model <- model(greta_theta, precision = "single") - # run tests on all available samplers geweke_hmc <- check_geweke( sampler = hmc(), @@ -48,8 +46,14 @@ test_that("samplers pass geweke tests", { geweke_stat_hmc <- geweke_ks(geweke_hmc) testthat::expect_gte(geweke_stat_hmc$p.value, 0.005) +}) + +test_that("rwmh sampler passes geweke tests", { + skip_if_not(check_tf_version()) - geweke_hmc_rwmh <- check_geweke( + skip_if_not_release() + + geweke_rwmh <- check_geweke( sampler = rwmh(), model = model, data = x, @@ -59,13 +63,19 @@ test_that("samplers pass geweke tests", { thin = 5 ) - geweke_qq(geweke_hmc_rwmh, title = "RWMH Geweke test") + geweke_qq(geweke_rwmh, title = "RWMH Geweke test") - geweke_stat_rwmh <- geweke_ks(geweke_hmc_rwmh) + geweke_stat_rwmh <- geweke_ks(geweke_rwmh) testthat::expect_gte(geweke_stat_rwmh$p.value, 0.005) +}) - geweke_hmc_slice <- check_geweke( +test_that("slice sampler passes geweke tests", { + skip_if_not(check_tf_version()) + + skip_if_not_release() + + geweke_slice <- check_geweke( sampler = slice(), model = model, data = x, @@ -73,8 +83,28 @@ test_that("samplers pass geweke tests", { p_x_bar_theta = p_x_bar_theta ) - geweke_qq(geweke_hmc_slice, title = "slice sampler Geweke test") + geweke_qq(geweke_slice, title = "slice sampler Geweke test") + + testthat::expect_gte(geweke_slice$p.value, 0.005) + +}) + + +test_that("adaptive hmc sampler passes geweke tests", { + skip_if_not(check_tf_version()) + + skip_if_not_release() + + geweke_adaptive_hmc <- check_geweke( + sampler = adaptive_hmc(), + model = model, + data = x, + p_theta = p_theta, + p_x_bar_theta = p_x_bar_theta + ) + + geweke_qq(geweke_adaptive_hmc, title = "adaptive hmc sampler Geweke test") - testthat::expect_gte(geweke_hmc_slice$p.value, 0.005) + testthat::expect_gte(geweke_adaptive_hmc$p.value, 0.005) }) From 188bf1dac09bbcbff08f55279c19ebc943e595f7 Mon Sep 17 00:00:00 2001 From: njtierney Date: Thu, 13 Feb 2025 15:35:06 +0800 Subject: [PATCH 03/11] a bit more progress on implementing adaptive HMC --- R/inference_class.R | 78 +++++++++++++++++++-------------------------- R/samplers.R | 24 +++++++++----- 2 files changed, 48 insertions(+), 54 deletions(-) diff --git a/R/inference_class.R b/R/inference_class.R index 3ea43919..42e8fbd7 100644 --- a/R/inference_class.R +++ b/R/inference_class.R @@ -157,6 +157,7 @@ inference <- R6Class( # check whether the model can be evaluated at these parameters valid_parameters = function(parameters) { + # ADAPTIVE HMC - this is where this breaks dag <- self$model$dag tf_parameters <- fl(array( data = parameters, @@ -668,6 +669,7 @@ sampler <- R6Class( # pass values through ) { + browser() dag <- self$model$dag tfe <- dag$tf_environment @@ -693,6 +695,7 @@ sampler <- R6Class( # Need to work out how to get sampler_batch() to run as a TF function. # To do that we need to work out how to get the free state + browser() sampler_batch <- tfp$mcmc$sample_chain( num_results = tf$math$floordiv(sampler_burst_length, sampler_thin), current_state = free_state, @@ -794,6 +797,8 @@ sampler <- R6Class( # legacy: previously we used `n_samples` not `sampler_burst_length` n_samples <- sampler_burst_length + # ADPATIVE HMC + # TODO - this is where the adaptive_hmc fails at the moment result <- cleanly( self$tf_evaluate_sample_batch( free_state = tensorflow::as_tensor( @@ -877,7 +882,6 @@ hmc_sampler <- R6Class( accept_target = 0.651, define_tf_kernel = function(sampler_param_vec) { - dag <- self$model$dag tfe <- dag$tf_environment @@ -1056,70 +1060,50 @@ slice_sampler <- R6Class( ) ) -snaper_hmc_sampler <- R6Class( - "snaper_hmc_sampler", +adaptive_hmc_sampler <- R6Class( + "adaptive_hmc_sampler", inherit = sampler, public = list( parameters = list( # Lmin = 10, # Lmax = 20, max_leapfrog_steps = 1000, - epsilon = 0.005, - diag_sd = 1 + # TODO clean up these parameter usage else where + # epsilon = 0.005, + # diag_sd = 1, + # TODO some kind of validity check of method? Currently this can only be + # "SNAPER". + method = "SNAPER" ), accept_target = 0.651, define_tf_kernel = function(sampler_param_vec) { - + browser() dag <- self$model$dag tfe <- dag$tf_environment - # TODO double check this free_state_size <- length(sampler_param_vec) - 2 - # TF1/2 check - # this will likely get replaced... - - s_hmc_max_leapfrog_steps <- sampler_param_vec[0] - s_hmc_epsilon <- sampler_param_vec[2] - s_hmc_diag_sd <- sampler_param_vec[2:(1+free_state_size)] - - hmc_step_sizes <- tf$cast( - x = tf$reshape( - hmc_epsilon * (s_hmc_diag_sd / tf$reduce_sum(s_hmc_diag_sd)), - shape = shape(free_state_size) - ), - dtype = tf$float64 + adaptive_hmc_max_leapfrog_steps <- tf$cast( + x = sampler_param_vec[0], + dtype = tf$int32 ) - # TF1/2 check - # where is "free_state" pulled from, given that it is the - # argument to this function, "generate_log_prob_function" ? - # log probability function - - # build the kernel - # nolint start - - # sampler_kernel <- tfp$mcmc$HamiltonianMonteCarlo( - # target_log_prob_fn = dag$tf_log_prob_function_adjusted, - # step_size = hmc_step_sizes, - # num_leapfrog_steps = hmc_l - # ) + # TODO pipe that in properly + n_warmup <- sampler_param_vec[1] + # adaptive_hmc_epsilon <- sampler_param_vec[1] + # adaptive_hmc_diag_sd <- sampler_param_vec[2:(1+free_state_size)] kernel_base <- tfp$experimental$mcmc$SNAPERHamiltonianMonteCarlo( target_log_prob_fn = dag$tf_log_prob_function_adjusted, - # TODO do we want to use hmc_step_size or 1? - # step_size = 1, - step_size = hmc_step_sizes, - # num_adaptation_steps = as.integer(n_warmup * 0.9)) - # TODO check n_warmup? - num_adaptation_steps = as.integer(s_hmc_max_leapfrog_steps * 0.9)) + step_size = 1, + num_adaptation_steps = as.integer(self$warmup), + max_leapfrog_steps = adaptive_hmc_max_leapfrog_steps + ) sampler_kernel <- tfp$mcmc$DualAveragingStepSizeAdaptation( inner_kernel = kernel_base, - # num_adaptation_steps = as.integer(n_warmup)) - # TODO check n_warmup? - num_adaptation_steps = as.integer(s_hmc_max_leapfrog_steps)) - + num_adaptation_steps = as.integer(self$warmup) + ) return( sampler_kernel @@ -1131,12 +1115,14 @@ snaper_hmc_sampler <- R6Class( max_leapfrog_steps <- self$parameters$max_leapfrog_steps epsilon <- self$parameters$epsilon diag_sd <- matrix(self$parameters$diag_sd) + method <- self$parameters$method # return named list for replacing tensors list( - s_hmc_max_leapfrog_steps = max_leapfrog_steps, - s_hmc_epsilon = epsilon, - s_hmc_diag_sd = diag_sd + adaptive_hmc_max_leapfrog_steps = max_leapfrog_steps, + adaptive_hmc_epsilon = epsilon, + adaptive_hmc_diag_sd = diag_sd, + method = method ) } ) diff --git a/R/samplers.R b/R/samplers.R index dbd051a5..612fe2cb 100644 --- a/R/samplers.R +++ b/R/samplers.R @@ -98,16 +98,24 @@ slice <- function(max_doublings = 5) { #' (positive, will be tuned). #' @param max_leapfrog_steps numeric. Default 1000. Maximum number of leapfrog #' steps used. The algorithm will determine the optimal number less than this. -#' -#' @details For `snaper_hmc()`. The Lmin and Lmax parameters are learnt and so -#' not probivided in this. The number of chains cannot be less than 2, due to -#' how Snaper HMC works. `diag_sd` is used to rescale the parameter space to +#' @param method character length one. Currently can only be "SNAPER" but in +#' the future this may expand to other adaptive samplers. +#' @details For `adaptive_hmc()`. The Lmin and Lmax parameters are learnt and so +#' not provided in this. The number of chains cannot be less than 2, due to +#' how adaptive HMC works. `diag_sd` is used to rescale the parameter space to #' make it more uniform, and make sampling more efficient. -snaper_hmc <- function( +adaptive_hmc <- function( max_leapfrog_steps = 1000, epsilon = 0.1, - diag_sd = 1 + diag_sd = 1, + method = "SNAPER" ) { + + method <- rlang::arg_match( + arg = method, + values = "SNAPER" + ) + # nolint end obj <- list( parameters = list( @@ -115,9 +123,9 @@ snaper_hmc <- function( epsilon = epsilon, diag_sd = diag_sd ), - class = snaper_hmc_sampler + class = adaptive_hmc_sampler ) - class(obj) <- c("snaper hmc sampler", "sampler") + class(obj) <- c("adaptive_hmc_sampler", "sampler") obj } From 10f76e3c423a1d71ac2cf454af531bceb80d4586 Mon Sep 17 00:00:00 2001 From: njtierney Date: Thu, 13 Feb 2025 16:08:21 +0800 Subject: [PATCH 04/11] NG notes on appropriately handling the kernel --- R/inference_class.R | 28 +++++++++++++++++++++++----- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/R/inference_class.R b/R/inference_class.R index 42e8fbd7..6614d040 100644 --- a/R/inference_class.R +++ b/R/inference_class.R @@ -307,6 +307,15 @@ sampler <- R6Class( }, define_tf_evaluate_sample_batch = function(){ + + # create a dummy sample_param_vec (vector with length as defined below) + # dummy_sampler_param_vec <- self$sampler_parameter_values() + # create dummy kernel using this, with: + # dummy_kernel <- self$define_tf_kernel(dummy_sampler_param_vec) + # use dummy kernel to bootrap a dummy results object + # dummy_kernel_results <- dummy_kernel$bootstrap_results() + # use dummy results object to make a tensorspec or whatever + self$tf_evaluate_sample_batch <- tensorflow::tf_function( f = self$define_tf_draws, input_signature = list( @@ -327,6 +336,8 @@ sampler <- R6Class( ) ) ), + # kernel_results + kernel$bootstrap_results() dtype = tf_float() ) ) @@ -744,7 +755,8 @@ sampler <- R6Class( free_state = self$free_state, sampler_burst_length = as.integer(n_samples), sampler_thin = as.integer(thin), - sampler_param_vec = param_vec + sampler_param_vec = param_vec, + kernel_results = kernel_results ) # get trace of free state and drop the null dimension @@ -789,7 +801,8 @@ sampler <- R6Class( sample_carefully = function(free_state, sampler_burst_length, sampler_thin, - sampler_param_vec) { + sampler_param_vec, + kernel_results) { # tryCatch handling for numerical errors dag <- self$model$dag @@ -799,6 +812,10 @@ sampler <- R6Class( # ADPATIVE HMC # TODO - this is where the adaptive_hmc fails at the moment + + # so we can pass in the results from the previous kernel + dummy_kernel <- self$define_tf_kernel() + result <- cleanly( self$tf_evaluate_sample_batch( free_state = tensorflow::as_tensor( @@ -811,7 +828,8 @@ sampler <- R6Class( sampler_param_vec, dtype = tf_float(), shape = length(sampler_param_vec) - ) + ), + kernel_results = kernel_results ) ) # closing cleanly @@ -1120,8 +1138,8 @@ adaptive_hmc_sampler <- R6Class( # return named list for replacing tensors list( adaptive_hmc_max_leapfrog_steps = max_leapfrog_steps, - adaptive_hmc_epsilon = epsilon, - adaptive_hmc_diag_sd = diag_sd, + # adaptive_hmc_epsilon = epsilon, + # adaptive_hmc_diag_sd = diag_sd, method = method ) } From a0672adddd11a17375295ae49293a7d419994773 Mon Sep 17 00:00:00 2001 From: njtierney Date: Thu, 13 Feb 2025 16:54:51 +0800 Subject: [PATCH 05/11] trying to pass through the appropriate TensorSpec --- R/inference_class.R | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/R/inference_class.R b/R/inference_class.R index 6614d040..97e94839 100644 --- a/R/inference_class.R +++ b/R/inference_class.R @@ -307,13 +307,20 @@ sampler <- R6Class( }, define_tf_evaluate_sample_batch = function(){ + browser() + + dummy_init_state <- matrix(data = 0, + nrow = nrow(self$free_state), + ncol = ncol(self$free_state)) # create a dummy sample_param_vec (vector with length as defined below) - # dummy_sampler_param_vec <- self$sampler_parameter_values() + dummy_sampler_param_vec <- length(unlist(self$sampler_parameter_values())) # create dummy kernel using this, with: - # dummy_kernel <- self$define_tf_kernel(dummy_sampler_param_vec) + dummy_kernel <- self$define_tf_kernel(dummy_sampler_param_vec) # use dummy kernel to bootrap a dummy results object - # dummy_kernel_results <- dummy_kernel$bootstrap_results() + dummy_kernel_results <- dummy_kernel$bootstrap_results( + init_state = dummy_init_state + ) # use dummy results object to make a tensorspec or whatever self$tf_evaluate_sample_batch <- tensorflow::tf_function( @@ -335,10 +342,13 @@ sampler <- R6Class( self$sampler_parameter_values() ) ) - ), + ) + ), # kernel_results - kernel$bootstrap_results() - dtype = tf_float() + tf$TensorSpec( + shape = list( + length(dummy_kernel_results) + ) ) ) ) From 435d9aa736a8a80e53255cf8fd05f944dcd5a9d6 Mon Sep 17 00:00:00 2001 From: njtierney Date: Fri, 14 Feb 2025 10:37:52 +0800 Subject: [PATCH 06/11] try creating a tensor spec for the kernel results using tf$nest$map_structure --- R/inference_class.R | 12 +++++++----- R/utils.R | 13 +++++++++++++ 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/R/inference_class.R b/R/inference_class.R index 97e94839..a2a0adda 100644 --- a/R/inference_class.R +++ b/R/inference_class.R @@ -321,6 +321,12 @@ sampler <- R6Class( dummy_kernel_results <- dummy_kernel$bootstrap_results( init_state = dummy_init_state ) + + dummy_kernel_results_tensor_spec <- tf$nest$map_structure( + maybe_make_tensor_shape, + dummy_kernel_results + ) + # use dummy results object to make a tensorspec or whatever self$tf_evaluate_sample_batch <- tensorflow::tf_function( @@ -345,11 +351,7 @@ sampler <- R6Class( ) ), # kernel_results - tf$TensorSpec( - shape = list( - length(dummy_kernel_results) - ) - ) + dummy_kernel_results_tensor_spec ) ) }, diff --git a/R/utils.R b/R/utils.R index 1abe0105..d42ab725 100644 --- a/R/utils.R +++ b/R/utils.R @@ -1003,3 +1003,16 @@ n_warmup <- function(x){ x_info <- attr(x, "model_info") x_info$warmup } + +build_tensor_spec <- function(tensor){ + tf$TensorSpec(shape = tensor$shape, + dtype = tensor$dtype) +} + +maybe_make_tensor_shape <- function(x){ + if (tf$is_tensor(x)) { + build_tensor_spec(x) + } else{ + x + } +} From 1c7a6f3ea0d0fb1a010702c3050e3de312d224ac Mon Sep 17 00:00:00 2001 From: njtierney Date: Fri, 14 Feb 2025 12:02:22 +0800 Subject: [PATCH 07/11] notes from meeting with NG --- R/inference_class.R | 177 ++++++++++++++++++++++++++------------------ 1 file changed, 103 insertions(+), 74 deletions(-) diff --git a/R/inference_class.R b/R/inference_class.R index a2a0adda..9e481ed9 100644 --- a/R/inference_class.R +++ b/R/inference_class.R @@ -275,6 +275,9 @@ sampler <- R6Class( # batch sizes for tracing trace_batch_size = 100, + + kernel_results = NULL, + initialize = function(initial_values, model, parameters = list(), @@ -306,52 +309,48 @@ sampler <- R6Class( }, + # TODO two versions of `define_tf_evaluate_sample_batch()` + # one that does warmup + # this returns the last trace + # and the kernel + # one that does sampling + # then we need to curry this with the kernel used in warmup + + # sampling one? define_tf_evaluate_sample_batch = function(){ browser() - dummy_init_state <- matrix(data = 0, - nrow = nrow(self$free_state), - ncol = ncol(self$free_state)) - # create a dummy sample_param_vec (vector with length as defined below) - dummy_sampler_param_vec <- length(unlist(self$sampler_parameter_values())) - # create dummy kernel using this, with: - dummy_kernel <- self$define_tf_kernel(dummy_sampler_param_vec) - # use dummy kernel to bootrap a dummy results object - dummy_kernel_results <- dummy_kernel$bootstrap_results( - init_state = dummy_init_state - ) - - dummy_kernel_results_tensor_spec <- tf$nest$map_structure( - maybe_make_tensor_shape, - dummy_kernel_results - ) - - # use dummy results object to make a tensorspec or whatever + sampler_param_length <- length(unlist(self$sampler_parameter_values())) self$tf_evaluate_sample_batch <- tensorflow::tf_function( f = self$define_tf_draws, input_signature = list( # free state - tf$TensorSpec(shape = list(NULL, self$n_free), - dtype = tf_float()), + tf$TensorSpec( + # TODO + # we might be able to remove this NULL + # with a known shape, might make TF happier + # this will just be the number of chains + shape = list(NULL, self$n_free), + dtype = tf_float() + ), # sampler_burst_length - tf$TensorSpec(shape = list(), - dtype = tf$int32), + tf$TensorSpec( + shape = list(), + dtype = tf$int32 + ), # sampler_thin - tf$TensorSpec(shape = list(), - dtype = tf$int32), + tf$TensorSpec( + shape = list(), + dtype = tf$int32 + ), # sampler_param_vec - tf$TensorSpec(shape = list( - length( - unlist( - self$sampler_parameter_values() - ) + tf$TensorSpec( + shape = list(sampler_param_length), + # TODO this is new + dtype = tf$float64 ) - ) - ), - # kernel_results - dummy_kernel_results_tensor_spec ) ) }, @@ -397,8 +396,17 @@ sampler <- R6Class( # how big would we like the bursts to be ideal_burst_size <- ifelse(one_by_one, 1L, pb_update) + # TODO # if warmup is required, do that now if (warmup > 0) { + + # TODO + # create kernel results dummy object + # assign to self$kernel_results + # rebuild tf_evaluate_sample_batch function + # self$define_tf_evaluate_sample_batch() + # also get rid of bursts for warmup / progress bar goes from 0->100 + if (verbose) { pb_warmup <- create_progress_bar( "warmup", @@ -412,57 +420,73 @@ sampler <- R6Class( pb_warmup <- NULL } - # split up warmup iterations into bursts of sampling - burst_lengths <- self$burst_lengths(warmup, - ideal_burst_size, - warmup = TRUE - ) - completed_iterations <- cumsum(burst_lengths) + self$run_burst(n_samples = warmup) + # # split up warmup iterations into bursts of sampling + # burst_lengths <- self$burst_lengths(warmup, + # ideal_burst_size, + # warmup = TRUE + # ) + # completed_iterations <- cumsum(burst_lengths) # relay between R and tensorflow in a burst to be cpu efficient - for (burst in seq_along(burst_lengths)) { - # TF1/2 check todo? - # replace with define_tf_draws - - self$run_burst(n_samples = burst_lengths[burst]) - # align the free state back to the parameters we are tracing - # TF1/2 check todo? - # this is the tuning stage, might not need to evaluate - # / record the parameter values, as they will be thrown away - # after warmup - so could remove trace here. - - self$trace() - # a memory efficient way to calculate summary stats of samples - self$update_welford() - self$tune(completed_iterations[burst], warmup) - - if (verbose) { + # for (burst in seq_along(burst_lengths)) { + # # TF1/2 check todo? + # # replace with define_tf_draws + # + # # self$run_burst(n_samples = burst_lengths[burst]) + # # align the free state back to the parameters we are tracing + # # TF1/2 check todo? + # # this is the tuning stage, might not need to evaluate + # # / record the parameter values, as they will be thrown away + # # after warmup - so could remove trace here. + # + # self$trace() + # # a memory efficient way to calculate summary stats of samples + # self$update_welford() + # self$tune(completed_iterations[burst], warmup) + # + # if (verbose) { + # + # # update the progress bar/percentage log + # iterate_progress_bar(pb_warmup, + # it = completed_iterations[burst], + # rejects = self$numerical_rejections, + # chains = self$n_chains, + # file = self$pb_file + # ) + # + # self$write_percentage_log(warmup, + # completed_iterations[burst], + # stage = "warmup" + # ) + # } + # } + + ## close progress bar - # update the progress bar/percentage log iterate_progress_bar(pb_warmup, - it = completed_iterations[burst], - rejects = self$numerical_rejections, + it = warmup, + # TODO + # grab the rejected samples somehow + # rejects = self$numerical_rejections, + rejects = 1L, chains = self$n_chains, file = self$pb_file ) - self$write_percentage_log(warmup, - completed_iterations[burst], - stage = "warmup" - ) - } - } - # scrub the free state trace and numerical rejections - self$traced_free_state <- replicate(self$n_chains, - matrix(NA, 0, self$n_free), - simplify = FALSE - ) + # self$traced_free_state <- replicate(self$n_chains, + # matrix(NA, 0, self$n_free), + # simplify = FALSE + # ) self$numerical_rejections <- 0 } - + # TODO possibly grab out last state (final free state), put into sampling if (n_samples > 0) { + # Recompile kernel results - self$define_tf_evaluate_sample_batch() + # then...the rest of this code should work + # on exiting during the main sampling period (even if killed by the # user) trace the free state values @@ -721,8 +745,10 @@ sampler <- R6Class( browser() sampler_batch <- tfp$mcmc$sample_chain( num_results = tf$math$floordiv(sampler_burst_length, sampler_thin), + previous_kernel_results = self$kernel_results, current_state = free_state, kernel = sampler_kernel, + return_final_kernel_results = TRUE, trace_fn = function(current_state, kernel_results) { kernel_results }, @@ -745,6 +771,7 @@ sampler <- R6Class( dag <- self$model$dag tfe <- dag$tf_environment + # TODO we can use the param vec internally inside TF/TFP param_vec <- unlist(self$sampler_parameter_values()) # combine the sampler information with information on the sampler's tuning # parameters, and make into a dict @@ -760,6 +787,7 @@ sampler <- R6Class( # # and then run the code for the sampler_batch # + # TODO - possibly TF1 thing, can remove dag$set_tf_data_list("batch_size", nrow(self$free_state)) # run the sampler, handling numerical errors @@ -771,6 +799,8 @@ sampler <- R6Class( kernel_results = kernel_results ) + self$kernel_results <- batch_results$final_kernel_results + # get trace of free state and drop the null dimension if (is.null(batch_results$all_states)){ browser() @@ -838,10 +868,9 @@ sampler <- R6Class( sampler_thin = tensorflow::as_tensor(sampler_thin), sampler_param_vec = tensorflow::as_tensor( sampler_param_vec, - dtype = tf_float(), + dtype = tf$float64, shape = length(sampler_param_vec) - ), - kernel_results = kernel_results + ) ) ) # closing cleanly From 15d488898e8b7a41c1ecd3b58177ed2cd383d5eb Mon Sep 17 00:00:00 2001 From: njtierney Date: Thu, 6 Mar 2025 12:18:09 +1100 Subject: [PATCH 08/11] add some missing commas --- R/inference_class.R | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/R/inference_class.R b/R/inference_class.R index 9e481ed9..0512fe59 100644 --- a/R/inference_class.R +++ b/R/inference_class.R @@ -309,6 +309,12 @@ sampler <- R6Class( }, + define_tf_evaluate_sample_batch_warmup = function(){ + # This does warmup, returns: + # Last trace + # Kernel + }, + # TODO two versions of `define_tf_evaluate_sample_batch()` # one that does warmup # this returns the last trace @@ -355,6 +361,23 @@ sampler <- R6Class( ) }, + create_dummy_kernel_results = function(){ + + dummy_init_state <- matrix(data = 0, + nrow = nrow(self$free_state), + ncol = ncol(self$free_state)) + + dummy_sampler_param_vec <- length(unlist(self$sampler_parameter_values())) + # create dummy kernel using this, with: + dummy_kernel <- self$define_tf_kernel(dummy_sampler_param_vec) + # use dummy kernel to bootrap a dummy results object + dummy_kernel_results <- dummy_kernel$bootstrap_results( + init_state = dummy_init_state + ) + + dummy_kernel_results + }, + run_chain = function(n_samples, thin, warmup, verbose, pb_update, one_by_one, plan_is, n_cores, float_type, @@ -403,8 +426,9 @@ sampler <- R6Class( # TODO # create kernel results dummy object # assign to self$kernel_results + self$kernel_results <- self$create_dummy_kernel_results() # rebuild tf_evaluate_sample_batch function - # self$define_tf_evaluate_sample_batch() + self$define_tf_evaluate_sample_batch() # also get rid of bursts for warmup / progress bar goes from 0->100 if (verbose) { From 8fd433421c037e2faf13fff8516f8e932f381be3 Mon Sep 17 00:00:00 2001 From: njtierney Date: Thu, 6 Mar 2025 12:26:43 +1100 Subject: [PATCH 09/11] make sure tests for adaptive explicitly call sampler = adaptive_hmc() --- tests/testthat/test-adaptive-hmc.R | 50 ++++++++++++++++++------------ 1 file changed, 30 insertions(+), 20 deletions(-) diff --git a/tests/testthat/test-adaptive-hmc.R b/tests/testthat/test-adaptive-hmc.R index 450c2fe7..45874fa6 100644 --- a/tests/testthat/test-adaptive-hmc.R +++ b/tests/testthat/test-adaptive-hmc.R @@ -15,14 +15,17 @@ test_that("bad mcmc proposals are rejected", { ) expect_match(out, "100% bad") - expect_snapshot(error = TRUE, - draws <- mcmc(m, - chains = 1, - n_samples = 2, - warmup = 0, - verbose = FALSE, - initial_values = initials(z = 1e120) - ) + expect_snapshot( + error = TRUE, + draws <- mcmc( + m, + chains = 1, + n_samples = 2, + warmup = 0, + verbose = FALSE, + sampler = adaptive_hmc(), + initial_values = initials(z = 1e120) + ) ) # really bad proposals @@ -30,22 +33,29 @@ test_that("bad mcmc proposals are rejected", { z <- normal(-1e120, 1e-120) distribution(x) <- normal(z, 1e-120) m <- model(z, precision = "single") - expect_snapshot(error = TRUE, - mcmc(m, chains = 1, n_samples = 1, warmup = 0, verbose = FALSE) + expect_snapshot( + error = TRUE, + mcmc(m, + chains = 1, + n_samples = 1, + warmup = 0, + sampler = adaptive_hmc(), + verbose = FALSE) ) # proposals that are fine, but rejected anyway z <- normal(0, 1) m <- model(z, precision = "single") - expect_ok(mcmc(m, - adaptive_hmc( - epsilon = 100, - # Lmin = 1, - # Lmax = 1 - ), - chains = 1, - n_samples = 5, - warmup = 0, - verbose = FALSE + expect_ok(mcmc( + m, + adaptive_hmc( + epsilon = 100, + # Lmin = 1, + # Lmax = 1 + ), + chains = 1, + n_samples = 5, + warmup = 0, + verbose = FALSE )) }) From 1d5c41583b4c460e317a70e7f8e43c61f55b859b Mon Sep 17 00:00:00 2001 From: njtierney Date: Thu, 6 Mar 2025 12:55:49 +1100 Subject: [PATCH 10/11] Add some helper functions for working with parts of snaper / adaptive hmc --- R/inference_class.R | 2 ++ R/utils.R | 71 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 73 insertions(+) diff --git a/R/inference_class.R b/R/inference_class.R index 0512fe59..7e6c704d 100644 --- a/R/inference_class.R +++ b/R/inference_class.R @@ -309,10 +309,12 @@ sampler <- R6Class( }, + ## TODO define_tf_evaluate_sample_batch_warmup = function(){ # This does warmup, returns: # Last trace # Kernel + }, # TODO two versions of `define_tf_evaluate_sample_batch()` diff --git a/R/utils.R b/R/utils.R index d42ab725..2c0a3525 100644 --- a/R/utils.R +++ b/R/utils.R @@ -1016,3 +1016,74 @@ maybe_make_tensor_shape <- function(x){ x } } + +# get the final model parameter state from a chain as returned in the all_states +# object from tfp$mcmc$sample_chain +get_last_state <- function(all_states) { + n_iter <- dim(all_states)[1] + tf$gather(all_states, n_iter - 1L, 0L) +} + +# find out if MCMC steps had non-finite acceptance probabilities +bad_steps <- function(kernel_results) { + log_accept_ratios <- recursive_get_log_accept_ratio(kernel_results) + !is.finite(log_accept_ratios) +} + + +# recursively extract the log accaptance ratio from the MCMC kernel +recursive_get_log_accept_ratio <- function(kernel_results) { + nm <- names(kernel_results) + if("log_accept_ratio" %in% nm) { + log_accept_ratios <- kernel_results$log_accept_ratio + } else if ("inner_results" %in% nm) { + log_accept_ratios <- recursive_get_log_accept_ratio( + kernel_results$inner_results + ) + } else { + stop("non-standard kernel structure") + } + as.array(log_accept_ratios) +} + + +# given a warmed up sampler object, return a compiled TF function that generates +# a new burst of samples from samples from it +make_sampler_function <- function(warm_sampler) { + + # make the uncompiled function (with curried arguments) + sample_raw <- function(current_state, n_samples) { + results <- tfp$mcmc$sample_chain( + # how many iterations + num_results = n_samples, + # where to start from + current_state = current_state, + # kernel + kernel = warm_sampler$kernel, + # tuned sampler settings + previous_kernel_results = warm_sampler$kernel_results, + # what to trace (nothing) + trace_fn = function(current_state, kernel_results) { + # could compute badness here to save memory? + # is.finite(kernel_results$inner_results$inner_results$inner_results$log_accept_ratio) + kernel_results + } + ) + # return the parameter states and the kernel results + list( + all_states = results$all_states, + kernel_results = results$trace + ) + } + + # compile it into a concrete function and return + sample <- tf_function(sample_raw, + list( + as_tensorspec(warm_sampler$current_state), + tf$TensorSpec(shape = c(), + dtype = tf$int32) + )) + + sample + +} From 0ddff31a219f65466b381aeb4d7b17e3fc82afb3 Mon Sep 17 00:00:00 2001 From: njtierney Date: Tue, 11 Mar 2025 16:26:41 +1100 Subject: [PATCH 11/11] add R6 adaptive hmc sampler class --- R/samplers.R | 66 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/R/samplers.R b/R/samplers.R index b1591e82..f5771953 100644 --- a/R/samplers.R +++ b/R/samplers.R @@ -309,3 +309,69 @@ slice_sampler <- R6Class( } ) ) + +adaptive_hmc_sampler <- R6Class( + "adaptive_hmc_sampler", + inherit = sampler, + public = list( + parameters = list( + # Lmin = 10, + # Lmax = 20, + max_leapfrog_steps = 1000, + # TODO clean up these parameter usage else where + # epsilon = 0.005, + # diag_sd = 1, + # TODO some kind of validity check of method? Currently this can only be + # "SNAPER". + method = "SNAPER" + ), + accept_target = 0.651, + + define_tf_kernel = function(sampler_param_vec) { + dag <- self$model$dag + tfe <- dag$tf_environment + + free_state_size <- length(sampler_param_vec) - 2 + + adaptive_hmc_max_leapfrog_steps <- tf$cast( + x = sampler_param_vec[0], + dtype = tf$int32 + ) + # TODO pipe that in properly + n_warmup <- sampler_param_vec[1] + # adaptive_hmc_epsilon <- sampler_param_vec[1] + # adaptive_hmc_diag_sd <- sampler_param_vec[2:(1+free_state_size)] + + kernel_base <- tfp$experimental$mcmc$SNAPERHamiltonianMonteCarlo( + target_log_prob_fn = dag$tf_log_prob_function_adjusted, + step_size = 1, + num_adaptation_steps = as.integer(self$warmup), + max_leapfrog_steps = adaptive_hmc_max_leapfrog_steps + ) + + sampler_kernel <- tfp$mcmc$DualAveragingStepSizeAdaptation( + inner_kernel = kernel_base, + num_adaptation_steps = as.integer(self$warmup) + ) + + return( + sampler_kernel + ) + }, + sampler_parameter_values = function() { + # random number of integration steps + max_leapfrog_steps <- self$parameters$max_leapfrog_steps + epsilon <- self$parameters$epsilon + diag_sd <- matrix(self$parameters$diag_sd) + method <- self$parameters$method + + # return named list for replacing tensors + list( + adaptive_hmc_max_leapfrog_steps = max_leapfrog_steps, + # adaptive_hmc_epsilon = epsilon, + # adaptive_hmc_diag_sd = diag_sd, + method = method + ) + } + ) +)