From 024ed95d9fa98930915bd584d357f769a989d1d4 Mon Sep 17 00:00:00 2001 From: Alessandro Vullo Date: Tue, 13 Aug 2024 13:26:04 +0100 Subject: [PATCH] Handle the case where either query points or observations have unspecified leading dimension. --- tests/unit/test_data.py | 43 +++++++++++++++++++++++++++++++++++++++++ trieste/data.py | 2 +- 2 files changed, 44 insertions(+), 1 deletion(-) diff --git a/tests/unit/test_data.py b/tests/unit/test_data.py index 78c1c6df42..be10470990 100644 --- a/tests/unit/test_data.py +++ b/tests/unit/test_data.py @@ -13,8 +13,10 @@ # limitations under the License. from __future__ import annotations +import contextlib import copy +import numpy as np import numpy.testing as npt import pytest import tensorflow as tf @@ -68,6 +70,47 @@ def test_dataset_raises_for_different_leading_shapes( Dataset(query_points, observations) +def test_dataset_does_not_raise_with_unspecified_leading_dimension() -> None: + @contextlib.contextmanager + def does_not_raise(): + try: + yield + except Exception as e: + pytest.fail(f"An exception was raised: {e}") + + query_points = tf.zeros((2, 2)) + observations = tf.zeros((2, 1)) + + query_points_var = tf.Variable( + initial_value=np.zeros((0, 2)), + shape=(None, 2), + dtype=tf.float64, + ) + observations_var = tf.Variable( + initial_value=np.zeros((0, 1)), + shape=(None, 1), + dtype=tf.float64, + ) + + with does_not_raise(): + Dataset( + query_points=query_points_var, + observations=observations + ) + + with does_not_raise(): + Dataset( + query_points=query_points, + observations=observations_var + ) + + with does_not_raise(): + Dataset( + query_points=query_points_var, + observations=observations_var + ) + + @pytest.mark.parametrize( "query_points_shape, observations_shape", [ diff --git a/trieste/data.py b/trieste/data.py index 6c979a30e0..3a505997a2 100644 --- a/trieste/data.py +++ b/trieste/data.py @@ -52,7 +52,7 @@ def __post_init__(self) -> None: if ( self.query_points.shape[:-1] != self.observations.shape[:-1] # can't check dynamic shapes, so trust that they're ok (if not, they'll fail later) - and None not in self.query_points.shape[:-1] + and None not in self.query_points.shape[:-1] and None not in self.observations.shape[:-1] ): raise ValueError( f"Leading shapes of query_points and observations must match. Got shapes"