Skip to content

Commit add554c

Browse files
committed
abstract warmup and sampling into run_warmup and run_sampling
1 parent e74444c commit add554c

File tree

2 files changed

+56
-23
lines changed

2 files changed

+56
-23
lines changed

R/sampler_class.R

+56-16
Original file line numberDiff line numberDiff line change
@@ -131,22 +131,48 @@ sampler <- R6Class(
131131

132132
# create these objects if needed
133133
if (from_scratch) {
134-
self$traced_free_state <- empty_matrices(n = self$n_chains,
134+
self$traced_free_state <- self$empty_matrices(n = self$n_chains,
135135
ncol = self$n_free)
136136

137-
self$traced_values <- empty_matrices(n = self$n_chains,
137+
self$traced_values <- self$empty_matrices(n = self$n_chains,
138138
ncol = self$n_traced)
139139
}
140140

141141
# how big would we like the bursts to be
142142
ideal_burst_size <- ifelse(one_by_one, 1L, pb_update)
143143

144-
# if warmup is required, do that now
145-
if (warmup > 0) {
144+
self$run_warmup(
145+
n_samples = n_samples,
146+
pb_update = pb_update,
147+
ideal_burst_size = ideal_burst_size,
148+
verbose = verbose
149+
)
150+
151+
self$run_sampling(
152+
n_samples = n_samples,
153+
pb_update = pb_update,
154+
ideal_burst_size = ideal_burst_size,
155+
trace_batch_size = trace_batch_size,
156+
thin = thin,
157+
verbose = verbose
158+
)
159+
160+
# return self, to send results back when running in parallel
161+
self
162+
},
163+
164+
run_warmup = function(
165+
n_samples,
166+
pb_update,
167+
ideal_burst_size,
168+
verbose
169+
) {
170+
perform_warmup <- self$warmup > 0
171+
if (perform_warmup) {
146172
if (verbose) {
147173
pb_warmup <- create_progress_bar(
148174
"warmup",
149-
c(warmup, n_samples),
175+
c(self$warmup, n_samples),
150176
pb_update,
151177
self$pb_width
152178
)
@@ -157,7 +183,7 @@ sampler <- R6Class(
157183
}
158184

159185
# split up warmup iterations into bursts of sampling
160-
burst_lengths <- self$burst_lengths(warmup,
186+
burst_lengths <- self$burst_lengths(self$warmup,
161187
ideal_burst_size,
162188
warmup = TRUE)
163189

@@ -178,7 +204,7 @@ sampler <- R6Class(
178204
self$trace()
179205
# a memory efficient way to calculate summary stats of samples
180206
self$update_welford()
181-
self$tune(completed_iterations[burst], warmup)
207+
self$tune(completed_iterations[burst], self$warmup)
182208

183209
if (verbose) {
184210

@@ -190,21 +216,31 @@ sampler <- R6Class(
190216
file = self$pb_file
191217
)
192218

193-
self$write_percentage_log(warmup,
219+
self$write_percentage_log(self$warmup,
194220
completed_iterations[burst],
195221
stage = "warmup"
196222
)
197223
}
198224
}
199225

200226
# scrub the free state trace and numerical rejections
201-
self$traced_free_state <- empty_matrices(n = self$n_chains,
227+
self$traced_free_state <- self$empty_matrices(n = self$n_chains,
202228
ncol = self$n_free)
203229

204230
self$numerical_rejections <- 0
205-
}
231+
} # end warmup
232+
},
206233

207-
if (n_samples > 0) {
234+
run_sampling = function (
235+
n_samples,
236+
pb_update,
237+
ideal_burst_size,
238+
trace_batch_size,
239+
thin,
240+
verbose
241+
){
242+
perform_sampling <- n_samples > 0
243+
if (perform_sampling) {
208244

209245
# on exiting during the main sampling period (even if killed by the
210246
# user) trace the free state values
@@ -215,7 +251,7 @@ sampler <- R6Class(
215251
if (verbose) {
216252
pb_sampling <- create_progress_bar(
217253
"sampling",
218-
c(warmup, n_samples),
254+
c(self$warmup, n_samples),
219255
pb_update,
220256
self$pb_width
221257
)
@@ -254,10 +290,8 @@ sampler <- R6Class(
254290
)
255291
}
256292
}
257-
}
293+
} # end sampling
258294

259-
# return self, to send results back when running in parallel
260-
self
261295
},
262296

263297
# update the welford accumulator for summary statistics of the posterior,
@@ -616,6 +650,12 @@ sampler <- R6Class(
616650

617651
# random number of integration steps
618652
self$parameters
619-
}
653+
},
654+
empty_matrices = function(n,
655+
ncol){
656+
replicate(n = n,
657+
matrix(data = NA, nrow = 0, ncol = ncol),
658+
simplify = FALSE)
659+
}
620660
)
621661
)

R/utils.R

-7
Original file line numberDiff line numberDiff line change
@@ -1080,10 +1080,3 @@ n_warmup <- function(x){
10801080
x_info <- attr(x, "model_info")
10811081
x_info$warmup
10821082
}
1083-
1084-
empty_matrices <- function(n,
1085-
ncol){
1086-
replicate(n = n,
1087-
matrix(data = NA, nrow = 0, ncol = ncol),
1088-
simplify = FALSE)
1089-
}

0 commit comments

Comments
 (0)