@@ -131,22 +131,48 @@ sampler <- R6Class(
131
131
132
132
# create these objects if needed
133
133
if (from_scratch ) {
134
- self $ traced_free_state <- empty_matrices(n = self $ n_chains ,
134
+ self $ traced_free_state <- self $ empty_matrices(n = self $ n_chains ,
135
135
ncol = self $ n_free )
136
136
137
- self $ traced_values <- empty_matrices(n = self $ n_chains ,
137
+ self $ traced_values <- self $ empty_matrices(n = self $ n_chains ,
138
138
ncol = self $ n_traced )
139
139
}
140
140
141
141
# how big would we like the bursts to be
142
142
ideal_burst_size <- ifelse(one_by_one , 1L , pb_update )
143
143
144
- # if warmup is required, do that now
145
- if (warmup > 0 ) {
144
+ self $ run_warmup(
145
+ n_samples = n_samples ,
146
+ pb_update = pb_update ,
147
+ ideal_burst_size = ideal_burst_size ,
148
+ verbose = verbose
149
+ )
150
+
151
+ self $ run_sampling(
152
+ n_samples = n_samples ,
153
+ pb_update = pb_update ,
154
+ ideal_burst_size = ideal_burst_size ,
155
+ trace_batch_size = trace_batch_size ,
156
+ thin = thin ,
157
+ verbose = verbose
158
+ )
159
+
160
+ # return self, to send results back when running in parallel
161
+ self
162
+ },
163
+
164
+ run_warmup = function (
165
+ n_samples ,
166
+ pb_update ,
167
+ ideal_burst_size ,
168
+ verbose
169
+ ) {
170
+ perform_warmup <- self $ warmup > 0
171
+ if (perform_warmup ) {
146
172
if (verbose ) {
147
173
pb_warmup <- create_progress_bar(
148
174
" warmup" ,
149
- c(warmup , n_samples ),
175
+ c(self $ warmup , n_samples ),
150
176
pb_update ,
151
177
self $ pb_width
152
178
)
@@ -157,7 +183,7 @@ sampler <- R6Class(
157
183
}
158
184
159
185
# split up warmup iterations into bursts of sampling
160
- burst_lengths <- self $ burst_lengths(warmup ,
186
+ burst_lengths <- self $ burst_lengths(self $ warmup ,
161
187
ideal_burst_size ,
162
188
warmup = TRUE )
163
189
@@ -178,7 +204,7 @@ sampler <- R6Class(
178
204
self $ trace()
179
205
# a memory efficient way to calculate summary stats of samples
180
206
self $ update_welford()
181
- self $ tune(completed_iterations [burst ], warmup )
207
+ self $ tune(completed_iterations [burst ], self $ warmup )
182
208
183
209
if (verbose ) {
184
210
@@ -190,21 +216,31 @@ sampler <- R6Class(
190
216
file = self $ pb_file
191
217
)
192
218
193
- self $ write_percentage_log(warmup ,
219
+ self $ write_percentage_log(self $ warmup ,
194
220
completed_iterations [burst ],
195
221
stage = " warmup"
196
222
)
197
223
}
198
224
}
199
225
200
226
# scrub the free state trace and numerical rejections
201
- self $ traced_free_state <- empty_matrices(n = self $ n_chains ,
227
+ self $ traced_free_state <- self $ empty_matrices(n = self $ n_chains ,
202
228
ncol = self $ n_free )
203
229
204
230
self $ numerical_rejections <- 0
205
- }
231
+ } # end warmup
232
+ },
206
233
207
- if (n_samples > 0 ) {
234
+ run_sampling = function (
235
+ n_samples ,
236
+ pb_update ,
237
+ ideal_burst_size ,
238
+ trace_batch_size ,
239
+ thin ,
240
+ verbose
241
+ ){
242
+ perform_sampling <- n_samples > 0
243
+ if (perform_sampling ) {
208
244
209
245
# on exiting during the main sampling period (even if killed by the
210
246
# user) trace the free state values
@@ -215,7 +251,7 @@ sampler <- R6Class(
215
251
if (verbose ) {
216
252
pb_sampling <- create_progress_bar(
217
253
" sampling" ,
218
- c(warmup , n_samples ),
254
+ c(self $ warmup , n_samples ),
219
255
pb_update ,
220
256
self $ pb_width
221
257
)
@@ -254,10 +290,8 @@ sampler <- R6Class(
254
290
)
255
291
}
256
292
}
257
- }
293
+ } # end sampling
258
294
259
- # return self, to send results back when running in parallel
260
- self
261
295
},
262
296
263
297
# update the welford accumulator for summary statistics of the posterior,
@@ -616,6 +650,12 @@ sampler <- R6Class(
616
650
617
651
# random number of integration steps
618
652
self $ parameters
619
- }
653
+ },
654
+ empty_matrices = function (n ,
655
+ ncol ){
656
+ replicate(n = n ,
657
+ matrix (data = NA , nrow = 0 , ncol = ncol ),
658
+ simplify = FALSE )
659
+ }
620
660
)
621
661
)
0 commit comments