Skip to content

Commit 10f76e3

Browse files
committed
NG notes on appropriately handling the kernel
1 parent 188bf1d commit 10f76e3

File tree

1 file changed

+23
-5
lines changed

1 file changed

+23
-5
lines changed

R/inference_class.R

+23-5
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,15 @@ sampler <- R6Class(
307307
},
308308

309309
define_tf_evaluate_sample_batch = function(){
310+
311+
# create a dummy sample_param_vec (vector with length as defined below)
312+
# dummy_sampler_param_vec <- self$sampler_parameter_values()
313+
# create dummy kernel using this, with:
314+
# dummy_kernel <- self$define_tf_kernel(dummy_sampler_param_vec)
315+
# use dummy kernel to bootrap a dummy results object
316+
# dummy_kernel_results <- dummy_kernel$bootstrap_results()
317+
# use dummy results object to make a tensorspec or whatever
318+
310319
self$tf_evaluate_sample_batch <- tensorflow::tf_function(
311320
f = self$define_tf_draws,
312321
input_signature = list(
@@ -327,6 +336,8 @@ sampler <- R6Class(
327336
)
328337
)
329338
),
339+
# kernel_results
340+
kernel$bootstrap_results()
330341
dtype = tf_float()
331342
)
332343
)
@@ -744,7 +755,8 @@ sampler <- R6Class(
744755
free_state = self$free_state,
745756
sampler_burst_length = as.integer(n_samples),
746757
sampler_thin = as.integer(thin),
747-
sampler_param_vec = param_vec
758+
sampler_param_vec = param_vec,
759+
kernel_results = kernel_results
748760
)
749761

750762
# get trace of free state and drop the null dimension
@@ -789,7 +801,8 @@ sampler <- R6Class(
789801
sample_carefully = function(free_state,
790802
sampler_burst_length,
791803
sampler_thin,
792-
sampler_param_vec) {
804+
sampler_param_vec,
805+
kernel_results) {
793806

794807
# tryCatch handling for numerical errors
795808
dag <- self$model$dag
@@ -799,6 +812,10 @@ sampler <- R6Class(
799812

800813
# ADPATIVE HMC
801814
# TODO - this is where the adaptive_hmc fails at the moment
815+
816+
# so we can pass in the results from the previous kernel
817+
dummy_kernel <- self$define_tf_kernel()
818+
802819
result <- cleanly(
803820
self$tf_evaluate_sample_batch(
804821
free_state = tensorflow::as_tensor(
@@ -811,7 +828,8 @@ sampler <- R6Class(
811828
sampler_param_vec,
812829
dtype = tf_float(),
813830
shape = length(sampler_param_vec)
814-
)
831+
),
832+
kernel_results = kernel_results
815833
)
816834
) # closing cleanly
817835

@@ -1120,8 +1138,8 @@ adaptive_hmc_sampler <- R6Class(
11201138
# return named list for replacing tensors
11211139
list(
11221140
adaptive_hmc_max_leapfrog_steps = max_leapfrog_steps,
1123-
adaptive_hmc_epsilon = epsilon,
1124-
adaptive_hmc_diag_sd = diag_sd,
1141+
# adaptive_hmc_epsilon = epsilon,
1142+
# adaptive_hmc_diag_sd = diag_sd,
11251143
method = method
11261144
)
11271145
}

0 commit comments

Comments
 (0)