Skip to content

Commit cfdc095

Browse files
authored
Merge pull request #743 from njtierney/update-mcmc-seed-docs-736
strike a compromise for #736
2 parents 7d1f57e + 6d0f9fd commit cfdc095

File tree

3 files changed

+68
-5
lines changed

3 files changed

+68
-5
lines changed

R/inference.R

+24-2
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,9 @@ greta_stash$numerical_messages <- c(
108108
#' argument `trace_batch_size` can be modified to trade-off speed against
109109
#' memory usage.
110110
#'
111-
#' @note to set a seed with MCMC you must use [tensorflow::set_random_seed()].
112-
#' This is due to an internal API with tensorflow. See \url{https://github.com/greta-dev/greta/issues/559} for a thread exploring this.
111+
#' @note to set a seed with MCMC you can use [set.seed()], or
112+
#' [tensorflow::set_random_seed()]. They both given identical results. See
113+
#' examples below.
113114
#'
114115
#' @return `mcmc`, `stashed_samples` & `extra_samples` - a
115116
#' `greta_mcmc_list` object that can be analysed using functions from the
@@ -183,6 +184,27 @@ greta_stash$numerical_messages <- c(
183184
#' m3 <- model(params)
184185
#' o <- opt(m3, hessian = TRUE)
185186
#' o$hessian
187+
#'
188+
#' # using set.seed or tensorflow::set_random_seed to set RNG for MCMC
189+
#' a <- normal(0, 1)
190+
#' y <- normal(a, 1)
191+
#' m <- model(y)
192+
#'
193+
#' set.seed(12345)
194+
#' one <- mcmc(m, n_samples = 1, chains = 1)
195+
#' set.seed(12345)
196+
#' two <- mcmc(m, n_samples = 1, chains = 1)
197+
#' # same
198+
#' all.equal(as.numeric(one), as.numeric(two))
199+
#' tensorflow::set_random_seed(12345)
200+
#' one_tf <- mcmc(m, n_samples = 1, chains = 1)
201+
#' tensorflow::set_random_seed(12345)
202+
#' two_tf <- mcmc(m, n_samples = 1, chains = 1)
203+
#' # same
204+
#' all.equal(as.numeric(one_tf), as.numeric(two_tf))
205+
#' # different
206+
#' all.equal(as.numeric(one), as.numeric(one_tf))
207+
#'
186208
#' }
187209
mcmc <- function(
188210
model,

man/inference.Rd

+24-2
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test_seed.R

+20-1
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ test_that("calculate samples are the same when the R seed is the same", {
108108
)
109109
})
110110

111-
test_that("mcmc samples are the same when the R seed is the same", {
111+
test_that("mcmc samples are the same when the R seed is the same, also with tf set seed", {
112112
skip_if_not(check_tf_version())
113113
a <- normal(0, 1)
114114
y <- normal(a, 1)
@@ -123,6 +123,25 @@ test_that("mcmc samples are the same when the R seed is the same", {
123123
as.numeric(one),
124124
as.numeric(two)
125125
)
126+
127+
tensorflow::set_random_seed(12345)
128+
one_tf <- mcmc(m, warmup = 10, n_samples = 1, chains = 1)
129+
tensorflow::set_random_seed(12345)
130+
two_tf <- mcmc(m, warmup = 10, n_samples = 1, chains = 1)
131+
132+
expect_equal(
133+
as.numeric(one_tf),
134+
as.numeric(two_tf)
135+
)
136+
137+
# but these are not (always) equal to each other
138+
mcmc_matches_tf_one <- identical(as.numeric(one),as.numeric(one_tf))
139+
mcmc_matches_tf_two <- identical(as.numeric(two),as.numeric(two_tf))
140+
141+
expect_false(mcmc_matches_tf_one)
142+
143+
expect_false(mcmc_matches_tf_two)
144+
126145
})
127146

128147
test_that("simulate uses the local RNG seed", {

0 commit comments

Comments
 (0)