Skip to content

Commit cb14e95

Browse files
authored
Merge pull request #725 from njtierney/add-checkers-test-posteriors-723
Add checkers test posteriors 723
2 parents 595f7b2 + 1bf0946 commit cb14e95

10 files changed

+269
-178
lines changed

tests/testthat/helpers.R

+43-10
Original file line numberDiff line numberDiff line change
@@ -834,7 +834,7 @@ check_mvn_samples <- function(sampler, n_effective = 3000) {
834834
# away from truth. There's a 1/100 chance of any one of these scaled errors
835835
# being greater than qnorm(0.99) if the sampler is correct
836836
errors <- scaled_error(stat_draws, stat_truth)
837-
expect_lte(max(errors), stats::qnorm(0.99))
837+
errors
838838
}
839839

840840
# sample values of greta array 'x' (which must follow a distribution), and
@@ -864,19 +864,52 @@ check_samples <- function(
864864
iid_samples <- iid_function(neff)
865865
mcmc_samples <- as.matrix(draws)
866866

867-
# plot
868-
if (is.null(title)) {
869-
distrib <- get_node(x)$distribution$distribution_name
870-
sampler_name <- class(sampler)[1]
871-
title <- paste(distrib, "with", sampler_name)
872-
}
867+
# # plot
868+
# if (is.null(title)) {
869+
# distrib <- get_node(x)$distribution$distribution_name
870+
# sampler_name <- class(sampler)[1]
871+
# title <- paste(distrib, "with", sampler_name)
872+
# }
873+
874+
# stats::qqplot(mcmc_samples, iid_samples, main = title)
875+
# graphics::abline(0, 1)
876+
877+
# do a formal hypothesis test
878+
# suppressWarnings(stat <- ks.test(mcmc_samples, iid_samples))
879+
# testthat::expect_gte(stat$p.value, 0.01)
880+
881+
list(
882+
mcmc_samples = mcmc_samples,
883+
iid_samples = iid_samples,
884+
distrib = get_node(x)$distribution$distribution_name,
885+
sampler_name = class(sampler)[1]
886+
)
887+
}
888+
889+
qqplot_checked_samples <- function(checked_samples, title){
890+
891+
distrib <- checked_samples$distrib
892+
sampler_name <- checked_samples$sampler_name
893+
title <- paste(distrib, "with", sampler_name)
894+
895+
mcmc_samples <- checked_samples$mcmc_samples
896+
iid_samples <- checked_samples$iid_samples
897+
898+
stats::qqplot(
899+
x = mcmc_samples,
900+
y = iid_samples,
901+
main = title
902+
)
873903

874-
stats::qqplot(mcmc_samples, iid_samples, main = title)
875904
graphics::abline(0, 1)
905+
}
876906

907+
## helpers for running Kolmogorov-Smirnov test for MCMC samples vs IID samples
908+
ks_test_mcmc_vs_iid <- function(checked_samples){
877909
# do a formal hypothesis test
878-
suppressWarnings(stat <- ks.test(mcmc_samples, iid_samples))
879-
testthat::expect_gte(stat$p.value, 0.01)
910+
suppressWarnings(stat <- ks.test(checked_samples$mcmc_samples,
911+
checked_samples$iid_samples))
912+
stat
880913
}
881914

882915
## helpers for looping through optimisers

tests/testthat/test_posteriors.R

-163
This file was deleted.
+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
Sys.setenv("RELEASE_CANDIDATE" = "true")
2+
test_that("posterior is correct (binomial)", {
3+
skip_if_not(check_tf_version())
4+
5+
skip_if_not_release()
6+
7+
# analytic solution to the posterior of the paramter of a binomial
8+
# distribution, with uniform prior
9+
n <- 100
10+
pos <- rbinom(1, n, runif(1))
11+
theta <- uniform(0, 1)
12+
distribution(pos) <- binomial(n, theta)
13+
m <- model(theta)
14+
15+
draws <- get_enough_draws(m, hmc(), 2000, verbose = FALSE)
16+
17+
samples <- as.matrix(draws)
18+
19+
# analytic solution to posterior is beta(1 + pos, 1 + N - pos)
20+
shape1 <- 1 + pos
21+
shape2 <- 1 + n - pos
22+
23+
# qq plot against true quantiles
24+
quants <- (1:99) / 100
25+
q_target <- qbeta(quants, shape1, shape2)
26+
q_est <- quantile(samples, quants)
27+
plot(q_target ~ q_est, main = "binomial posterior")
28+
abline(0, 1)
29+
30+
n_draws <- round(coda::effectiveSize(draws))
31+
comparison <- rbeta(n_draws, shape1, shape2)
32+
suppressWarnings(test <- ks.test(samples, comparison))
33+
expect_gte(test$p.value, 0.01)
34+
})
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Currently takes about 30 seconds on an M1 mac
2+
Sys.setenv("RELEASE_CANDIDATE" = "false")
3+
4+
test_that("samplers are unbiased for bivariate normals", {
5+
skip_if_not(check_tf_version())
6+
7+
skip_if_not_release()
8+
9+
hmc_mvn_samples <- check_mvn_samples(sampler = hmc())
10+
expect_lte(max(hmc_mvn_samples), stats::qnorm(0.99))
11+
12+
rwmh_mvn_samples <- check_mvn_samples(sampler = rwmh())
13+
expect_lte(max(rwmh_mvn_samples), stats::qnorm(0.99))
14+
15+
slice_mvn_samples <- check_mvn_samples(sampler = slice())
16+
expect_lte(max(rwmh_mvn_samples), stats::qnorm(0.99))
17+
})
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
Sys.setenv("RELEASE_CANDIDATE" = "true")
2+
3+
test_that("samplers are unbiased for chi-squared", {
4+
skip_if_not(check_tf_version())
5+
6+
skip_if_not_release()
7+
8+
df <- 5
9+
x <- chi_squared(df)
10+
iid <- function(n) rchisq(n, df)
11+
12+
chi_squared_checked <- check_samples(x = x,
13+
iid_function = iid,
14+
sampler = hmc())
15+
16+
# do the plotting
17+
qqplot_checked_samples(chi_squared_checked)
18+
19+
# do a formal hypothesis test
20+
stat <- ks_test_mcmc_vs_iid(chi_squared_checked)
21+
22+
expect_gte(stat$p.value, 0.01)
23+
})
+66
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
Sys.setenv("RELEASE_CANDIDATE" = "false")
2+
3+
## TF1/2 - method for this test needs to be updated for TF2
4+
## See https://github.com/greta-dev/greta/issues/720
5+
test_that("samplers pass geweke tests", {
6+
skip_if_not(check_tf_version())
7+
8+
skip_if_not_release()
9+
10+
# nolint start
11+
# run geweke tests on this model:
12+
# theta ~ normal(mu1, sd1)
13+
# x[i] ~ normal(theta, sd2)
14+
# for i in N
15+
# nolint end
16+
17+
n <- 10
18+
mu1 <- rnorm(1, 0, 3)
19+
sd1 <- rlnorm(1)
20+
sd2 <- rlnorm(1)
21+
22+
# prior (n draws)
23+
p_theta <- function(n) {
24+
rnorm(n, mu1, sd1)
25+
}
26+
27+
# likelihood
28+
p_x_bar_theta <- function(theta) {
29+
rnorm(n, theta, sd2)
30+
}
31+
32+
# define the greta model (single precision for slice sampler)
33+
x <- as_data(rep(0, n))
34+
greta_theta <- normal(mu1, sd1)
35+
distribution(x) <- normal(greta_theta, sd2)
36+
model <- model(greta_theta, precision = "single")
37+
38+
# run tests on all available samplers
39+
check_geweke(
40+
sampler = hmc(),
41+
model = model,
42+
data = x,
43+
p_theta = p_theta,
44+
p_x_bar_theta = p_x_bar_theta,
45+
title = "HMC Geweke test"
46+
)
47+
48+
check_geweke(
49+
sampler = rwmh(),
50+
model = model,
51+
data = x,
52+
p_theta = p_theta,
53+
p_x_bar_theta = p_x_bar_theta,
54+
warmup = 2000,
55+
title = "RWMH Geweke test"
56+
)
57+
58+
check_geweke(
59+
sampler = slice(),
60+
model = model,
61+
data = x,
62+
p_theta = p_theta,
63+
p_x_bar_theta = p_x_bar_theta,
64+
title = "slice sampler Geweke test"
65+
)
66+
})

0 commit comments

Comments
 (0)