Skip to content

Commit

Permalink
Add more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Uri Granta committed Jul 30, 2024
1 parent 6bc5d4d commit 74122e5
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 13 deletions.
4 changes: 2 additions & 2 deletions tests/unit/models/gpflow/test_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from tests.util.misc import random_seed
from trieste.data import Dataset
from trieste.models.gpflow import BatchReparametrizationSampler, GPflowPredictor
from trieste.space import CategoricalSearchSpace
from trieste.space import CategoricalSearchSpace, one_hot_encoder


class _QuadraticPredictor(GPflowPredictor):
Expand Down Expand Up @@ -118,7 +118,7 @@ def test_gpflow_reparam_sampler_returns_reparam_sampler_with_correct_samples() -
def test_gpflow_categorical_predict() -> None:
search_space = CategoricalSearchSpace(["Red", "Green", "Blue"])
query_points = search_space.sample(10)
model = _QuadraticPredictor(encoder=search_space.one_hot_encoder)
model = _QuadraticPredictor(encoder=one_hot_encoder(search_space))
mean, variance = model.predict(query_points)
assert mean.shape == [10, 1]
assert variance.shape == [10, 1]
Expand Down
79 changes: 72 additions & 7 deletions tests/unit/test_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
CollectionSearchSpace,
Constraint,
DiscreteSearchSpace,
GeneralDiscreteSearchSpace,
LinearConstraint,
NonlinearConstraint,
SearchSpace,
Expand Down Expand Up @@ -210,22 +211,47 @@ def test_discrete_search_space_contains_raises_for_invalid_shapes(


@pytest.mark.parametrize("num_samples", [0, 1, 3, 5, 6, 10, 20])
def test_discrete_search_space_sampling(num_samples: int) -> None:
search_space = DiscreteSearchSpace(_points_in_2D_search_space())
@pytest.mark.parametrize(
"search_space",
[
pytest.param(DiscreteSearchSpace(_points_in_2D_search_space()), id="DiscreteSearchSpace"),
pytest.param(CategoricalSearchSpace([3, 2]), id="CategoricalSearchSpace"),
],
)
def test_discrete_search_space_sampling(
search_space: GeneralDiscreteSearchSpace, num_samples: int
) -> None:
samples = search_space.sample(num_samples)
assert all(sample in search_space for sample in samples)
assert len(samples) == num_samples


@pytest.mark.parametrize("seed", [1, 42, 123])
def test_discrete_search_space_sampling_returns_same_points_for_same_seed(seed: int) -> None:
search_space = DiscreteSearchSpace(_points_in_2D_search_space())
@pytest.mark.parametrize(
"search_space",
[
pytest.param(DiscreteSearchSpace(_points_in_2D_search_space()), id="DiscreteSearchSpace"),
pytest.param(CategoricalSearchSpace([3, 2]), id="CategoricalSearchSpace"),
],
)
def test_discrete_search_space_sampling_returns_same_points_for_same_seed(
search_space: GeneralDiscreteSearchSpace, seed: int
) -> None:
random_samples_1 = search_space.sample(num_samples=100, seed=seed)
random_samples_2 = search_space.sample(num_samples=100, seed=seed)
npt.assert_allclose(random_samples_1, random_samples_2)


def test_discrete_search_space_sampling_returns_different_points_for_different_call() -> None:
@pytest.mark.parametrize(
"search_space",
[
pytest.param(DiscreteSearchSpace(_points_in_2D_search_space()), id="DiscreteSearchSpace"),
pytest.param(CategoricalSearchSpace([3, 2]), id="CategoricalSearchSpace"),
],
)
def test_discrete_search_space_sampling_returns_different_points_for_different_call(
search_space: GeneralDiscreteSearchSpace,
) -> None:
search_space = DiscreteSearchSpace(_points_in_2D_search_space())
random_samples_1 = search_space.sample(num_samples=100)
random_samples_2 = search_space.sample(num_samples=100)
Expand Down Expand Up @@ -1596,6 +1622,45 @@ def test_box_empty_halton_sampling_returns_correct_dtype(dtype: tf.DType) -> Non
assert sobol_samples.dtype == dtype


@pytest.mark.parametrize(
"categories, points",
[
pytest.param([], tf.zeros([0, 0])),
pytest.param(3, tf.constant([[0], [1], [2]])),
pytest.param([3], tf.constant([[0], [1], [2]])),
pytest.param([3, 2], tf.constant([[0, 0], [0, 1], [1, 0], [1, 1], [2, 0], [2, 1]])),
pytest.param(["R", "G", "B"], tf.constant([[0], [1], [2]])),
pytest.param([["R", "G", "B"]], tf.constant([[0], [1], [2]])),
pytest.param(
[["R", "G", "B"], ["Y", "N"]],
tf.constant([[0, 0], [0, 1], [1, 0], [1, 1], [2, 0], [2, 1]]),
),
],
)
def test_categorical_search_space__points(
categories: int | Sequence[int] | Sequence[str] | Sequence[Sequence[str]], points: TensorType
) -> None:
space = CategoricalSearchSpace(categories)
npt.assert_array_equal(space.points, points)
contains = space.contains(points)
assert len(contains) == len(points)
assert tf.reduce_all(contains)


@pytest.mark.parametrize(
"categories, exception",
[
pytest.param([3, 0, 2], ValueError, id="Empty category size"),
pytest.param([-2, 1], ValueError, id="Negative category size"),
pytest.param([["R", "G", "B"], [], ["Y", "N"]], ValueError, id="Empty category list"),
pytest.param([3, ["A", "B"]], TypeError, id="Mixed description types"),
],
)
def test_categorical_search_space__raises(categories: Any, exception: type) -> None:
with pytest.raises(exception):
CategoricalSearchSpace(categories)


@pytest.mark.parametrize(
"search_space, query_points, encoded_points",
[
Expand Down Expand Up @@ -1662,9 +1727,9 @@ def test_categorical_search_space_one_hot_encoding(
),
],
)
def test_category_one_hot_encoding_value_errors(
def test_categorical_search_space_one_hot_encoding__raises(
search_space: CategoricalSearchSpace, query_points: TensorType, exception: type
) -> None:
encoder = one_hot_encoder(search_space)
with pytest.raises(exception):
encoder = search_space.one_hot_encoder
encoder(query_points)
20 changes: 16 additions & 4 deletions trieste/space.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,11 +547,21 @@ def __init__(self, categories: int | Sequence[int] | Sequence[str] | Sequence[Se
if isinstance(categories, int) or any(isinstance(x, str) for x in categories):
categories = [categories] # type: ignore[assignment]

assert isinstance(categories, Sequence)
if all(isinstance(x, int) for x in categories):
if not isinstance(categories, Sequence) or not (
all(
isinstance(x, Sequence)
and not isinstance(x, str)
and all(isinstance(y, str) for y in x)
for x in categories
)
or all(isinstance(x, int) for x in categories)
):
raise TypeError("Invalid category description: expected either numbers or names.")

elif any(isinstance(x, int) for x in categories):
category_lens: Sequence[int] = categories # type: ignore[assignment]
if any(x <= 0 for x in category_lens):
raise ValueError("Number of categories must be positive")
raise ValueError("Numbers of categories must be positive")
tags = [tuple(f"{i}" for i in range(n)) for n in category_lens]
else:
category_names: Sequence[Sequence[str]] = categories # type: ignore[assignment]
Expand All @@ -564,7 +574,9 @@ def __init__(self, categories: int | Sequence[int] | Sequence[str] | Sequence[Se
# TODO: inherit from GridSearchSpace to avoid generating the points explicitly?
ranges = [tf.range(len(ts)) for ts in tags]
meshgrid = tf.meshgrid(*ranges, indexing="ij")
points = tf.reshape(tf.stack(meshgrid, axis=-1), [-1, len(tags)])
points = (
tf.reshape(tf.stack(meshgrid, axis=-1), [-1, len(tags)]) if tags else tf.zeros([0, 0])
)

super().__init__(points)

Expand Down

0 comments on commit 74122e5

Please sign in to comment.