Skip to content

Commit a0672ad

Browse files
committed
trying to pass through the appropriate TensorSpec
1 parent 10f76e3 commit a0672ad

File tree

1 file changed

+16
-6
lines changed

1 file changed

+16
-6
lines changed

R/inference_class.R

+16-6
Original file line numberDiff line numberDiff line change
@@ -307,13 +307,20 @@ sampler <- R6Class(
307307
},
308308

309309
define_tf_evaluate_sample_batch = function(){
310+
browser()
311+
312+
dummy_init_state <- matrix(data = 0,
313+
nrow = nrow(self$free_state),
314+
ncol = ncol(self$free_state))
310315

311316
# create a dummy sample_param_vec (vector with length as defined below)
312-
# dummy_sampler_param_vec <- self$sampler_parameter_values()
317+
dummy_sampler_param_vec <- length(unlist(self$sampler_parameter_values()))
313318
# create dummy kernel using this, with:
314-
# dummy_kernel <- self$define_tf_kernel(dummy_sampler_param_vec)
319+
dummy_kernel <- self$define_tf_kernel(dummy_sampler_param_vec)
315320
# use dummy kernel to bootrap a dummy results object
316-
# dummy_kernel_results <- dummy_kernel$bootstrap_results()
321+
dummy_kernel_results <- dummy_kernel$bootstrap_results(
322+
init_state = dummy_init_state
323+
)
317324
# use dummy results object to make a tensorspec or whatever
318325

319326
self$tf_evaluate_sample_batch <- tensorflow::tf_function(
@@ -335,10 +342,13 @@ sampler <- R6Class(
335342
self$sampler_parameter_values()
336343
)
337344
)
338-
),
345+
)
346+
),
339347
# kernel_results
340-
kernel$bootstrap_results()
341-
dtype = tf_float()
348+
tf$TensorSpec(
349+
shape = list(
350+
length(dummy_kernel_results)
351+
)
342352
)
343353
)
344354
)

0 commit comments

Comments
 (0)