Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adaptive hmc v2 i765 #778

Closed
wants to merge 12 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions R/samplers.R
Original file line number Diff line number Diff line change
Expand Up @@ -309,3 +309,69 @@ slice_sampler <- R6Class(
}
)
)

adaptive_hmc_sampler <- R6Class(
"adaptive_hmc_sampler",
inherit = sampler,
public = list(
parameters = list(
# Lmin = 10,
# Lmax = 20,
max_leapfrog_steps = 1000,
# TODO clean up these parameter usage else where
# epsilon = 0.005,
# diag_sd = 1,
# TODO some kind of validity check of method? Currently this can only be
# "SNAPER".
method = "SNAPER"
),
accept_target = 0.651,

define_tf_kernel = function(sampler_param_vec) {
dag <- self$model$dag
tfe <- dag$tf_environment

free_state_size <- length(sampler_param_vec) - 2

adaptive_hmc_max_leapfrog_steps <- tf$cast(
x = sampler_param_vec[0],
dtype = tf$int32
)
# TODO pipe that in properly
n_warmup <- sampler_param_vec[1]
# adaptive_hmc_epsilon <- sampler_param_vec[1]
# adaptive_hmc_diag_sd <- sampler_param_vec[2:(1+free_state_size)]

kernel_base <- tfp$experimental$mcmc$SNAPERHamiltonianMonteCarlo(
target_log_prob_fn = dag$tf_log_prob_function_adjusted,
step_size = 1,
num_adaptation_steps = as.integer(self$warmup),
max_leapfrog_steps = adaptive_hmc_max_leapfrog_steps
)

sampler_kernel <- tfp$mcmc$DualAveragingStepSizeAdaptation(
inner_kernel = kernel_base,
num_adaptation_steps = as.integer(self$warmup)
)

return(
sampler_kernel
)
},
sampler_parameter_values = function() {
# random number of integration steps
max_leapfrog_steps <- self$parameters$max_leapfrog_steps
epsilon <- self$parameters$epsilon
diag_sd <- matrix(self$parameters$diag_sd)
method <- self$parameters$method

# return named list for replacing tensors
list(
adaptive_hmc_max_leapfrog_steps = max_leapfrog_steps,
# adaptive_hmc_epsilon = epsilon,
# adaptive_hmc_diag_sd = diag_sd,
method = method
)
}
)
)
82 changes: 82 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -1079,3 +1079,85 @@ n_warmup <- function(x) {
x_info <- attr(x, "model_info")
x_info$warmup
}

as_tensor_spec <- function(tensor) {
tf$TensorSpec(shape = tensor$shape, dtype = tensor$dtype)
}

maybe_make_tensor_shape <- function(x) {
if (tf$is_tensor(x)) {
as_tensor_spec(x)
} else {
x
}
}

# get the final model parameter state from a chain as returned in the all_states
# object from tfp$mcmc$sample_chain
get_last_state <- function(all_states) {
n_iter <- dim(all_states)[1]
tf$gather(all_states, n_iter - 1L, 0L)
}

# find out if MCMC steps had non-finite acceptance probabilities
bad_steps <- function(kernel_results) {
log_accept_ratios <- recursive_get_log_accept_ratio(kernel_results)
!is.finite(log_accept_ratios)
}


# recursively extract the log accaptance ratio from the MCMC kernel
recursive_get_log_accept_ratio <- function(kernel_results) {
nm <- names(kernel_results)
if ("log_accept_ratio" %in% nm) {
log_accept_ratios <- kernel_results$log_accept_ratio
} else if ("inner_results" %in% nm) {
log_accept_ratios <- recursive_get_log_accept_ratio(
kernel_results$inner_results
)
} else {
stop("non-standard kernel structure")
}
as.array(log_accept_ratios)
}


# given a warmed up sampler object, return a compiled TF function that generates
# a new burst of samples from samples from it
make_sampler_function <- function(warm_sampler) {
# make the uncompiled function (with curried arguments)
sample_raw <- function(current_state, n_samples) {
results <- tfp$mcmc$sample_chain(
# how many iterations
num_results = n_samples,
# where to start from
current_state = current_state,
# kernel
kernel = warm_sampler$kernel,
# tuned sampler settings
previous_kernel_results = warm_sampler$kernel_results,
# what to trace (nothing)
trace_fn = function(current_state, kernel_results) {
# could compute badness here to save memory?
# is.finite(kernel_results$inner_results$inner_results$inner_results$log_accept_ratio)
kernel_results
}
)
# return the parameter states and the kernel results
list(
all_states = results$all_states,
kernel_results = results$trace
)
}

# compile it into a concrete function and return
sample <- tf_function(
sample_raw,
list(
as_tensorspec(warm_sampler$current_state),
tf$TensorSpec(shape = c(), dtype = tf$int32)
)
)

sample
}
2 changes: 2 additions & 0 deletions tests/testthat/helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -714,6 +714,7 @@ p_theta_greta <- function(
data,
p_theta,
p_x_bar_theta,
# TODO note that we might want to change this to adaptive_hmc()
sampler = hmc(),
warmup = 1000
) {
Expand Down Expand Up @@ -934,6 +935,7 @@ get_distribution_name <- function(x) {
check_samples <- function(
x,
iid_function,
# TODO note that we might want to change this to adaptive_hmc
sampler = hmc(),
n_effective = 3000,
title = NULL,
Expand Down
61 changes: 61 additions & 0 deletions tests/testthat/test-adaptive-hmc.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
set.seed(2025 - 02 - 13)

test_that("bad mcmc proposals are rejected", {
skip_if_not(check_tf_version())

# set up for numerical rejection of initial location
x <- rnorm(10000, 1e60, 1)
z <- normal(-1e60, 1e-60)
distribution(x) <- normal(z, 1e-60)
m <- model(z, precision = "single")

# # catch badness in the progress bar
out <- get_output(
mcmc(m, n_samples = 10, warmup = 0, pb_update = 10)
)
expect_match(out, "100% bad")

expect_snapshot(
error = TRUE,
draws <- mcmc(
m,
chains = 1,
n_samples = 2,
warmup = 0,
verbose = FALSE,
sampler = adaptive_hmc(),
initial_values = initials(z = 1e120)
)
)

# really bad proposals
x <- rnorm(100000, 1e120, 1)
z <- normal(-1e120, 1e-120)
distribution(x) <- normal(z, 1e-120)
m <- model(z, precision = "single")
expect_snapshot(
error = TRUE,
mcmc(m,
chains = 1,
n_samples = 1,
warmup = 0,
sampler = adaptive_hmc(),
verbose = FALSE)
)

# proposals that are fine, but rejected anyway
z <- normal(0, 1)
m <- model(z, precision = "single")
expect_ok(mcmc(
m,
adaptive_hmc(
epsilon = 100,
# Lmin = 1,
# Lmax = 1
),
chains = 1,
n_samples = 5,
warmup = 0,
verbose = FALSE
))
})
3 changes: 3 additions & 0 deletions tests/testthat/test_inference.R
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,9 @@ test_that("samplers print informatively", {
expect_snapshot(
hmc(Lmin = 1)
)
expect_snapshot(
adaptive_hmc(Lmin = 1)
)

# # check print sees changed parameters
# out <- capture_output(hmc(Lmin = 1), TRUE)
Expand Down
32 changes: 32 additions & 0 deletions tests/testthat/test_posteriors_binomial.R
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,35 @@ test_that("posterior is correct (binomial)", {
suppressWarnings(test <- ks.test(samples, comparison))
expect_gte(test$p.value, 0.01)
})

test_that("posterior is correct (binomial) with adaptive_hmc", {
skip_if_not(check_tf_version())
skip_on_os("windows")
# analytic solution to the posterior of the paramter of a binomial
# distribution, with uniform prior
n <- 100
pos <- rbinom(1, n, runif(1))
theta <- uniform(0, 1)
distribution(pos) <- binomial(n, theta)
m <- model(theta)

draws <- get_enough_draws(m, adaptive_hmc(), 2000, verbose = FALSE)

samples <- as.matrix(draws)

# analytic solution to posterior is beta(1 + pos, 1 + N - pos)
shape1 <- 1 + pos
shape2 <- 1 + n - pos

# qq plot against true quantiles
quants <- (1:99) / 100
q_target <- qbeta(quants, shape1, shape2)
q_est <- quantile(samples, quants)
plot(q_target ~ q_est, main = "binomial posterior")
abline(0, 1)

n_draws <- round(coda::effectiveSize(draws))
comparison <- rbeta(n_draws, shape1, shape2)
suppressWarnings(test <- ks.test(samples, comparison))
expect_gte(test$p.value, 0.01)
})
22 changes: 22 additions & 0 deletions tests/testthat/test_posteriors_chi_squared.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,25 @@ test_that("samplers are unbiased for chi-squared", {

expect_gte(stat$p.value, 0.01)
})

test_that("samplers are unbiased for chi-squared", {
skip_if_not(check_tf_version())
skip_on_os("windows")
df <- 5
x <- chi_squared(df)
iid <- function(n) rchisq(n, df)

chi_squared_checked <- check_samples(
x = x,
iid_function = iid,
sampler = adaptive_hmc()
)

# do the plotting
qqplot_checked_samples(chi_squared_checked)

# do a formal hypothesis test
stat <- ks_test_mcmc_vs_iid(chi_squared_checked)

expect_gte(stat$p.value, 0.01)
})
Loading
Loading