@@ -307,6 +307,15 @@ sampler <- R6Class(
307
307
},
308
308
309
309
define_tf_evaluate_sample_batch = function (){
310
+
311
+ # create a dummy sample_param_vec (vector with length as defined below)
312
+ # dummy_sampler_param_vec <- self$sampler_parameter_values()
313
+ # create dummy kernel using this, with:
314
+ # dummy_kernel <- self$define_tf_kernel(dummy_sampler_param_vec)
315
+ # use dummy kernel to bootrap a dummy results object
316
+ # dummy_kernel_results <- dummy_kernel$bootstrap_results()
317
+ # use dummy results object to make a tensorspec or whatever
318
+
310
319
self $ tf_evaluate_sample_batch <- tensorflow :: tf_function(
311
320
f = self $ define_tf_draws ,
312
321
input_signature = list (
@@ -327,6 +336,8 @@ sampler <- R6Class(
327
336
)
328
337
)
329
338
),
339
+ # kernel_results
340
+ kernel $ bootstrap_results()
330
341
dtype = tf_float()
331
342
)
332
343
)
@@ -744,7 +755,8 @@ sampler <- R6Class(
744
755
free_state = self $ free_state ,
745
756
sampler_burst_length = as.integer(n_samples ),
746
757
sampler_thin = as.integer(thin ),
747
- sampler_param_vec = param_vec
758
+ sampler_param_vec = param_vec ,
759
+ kernel_results = kernel_results
748
760
)
749
761
750
762
# get trace of free state and drop the null dimension
@@ -789,7 +801,8 @@ sampler <- R6Class(
789
801
sample_carefully = function (free_state ,
790
802
sampler_burst_length ,
791
803
sampler_thin ,
792
- sampler_param_vec ) {
804
+ sampler_param_vec ,
805
+ kernel_results ) {
793
806
794
807
# tryCatch handling for numerical errors
795
808
dag <- self $ model $ dag
@@ -799,6 +812,10 @@ sampler <- R6Class(
799
812
800
813
# ADPATIVE HMC
801
814
# TODO - this is where the adaptive_hmc fails at the moment
815
+
816
+ # so we can pass in the results from the previous kernel
817
+ dummy_kernel <- self $ define_tf_kernel()
818
+
802
819
result <- cleanly(
803
820
self $ tf_evaluate_sample_batch(
804
821
free_state = tensorflow :: as_tensor(
@@ -811,7 +828,8 @@ sampler <- R6Class(
811
828
sampler_param_vec ,
812
829
dtype = tf_float(),
813
830
shape = length(sampler_param_vec )
814
- )
831
+ ),
832
+ kernel_results = kernel_results
815
833
)
816
834
) # closing cleanly
817
835
@@ -1120,8 +1138,8 @@ adaptive_hmc_sampler <- R6Class(
1120
1138
# return named list for replacing tensors
1121
1139
list (
1122
1140
adaptive_hmc_max_leapfrog_steps = max_leapfrog_steps ,
1123
- adaptive_hmc_epsilon = epsilon ,
1124
- adaptive_hmc_diag_sd = diag_sd ,
1141
+ # adaptive_hmc_epsilon = epsilon,
1142
+ # adaptive_hmc_diag_sd = diag_sd,
1125
1143
method = method
1126
1144
)
1127
1145
}
0 commit comments