From 1a2e0acbecae87836c1c2dd5efa512c9e0bdd256 Mon Sep 17 00:00:00 2001 From: uri-granta <50578464+uri-granta@users.noreply.github.com> Date: Wed, 9 Oct 2024 09:33:50 +0100 Subject: [PATCH] Improve dataset_len error message (#879) Co-authored-by: Uri Granta --- trieste/ask_tell_optimization.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/trieste/ask_tell_optimization.py b/trieste/ask_tell_optimization.py index 4052ed2ab..f2c2da6d5 100644 --- a/trieste/ask_tell_optimization.py +++ b/trieste/ask_tell_optimization.py @@ -435,16 +435,18 @@ def acquisition_state(self) -> StateType | None: @classmethod def dataset_len(cls, datasets: Mapping[Tag, Dataset]) -> int: """Helper method for inferring the global dataset size.""" - dataset_lens = [ - tf.shape(dataset.query_points)[0] + dataset_lens = { + tag: int(tf.shape(dataset.query_points)[0]) for tag, dataset in datasets.items() if not LocalizedTag.from_tag(tag).is_local - ] - unique_lens, _ = tf.unique(dataset_lens) + } + unique_lens, _ = tf.unique(list(dataset_lens.values())) if len(unique_lens) == 1: return int(unique_lens[0]) else: - raise ValueError(f"Expected unique global dataset size, got {unique_lens}") + raise ValueError( + f"Expected unique global dataset size, got {unique_lens}: {dataset_lens}" + ) @classmethod def from_record(