diff --git a/R/samplers.R b/R/samplers.R index 629808fd..99a70974 100644 --- a/R/samplers.R +++ b/R/samplers.R @@ -94,6 +94,45 @@ slice <- function(max_doublings = 5) { obj } +#' @rdname samplers +#' @export +#' +#' @param epsilon leapfrog stepsize hyperparameter (positive, will be tuned) +#' @param diag_sd estimate of the posterior marginal standard deviations +#' (positive, will be tuned). +#' @param max_leapfrog_steps numeric. Default 1000. Maximum number of leapfrog +#' steps used. The algorithm will determine the optimal number less than this. +#' @param method character length one. Currently can only be "SNAPER" but in +#' the future this may expand to other adaptive samplers. +#' @details For `adaptive_hmc()`. The Lmin and Lmax parameters are learnt and so +#' not provided in this. The number of chains cannot be less than 2, due to +#' how adaptive HMC works. `diag_sd` is used to rescale the parameter space to +#' make it more uniform, and make sampling more efficient. +adaptive_hmc <- function( + max_leapfrog_steps = 1000, + epsilon = 0.1, + diag_sd = 1, + method = "SNAPER" +) { + method <- rlang::arg_match( + arg = method, + values = "SNAPER" + ) + + # nolint end + obj <- list( + parameters = list( + max_leapfrog_steps = max_leapfrog_steps, + epsilon = epsilon, + diag_sd = diag_sd + ), + class = adaptive_hmc_sampler + ) + class(obj) <- c("adaptive_hmc_sampler", "sampler") + obj +} + + #' @noRd #' @export print.sampler <- function(x, ...) { @@ -120,7 +159,6 @@ print.sampler <- function(x, ...) { cat(msg) } - hmc_sampler <- R6Class( "hmc_sampler", inherit = sampler, @@ -331,10 +369,11 @@ adaptive_hmc_sampler <- R6Class( dag <- self$model$dag tfe <- dag$tf_environment - free_state_size <- length(sampler_param_vec) - 2 + # free_state_size <- length(sampler_param_vec) - 2 + free_state_size <- length(sampler_param_vec) adaptive_hmc_max_leapfrog_steps <- tf$cast( - x = sampler_param_vec[0], + x = sampler_param_vec[1], dtype = tf$int32 ) # TODO pipe that in properly @@ -360,21 +399,19 @@ adaptive_hmc_sampler <- R6Class( ) }, - # given MCMC kernel `kernel` and initial model parameter state `init`, adapt - # the kernel tuning parameters whilst simultaneously burning-in the model - # parameter state. Return both finalised kernel tuning parameters and the - # burned-in model parameter state - warm_up_sampler = function(kernel, init) { - + # given MCMC kernel `sampler_kernel` and initial model parameter state + # `free_state`, adapt the kernel tuning parameters whilst simultaneously + # burning-in the model parameter state. Return both finalised kernel + # tuning parameters and the burned-in model parameter state + warm_up_sampler = function(sampler_kernel, n_adapt, free_state) { # get the predetermined adaptation period of the kernel - n_adapt <- kernel$num_adaptation_steps # make the uncompiled function (with curried arguments) warmup_raw <- function() { tfp$mcmc$sample_chain( num_results = n_adapt, - current_state = init, - kernel = kernel, + current_state = free_state, + kernel = sampler_kernel, return_final_kernel_results = TRUE, trace_fn = function(current_state, kernel_results) { kernel_results$step #kernel_results @@ -383,7 +420,7 @@ adaptive_hmc_sampler <- R6Class( } # compile it into a concrete function - warmup <- tf_function(warmup_raw) + warmup <- tensorflow::tf_function(warmup_raw) # execute it result <- warmup() @@ -391,12 +428,11 @@ adaptive_hmc_sampler <- R6Class( # return the last (burned-in) state of the model parameters and the final # (tuned) kernel parameters list( - kernel = kernel, + kernel = sampler_kernel, kernel_results = result$final_kernel_results, current_state = get_last_state(result$all_states) ) - - } + }, sampler_parameter_values = function() { # random number of integration steps @@ -414,159 +450,168 @@ adaptive_hmc_sampler <- R6Class( ) }, - # given a warmed up sampler object, return a compiled TF function - # that generates a new burst of samples from samples from it - make_sampler_function = function(warm_sampler) { - - # make the uncompiled function (with curried arguments) - sample_raw <- function(current_state, n_samples) { - results <- tfp$mcmc$sample_chain( - # how many iterations - num_results = n_samples, - # where to start from - current_state = current_state, - # kernel - kernel = warm_sampler$kernel, - # tuned sampler settings - previous_kernel_results = warm_sampler$kernel_results, - # what to trace (nothing) - trace_fn = function(current_state, kernel_results) { - # could compute badness here to save memory? - # is.finite(kernel_results$inner_results$inner_results$inner_results$log_accept_ratio) - kernel_results - } - ) - # return the parameter states and the kernel results - list( - all_states = results$all_states, - kernel_results = results$trace - ) - } - - # compile it into a concrete function and return - sample <- tf_function(sample_raw, - list( - as_tensorspec(warm_sampler$current_state), - tf$TensorSpec(shape = c(), - dtype = tf$int32) - )) - - sample - - }, - - run_warmup = function( - n_samples, - pb_update, - ideal_burst_size, - verbose - ) { - perform_warmup <- self$warmup > 0 - if (perform_warmup) { - # adapt and warm up - # self$kernel? - # self$init? - result <- self$warm_up_sampler(kernel, init) - } - - result - }, - - run_sampling = function( - n_samples, - pb_update, - ideal_burst_size, - trace_batch_size, - thin, - verbose - ) { - perform_sampling <- n_samples > 0 - if (perform_sampling) { - # on exiting during the main sampling period (even if killed by the - # user) trace the free state values - - on.exit(self$trace_values(trace_batch_size), add = TRUE) - - # main sampling - if (verbose) { - pb_sampling <- create_progress_bar( - phase = "sampling", - iter = c(self$warmup, n_samples), - pb_update = pb_update, - width = self$pb_width + # given a warmed up sampler object, return a compiled TF function + # that generates a new burst of samples from samples from it + make_sampler_function = function(warm_sampler) { + # make the uncompiled function (with curried arguments) + sample_raw <- function(current_state, n_samples) { + results <- tfp$mcmc$sample_chain( + # how many iterations + num_results = n_samples, + # where to start from + current_state = current_state, + # kernel + kernel = warm_sampler$kernel, + # tuned sampler settings + previous_kernel_results = warm_sampler$kernel_results, + # what to trace (nothing) + trace_fn = function(current_state, kernel_results) { + # could compute badness here to save memory? + # is.finite(kernel_results$inner_results$inner_results$inner_results$log_accept_ratio) + kernel_results + } ) - iterate_progress_bar( - pb = pb_sampling, - it = 0, - rejects = 0, - chains = self$n_chains, - file = self$pb_file + # return the parameter states and the kernel results + list( + all_states = results$all_states, + kernel_results = results$trace ) - } else { - pb_sampling <- NULL } - ### Adaptive start - print("Sampling parameters") - for (burst in seq_len(n_bursts)) { - burst_result <- sample( - current_state = current_state, - n_samples = burst_size + # compile it into a concrete function and return + sample <- tensorflow::tf_function( + sample_raw, + list( + as_tensorspec(warm_sampler$current_state), + tf$TensorSpec(shape = c(), dtype = tf$int32) ) + ) - # trace the MCMC results from this burst - burst_idx <- (burst - 1) * burst_size + seq_len(burst_size) - trace[burst_idx, , ] <- as.array(burst_result$all_states) - - # overwrite the current state - current_state <- get_last_state(burst_result$all_states) - - # accumulate and report on the badness - new_badness <- sum(bad_steps(burst_result$kernel_results)) - n_bad <- n_bad + new_badness - n_evaluations <- burst * burst_size * n_chains - perc_badness <- round(100 * n_bad / n_evaluations) - - # report on progress - print(sprintf("burst %i of %i (%i%s bad)", - burst, - n_bursts, - perc_badness, - "%")) + sample + }, + run_warmup = function( + n_samples, + pb_update, + ideal_burst_size, + verbose + ) { + perform_warmup <- self$warmup > 0 + if (perform_warmup) { + # adapt and warm up + param_vec <- unlist(self$sampler_parameter_values()) + sampler_kernel <- self$define_tf_kernel( + sampler_param_vec = param_vec + ) + init <- self$free_state + n_adapt <- as.integer(param_vec) + result <- self$warm_up_sampler( + sampler_kernel = sampler_kernel, + n_adapt = n_adapt, + free_state = init + ) } - ### Adaptive end - # split up warmup iterations into bursts of sampling - burst_lengths <- self$burst_lengths(n_samples, ideal_burst_size) - completed_iterations <- cumsum(burst_lengths) - - for (burst in seq_along(burst_lengths)) { - # so these bursts are R objects being passed through to python - # and how often to return them - # TF1/2 check todo - # replace with define_tf_draws - self$run_burst(n_samples = burst_lengths[burst], thin = thin) - # trace is it receiving the python - self$trace() + result + }, + run_sampling = function( + n_samples, + pb_update, + ideal_burst_size, + trace_batch_size, + thin, + verbose + ) { + perform_sampling <- n_samples > 0 + if (perform_sampling) { + # on exiting during the main sampling period (even if killed by the + # user) trace the free state values + + on.exit(self$trace_values(trace_batch_size), add = TRUE) + + # main sampling if (verbose) { - # update the progress bar/percentage log + pb_sampling <- create_progress_bar( + phase = "sampling", + iter = c(self$warmup, n_samples), + pb_update = pb_update, + width = self$pb_width + ) iterate_progress_bar( pb = pb_sampling, - it = completed_iterations[burst], - rejects = self$numerical_rejections, + it = 0, + rejects = 0, chains = self$n_chains, file = self$pb_file ) + } else { + pb_sampling <- NULL + } - self$write_percentage_log( - total = n_samples, - completed = completed_iterations[burst], - stage = "sampling" + ### Adaptive start + print("Sampling parameters") + for (burst in seq_len(n_bursts)) { + burst_result <- sample( + current_state = current_state, + n_samples = burst_size ) + + # trace the MCMC results from this burst + burst_idx <- (burst - 1) * burst_size + seq_len(burst_size) + trace[burst_idx, , ] <- as.array(burst_result$all_states) + + # overwrite the current state + current_state <- get_last_state(burst_result$all_states) + + # accumulate and report on the badness + new_badness <- sum(bad_steps(burst_result$kernel_results)) + n_bad <- n_bad + new_badness + n_evaluations <- burst * burst_size * n_chains + perc_badness <- round(100 * n_bad / n_evaluations) + + # report on progress + print(sprintf( + "burst %i of %i (%i%s bad)", + burst, + n_bursts, + perc_badness, + "%" + )) } - } - } # end sampling - }, + ### Adaptive end + + # split up warmup iterations into bursts of sampling + burst_lengths <- self$burst_lengths(n_samples, ideal_burst_size) + completed_iterations <- cumsum(burst_lengths) + + for (burst in seq_along(burst_lengths)) { + # so these bursts are R objects being passed through to python + # and how often to return them + # TF1/2 check todo + # replace with define_tf_draws + self$run_burst(n_samples = burst_lengths[burst], thin = thin) + # trace is it receiving the python + self$trace() + + if (verbose) { + # update the progress bar/percentage log + iterate_progress_bar( + pb = pb_sampling, + it = completed_iterations[burst], + rejects = self$numerical_rejections, + chains = self$n_chains, + file = self$pb_file + ) + + self$write_percentage_log( + total = n_samples, + completed = completed_iterations[burst], + stage = "sampling" + ) + } + } + } # end sampling + } + ) )