diff --git a/tests/unit/models/gpflow/test_interface.py b/tests/unit/models/gpflow/test_interface.py index 5eb78f143..5dd74e6f0 100644 --- a/tests/unit/models/gpflow/test_interface.py +++ b/tests/unit/models/gpflow/test_interface.py @@ -24,7 +24,7 @@ from tests.util.misc import random_seed from trieste.data import Dataset from trieste.models.gpflow import BatchReparametrizationSampler, GPflowPredictor -from trieste.space import CategoricalSearchSpace +from trieste.space import CategoricalSearchSpace, one_hot_encoder class _QuadraticPredictor(GPflowPredictor): @@ -118,7 +118,7 @@ def test_gpflow_reparam_sampler_returns_reparam_sampler_with_correct_samples() - def test_gpflow_categorical_predict() -> None: search_space = CategoricalSearchSpace(["Red", "Green", "Blue"]) query_points = search_space.sample(10) - model = _QuadraticPredictor(encoder=search_space.one_hot_encoder) + model = _QuadraticPredictor(encoder=one_hot_encoder(search_space)) mean, variance = model.predict(query_points) assert mean.shape == [10, 1] assert variance.shape == [10, 1] diff --git a/tests/unit/test_space.py b/tests/unit/test_space.py index f393b934f..745cdb753 100644 --- a/tests/unit/test_space.py +++ b/tests/unit/test_space.py @@ -32,6 +32,7 @@ CollectionSearchSpace, Constraint, DiscreteSearchSpace, + GeneralDiscreteSearchSpace, LinearConstraint, NonlinearConstraint, SearchSpace, @@ -210,22 +211,47 @@ def test_discrete_search_space_contains_raises_for_invalid_shapes( @pytest.mark.parametrize("num_samples", [0, 1, 3, 5, 6, 10, 20]) -def test_discrete_search_space_sampling(num_samples: int) -> None: - search_space = DiscreteSearchSpace(_points_in_2D_search_space()) +@pytest.mark.parametrize( + "search_space", + [ + pytest.param(DiscreteSearchSpace(_points_in_2D_search_space()), id="DiscreteSearchSpace"), + pytest.param(CategoricalSearchSpace([3, 2]), id="CategoricalSearchSpace"), + ], +) +def test_discrete_search_space_sampling( + search_space: GeneralDiscreteSearchSpace, num_samples: int +) -> None: samples = search_space.sample(num_samples) assert all(sample in search_space for sample in samples) assert len(samples) == num_samples @pytest.mark.parametrize("seed", [1, 42, 123]) -def test_discrete_search_space_sampling_returns_same_points_for_same_seed(seed: int) -> None: - search_space = DiscreteSearchSpace(_points_in_2D_search_space()) +@pytest.mark.parametrize( + "search_space", + [ + pytest.param(DiscreteSearchSpace(_points_in_2D_search_space()), id="DiscreteSearchSpace"), + pytest.param(CategoricalSearchSpace([3, 2]), id="CategoricalSearchSpace"), + ], +) +def test_discrete_search_space_sampling_returns_same_points_for_same_seed( + search_space: GeneralDiscreteSearchSpace, seed: int +) -> None: random_samples_1 = search_space.sample(num_samples=100, seed=seed) random_samples_2 = search_space.sample(num_samples=100, seed=seed) npt.assert_allclose(random_samples_1, random_samples_2) -def test_discrete_search_space_sampling_returns_different_points_for_different_call() -> None: +@pytest.mark.parametrize( + "search_space", + [ + pytest.param(DiscreteSearchSpace(_points_in_2D_search_space()), id="DiscreteSearchSpace"), + pytest.param(CategoricalSearchSpace([3, 2]), id="CategoricalSearchSpace"), + ], +) +def test_discrete_search_space_sampling_returns_different_points_for_different_call( + search_space: GeneralDiscreteSearchSpace, +) -> None: search_space = DiscreteSearchSpace(_points_in_2D_search_space()) random_samples_1 = search_space.sample(num_samples=100) random_samples_2 = search_space.sample(num_samples=100) @@ -1596,6 +1622,45 @@ def test_box_empty_halton_sampling_returns_correct_dtype(dtype: tf.DType) -> Non assert sobol_samples.dtype == dtype +@pytest.mark.parametrize( + "categories, points", + [ + pytest.param([], tf.zeros([0, 0])), + pytest.param(3, tf.constant([[0], [1], [2]])), + pytest.param([3], tf.constant([[0], [1], [2]])), + pytest.param([3, 2], tf.constant([[0, 0], [0, 1], [1, 0], [1, 1], [2, 0], [2, 1]])), + pytest.param(["R", "G", "B"], tf.constant([[0], [1], [2]])), + pytest.param([["R", "G", "B"]], tf.constant([[0], [1], [2]])), + pytest.param( + [["R", "G", "B"], ["Y", "N"]], + tf.constant([[0, 0], [0, 1], [1, 0], [1, 1], [2, 0], [2, 1]]), + ), + ], +) +def test_categorical_search_space__points( + categories: int | Sequence[int] | Sequence[str] | Sequence[Sequence[str]], points: TensorType +) -> None: + space = CategoricalSearchSpace(categories) + npt.assert_array_equal(space.points, points) + contains = space.contains(points) + assert len(contains) == len(points) + assert tf.reduce_all(contains) + + +@pytest.mark.parametrize( + "categories, exception", + [ + pytest.param([3, 0, 2], ValueError, id="Empty category size"), + pytest.param([-2, 1], ValueError, id="Negative category size"), + pytest.param([["R", "G", "B"], [], ["Y", "N"]], ValueError, id="Empty category list"), + pytest.param([3, ["A", "B"]], TypeError, id="Mixed description types"), + ], +) +def test_categorical_search_space__raises(categories: Any, exception: type) -> None: + with pytest.raises(exception): + CategoricalSearchSpace(categories) + + @pytest.mark.parametrize( "search_space, query_points, encoded_points", [ @@ -1662,9 +1727,9 @@ def test_categorical_search_space_one_hot_encoding( ), ], ) -def test_category_one_hot_encoding_value_errors( +def test_categorical_search_space_one_hot_encoding__raises( search_space: CategoricalSearchSpace, query_points: TensorType, exception: type ) -> None: + encoder = one_hot_encoder(search_space) with pytest.raises(exception): - encoder = search_space.one_hot_encoder encoder(query_points) diff --git a/trieste/space.py b/trieste/space.py index 3c8fdb67f..60c5782dd 100644 --- a/trieste/space.py +++ b/trieste/space.py @@ -547,11 +547,21 @@ def __init__(self, categories: int | Sequence[int] | Sequence[str] | Sequence[Se if isinstance(categories, int) or any(isinstance(x, str) for x in categories): categories = [categories] # type: ignore[assignment] - assert isinstance(categories, Sequence) - if all(isinstance(x, int) for x in categories): + if not isinstance(categories, Sequence) or not ( + all( + isinstance(x, Sequence) + and not isinstance(x, str) + and all(isinstance(y, str) for y in x) + for x in categories + ) + or all(isinstance(x, int) for x in categories) + ): + raise TypeError("Invalid category description: expected either numbers or names.") + + elif any(isinstance(x, int) for x in categories): category_lens: Sequence[int] = categories # type: ignore[assignment] if any(x <= 0 for x in category_lens): - raise ValueError("Number of categories must be positive") + raise ValueError("Numbers of categories must be positive") tags = [tuple(f"{i}" for i in range(n)) for n in category_lens] else: category_names: Sequence[Sequence[str]] = categories # type: ignore[assignment] @@ -564,7 +574,9 @@ def __init__(self, categories: int | Sequence[int] | Sequence[str] | Sequence[Se # TODO: inherit from GridSearchSpace to avoid generating the points explicitly? ranges = [tf.range(len(ts)) for ts in tags] meshgrid = tf.meshgrid(*ranges, indexing="ij") - points = tf.reshape(tf.stack(meshgrid, axis=-1), [-1, len(tags)]) + points = ( + tf.reshape(tf.stack(meshgrid, axis=-1), [-1, len(tags)]) if tags else tf.zeros([0, 0]) + ) super().__init__(points)