Skip to content

Commit

Permalink
Handle the case where either query points or observations have unspec…
Browse files Browse the repository at this point in the history
…ified leading dimension.
  • Loading branch information
avullo committed Aug 13, 2024
1 parent 43e6f01 commit 024ed95
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 1 deletion.
43 changes: 43 additions & 0 deletions tests/unit/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
# limitations under the License.
from __future__ import annotations

import contextlib
import copy

import numpy as np
import numpy.testing as npt
import pytest
import tensorflow as tf
Expand Down Expand Up @@ -68,6 +70,47 @@ def test_dataset_raises_for_different_leading_shapes(
Dataset(query_points, observations)


def test_dataset_does_not_raise_with_unspecified_leading_dimension() -> None:
@contextlib.contextmanager
def does_not_raise():
try:
yield
except Exception as e:
pytest.fail(f"An exception was raised: {e}")

query_points = tf.zeros((2, 2))
observations = tf.zeros((2, 1))

query_points_var = tf.Variable(
initial_value=np.zeros((0, 2)),
shape=(None, 2),
dtype=tf.float64,
)
observations_var = tf.Variable(
initial_value=np.zeros((0, 1)),
shape=(None, 1),
dtype=tf.float64,
)

with does_not_raise():
Dataset(
query_points=query_points_var,
observations=observations
)

with does_not_raise():
Dataset(
query_points=query_points,
observations=observations_var
)

with does_not_raise():
Dataset(
query_points=query_points_var,
observations=observations_var
)


@pytest.mark.parametrize(
"query_points_shape, observations_shape",
[
Expand Down
2 changes: 1 addition & 1 deletion trieste/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __post_init__(self) -> None:
if (
self.query_points.shape[:-1] != self.observations.shape[:-1]
# can't check dynamic shapes, so trust that they're ok (if not, they'll fail later)
and None not in self.query_points.shape[:-1]
and None not in self.query_points.shape[:-1] and None not in self.observations.shape[:-1]
):
raise ValueError(
f"Leading shapes of query_points and observations must match. Got shapes"
Expand Down

0 comments on commit 024ed95

Please sign in to comment.