Commit 435d9aa 1 parent a0672ad commit 435d9aa Copy full SHA for 435d9aa
File tree 2 files changed +20
-5
lines changed
2 files changed +20
-5
lines changed Original file line number Diff line number Diff line change @@ -321,6 +321,12 @@ sampler <- R6Class(
321
321
dummy_kernel_results <- dummy_kernel $ bootstrap_results(
322
322
init_state = dummy_init_state
323
323
)
324
+
325
+ dummy_kernel_results_tensor_spec <- tf $ nest $ map_structure(
326
+ maybe_make_tensor_shape ,
327
+ dummy_kernel_results
328
+ )
329
+
324
330
# use dummy results object to make a tensorspec or whatever
325
331
326
332
self $ tf_evaluate_sample_batch <- tensorflow :: tf_function(
@@ -345,11 +351,7 @@ sampler <- R6Class(
345
351
)
346
352
),
347
353
# kernel_results
348
- tf $ TensorSpec(
349
- shape = list (
350
- length(dummy_kernel_results )
351
- )
352
- )
354
+ dummy_kernel_results_tensor_spec
353
355
)
354
356
)
355
357
},
Original file line number Diff line number Diff line change @@ -1003,3 +1003,16 @@ n_warmup <- function(x){
1003
1003
x_info <- attr(x , " model_info" )
1004
1004
x_info $ warmup
1005
1005
}
1006
+
1007
+ build_tensor_spec <- function (tensor ){
1008
+ tf $ TensorSpec(shape = tensor $ shape ,
1009
+ dtype = tensor $ dtype )
1010
+ }
1011
+
1012
+ maybe_make_tensor_shape <- function (x ){
1013
+ if (tf $ is_tensor(x )) {
1014
+ build_tensor_spec(x )
1015
+ } else {
1016
+ x
1017
+ }
1018
+ }
You can’t perform that action at this time.
0 commit comments