Skip to content

Commit 435d9aa

Browse files
committed
try creating a tensor spec for the kernel results using tf$nest$map_structure
1 parent a0672ad commit 435d9aa

File tree

2 files changed

+20
-5
lines changed

2 files changed

+20
-5
lines changed

R/inference_class.R

+7-5
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,12 @@ sampler <- R6Class(
321321
dummy_kernel_results <- dummy_kernel$bootstrap_results(
322322
init_state = dummy_init_state
323323
)
324+
325+
dummy_kernel_results_tensor_spec <- tf$nest$map_structure(
326+
maybe_make_tensor_shape,
327+
dummy_kernel_results
328+
)
329+
324330
# use dummy results object to make a tensorspec or whatever
325331

326332
self$tf_evaluate_sample_batch <- tensorflow::tf_function(
@@ -345,11 +351,7 @@ sampler <- R6Class(
345351
)
346352
),
347353
# kernel_results
348-
tf$TensorSpec(
349-
shape = list(
350-
length(dummy_kernel_results)
351-
)
352-
)
354+
dummy_kernel_results_tensor_spec
353355
)
354356
)
355357
},

R/utils.R

+13
Original file line numberDiff line numberDiff line change
@@ -1003,3 +1003,16 @@ n_warmup <- function(x){
10031003
x_info <- attr(x, "model_info")
10041004
x_info$warmup
10051005
}
1006+
1007+
build_tensor_spec <- function(tensor){
1008+
tf$TensorSpec(shape = tensor$shape,
1009+
dtype = tensor$dtype)
1010+
}
1011+
1012+
maybe_make_tensor_shape <- function(x){
1013+
if (tf$is_tensor(x)) {
1014+
build_tensor_spec(x)
1015+
} else{
1016+
x
1017+
}
1018+
}

0 commit comments

Comments
 (0)