Skip to content

Commit

Permalink
Categorical search spaces (#863)
Browse files Browse the repository at this point in the history
  • Loading branch information
uri-granta authored Aug 8, 2024
1 parent f67c85c commit 43e6f01
Show file tree
Hide file tree
Showing 2 changed files with 533 additions and 36 deletions.
282 changes: 275 additions & 7 deletions tests/unit/test_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,23 @@
import numpy.testing as npt
import pytest
import tensorflow as tf
from tensorflow.python.framework.errors_impl import InvalidArgumentError
from typing_extensions import Final

from tests.util.misc import TF_DEBUGGING_ERROR_TYPES, ShapeLike, various_shapes
from trieste.space import (
Box,
CategoricalSearchSpace,
CollectionSearchSpace,
Constraint,
DiscreteSearchSpace,
GeneralDiscreteSearchSpace,
LinearConstraint,
NonlinearConstraint,
SearchSpace,
TaggedMultiSearchSpace,
TaggedProductSearchSpace,
one_hot_encoder,
)
from trieste.types import TensorType

Expand All @@ -44,6 +48,10 @@ def __init__(self, exclusive_limit: int):
assert exclusive_limit > 0
self.limit: Final[int] = exclusive_limit

@property
def has_bounds(self) -> bool:
return True

@property
def lower(self) -> None:
pass
Expand Down Expand Up @@ -183,6 +191,7 @@ def test_discrete_search_space_returns_correct_dimension(
def test_discrete_search_space_returns_correct_bounds(
space: DiscreteSearchSpace, lower: tf.Tensor, upper: tf.Tensor
) -> None:
assert space.has_bounds
npt.assert_array_equal(space.lower, lower)
npt.assert_array_equal(space.upper, upper)

Expand All @@ -207,22 +216,47 @@ def test_discrete_search_space_contains_raises_for_invalid_shapes(


@pytest.mark.parametrize("num_samples", [0, 1, 3, 5, 6, 10, 20])
def test_discrete_search_space_sampling(num_samples: int) -> None:
search_space = DiscreteSearchSpace(_points_in_2D_search_space())
@pytest.mark.parametrize(
"search_space",
[
pytest.param(DiscreteSearchSpace(_points_in_2D_search_space()), id="DiscreteSearchSpace"),
pytest.param(CategoricalSearchSpace([3, 2]), id="CategoricalSearchSpace"),
],
)
def test_discrete_search_space_sampling(
search_space: GeneralDiscreteSearchSpace, num_samples: int
) -> None:
samples = search_space.sample(num_samples)
assert all(sample in search_space for sample in samples)
assert len(samples) == num_samples


@pytest.mark.parametrize("seed", [1, 42, 123])
def test_discrete_search_space_sampling_returns_same_points_for_same_seed(seed: int) -> None:
search_space = DiscreteSearchSpace(_points_in_2D_search_space())
@pytest.mark.parametrize(
"search_space",
[
pytest.param(DiscreteSearchSpace(_points_in_2D_search_space()), id="DiscreteSearchSpace"),
pytest.param(CategoricalSearchSpace([3, 2]), id="CategoricalSearchSpace"),
],
)
def test_discrete_search_space_sampling_returns_same_points_for_same_seed(
search_space: GeneralDiscreteSearchSpace, seed: int
) -> None:
random_samples_1 = search_space.sample(num_samples=100, seed=seed)
random_samples_2 = search_space.sample(num_samples=100, seed=seed)
npt.assert_allclose(random_samples_1, random_samples_2)


def test_discrete_search_space_sampling_returns_different_points_for_different_call() -> None:
@pytest.mark.parametrize(
"search_space",
[
pytest.param(DiscreteSearchSpace(_points_in_2D_search_space()), id="DiscreteSearchSpace"),
pytest.param(CategoricalSearchSpace([3, 2]), id="CategoricalSearchSpace"),
],
)
def test_discrete_search_space_sampling_returns_different_points_for_different_call(
search_space: GeneralDiscreteSearchSpace,
) -> None:
search_space = DiscreteSearchSpace(_points_in_2D_search_space())
random_samples_1 = search_space.sample(num_samples=100)
random_samples_2 = search_space.sample(num_samples=100)
Expand Down Expand Up @@ -380,6 +414,7 @@ def test_box_returns_correct_dimension(space: Box, dimension: int) -> None:
def test_box_bounds_attributes() -> None:
lower, upper = tf.zeros([2]), tf.ones([2])
box = Box(lower, upper)
assert box.has_bounds
npt.assert_array_equal(box.lower, lower)
npt.assert_array_equal(box.upper, upper)

Expand Down Expand Up @@ -898,6 +933,7 @@ def test_collection_space_contains_raises_on_point_of_different_shape(
[
(TaggedMultiSearchSpace, Box([-1, -2], [2, 3]), [2, 2]),
(TaggedProductSearchSpace, Box([-1], [2]), [3]),
(TaggedProductSearchSpace, CategoricalSearchSpace(["A", "B", "C"]), [3]),
],
)
@pytest.mark.parametrize("num_samples", [0, 1, 10])
Expand Down Expand Up @@ -1081,6 +1117,7 @@ def test_product_space_returns_correct_bounds(
spaces: Sequence[SearchSpace], lower: tf.Tensor, upper: tf.Tensor
) -> None:
for space in (TaggedProductSearchSpace(spaces=spaces), reduce(operator.mul, spaces)):
assert space.has_bounds
npt.assert_array_equal(space.lower, lower)
npt.assert_array_equal(space.upper, upper)

Expand Down Expand Up @@ -1150,10 +1187,10 @@ def test_product_space_fix_subspace_doesnt_fix_undesired_subspace(points: tf.Ten
[
Box([-1, -2], [2, 3]),
DiscreteSearchSpace(tf.constant([[-0.5]])),
Box([-1], [2]),
CategoricalSearchSpace([3, 2]),
],
["A", "B", "C"],
{"A": [0, 2], "B": [2, 3], "C": [3, 4]},
{"A": [0, 2], "B": [2, 3], "C": [3, 5]},
),
],
)
Expand Down Expand Up @@ -1438,6 +1475,11 @@ def _nlc_func(x: TensorType) -> TensorType:
),
False,
),
(CategoricalSearchSpace([3, 2]), CategoricalSearchSpace([3, 2]), True),
(CategoricalSearchSpace([3]), CategoricalSearchSpace([3, 2]), False),
(CategoricalSearchSpace(3), CategoricalSearchSpace(["0", "1", "2"]), True),
(CategoricalSearchSpace(3), CategoricalSearchSpace(["R", "G", "B"]), False),
(CategoricalSearchSpace(3), DiscreteSearchSpace(tf.constant([[0], [1]])), False),
],
)
def test___eq___search_spaces(a: SearchSpace, b: SearchSpace, equal: bool) -> None:
Expand Down Expand Up @@ -1591,3 +1633,229 @@ def test_box_empty_halton_sampling_returns_correct_dtype(dtype: tf.DType) -> Non
box = Box(tf.zeros((3,), dtype=dtype), tf.ones((3,), dtype=dtype))
sobol_samples = box.sample_halton(0)
assert sobol_samples.dtype == dtype


@pytest.mark.parametrize(
"categories, points",
[
pytest.param([], tf.zeros([0, 0])),
pytest.param(3, tf.constant([[0.0], [1.0], [2.0]])),
pytest.param([3], tf.constant([[0.0], [1.0], [2.0]])),
pytest.param(
[3, 2],
tf.constant([[0.0, 0.0], [0.0, 1.0], [1.0, 0.0], [1.0, 1.0], [2.0, 0.0], [2.0, 1.0]]),
),
pytest.param(["R", "G", "B"], tf.constant([[0.0], [1.0], [2.0]])),
pytest.param([["R", "G", "B"]], tf.constant([[0.0], [1.0], [2.0]])),
pytest.param(
[["R", "G", "B"], ["Y", "N"]],
tf.constant([[0.0, 0.0], [0.0, 1.0], [1.0, 0.0], [1.0, 1.0], [2.0, 0.0], [2.0, 1.0]]),
),
],
)
def test_categorical_search_space__points(
categories: int | Sequence[int] | Sequence[str] | Sequence[Sequence[str]], points: TensorType
) -> None:
space = CategoricalSearchSpace(categories, dtype=tf.float32)
npt.assert_array_equal(space.points, points)
contains = space.contains(points)
assert len(contains) == len(points)
assert tf.reduce_all(contains)


@pytest.mark.parametrize(
"categories, exception",
[
pytest.param([3, 0, 2], ValueError, id="Empty category size"),
pytest.param([-2, 1], ValueError, id="Negative category size"),
pytest.param([["R", "G", "B"], [], ["Y", "N"]], ValueError, id="Empty category list"),
pytest.param([3, ["A", "B"]], TypeError, id="Mixed description types"),
],
)
def test_categorical_search_space__raises(categories: Any, exception: type) -> None:
with pytest.raises(exception):
CategoricalSearchSpace(categories)


@pytest.mark.parametrize(
"space1, space2, expected_product",
[
(CategoricalSearchSpace(3), CategoricalSearchSpace(2), CategoricalSearchSpace([3, 2])),
(CategoricalSearchSpace([]), CategoricalSearchSpace(2), CategoricalSearchSpace([2])),
(
CategoricalSearchSpace(["R", "G", "B"]),
CategoricalSearchSpace(["Y", "N"]),
CategoricalSearchSpace([["R", "G", "B"], ["Y", "N"]]),
),
],
)
def test_categorical_search_space__product(
space1: CategoricalSearchSpace,
space2: CategoricalSearchSpace,
expected_product: CategoricalSearchSpace,
) -> None:
product = space1 * space2
assert product == expected_product
npt.assert_array_equal(product.points, expected_product.points)
assert product.tags == expected_product.tags


@pytest.mark.parametrize(
"categories, tags",
[
([], []),
(3, [("0", "1", "2")]),
([2, 3], [("0", "1"), ("0", "1", "2")]),
(["R", "G", "B"], [("R", "G", "B")]),
([["R", "G", "B"], ["Y", "N"]], [("R", "G", "B"), ("Y", "N")]),
],
)
def test_categorical_search_space__tags(
categories: int | Sequence[int] | Sequence[str] | Sequence[Sequence[str]],
tags: Sequence[Sequence[str]],
) -> None:
search_space = CategoricalSearchSpace(categories)
assert search_space.tags == tags


@pytest.mark.parametrize(
"categories, indices, expected_tags",
[
(3, tf.constant([[0], [2], [2]]), tf.constant([["0"], ["2"], ["2"]])),
(["A", "B", "C"], tf.constant([[0.0], [2.0], [2.0]]), tf.constant([["A"], ["C"], ["C"]])),
(
(3, 2),
tf.constant([[0, 1.0], [2.0, 0.0], [2.0, 1.0]]),
tf.constant([["0", "1"], ["2", "0"], ["2", "1"]]),
),
(
[("A", "B", "C"), ("Y", "N")],
tf.constant([[0.0, 1.0], [2.0, 0.0], [2.0, 1.0]]),
tf.constant([["A", "N"], ["C", "Y"], ["C", "N"]]),
),
],
)
def test_categorical_search_space__to_tags(
categories: int | Sequence[int] | Sequence[str] | Sequence[Sequence[str]],
indices: TensorType,
expected_tags: TensorType,
) -> None:
search_space = CategoricalSearchSpace(categories)
tags = search_space.to_tags(indices)
npt.assert_array_equal(tags, expected_tags)


def test_categorical_search_space__to_tags_raises_for_non_integers() -> None:
search_space = CategoricalSearchSpace(["A", "B", "C"], dtype=tf.float32)
with pytest.raises(ValueError):
search_space.to_tags(tf.constant([[1.0], [1.2]]))


@pytest.mark.parametrize(
"search_space, query_points, encoded_points",
[
(
CategoricalSearchSpace(["V"]),
tf.constant([[0], [0]]),
tf.constant([[1], [1]], dtype=tf.float32),
),
(
CategoricalSearchSpace(["R", "G", "B"]),
tf.constant([[0], [2], [1]]),
tf.constant([[1, 0, 0], [0, 0, 1], [0, 1, 0]], dtype=tf.float32),
),
(
CategoricalSearchSpace(["R", "G", "B"]),
tf.constant([[[[[0]]]]]),
tf.constant([[[[[1, 0, 0]]]]], dtype=tf.float32),
),
(
CategoricalSearchSpace(["R", "G", "B", "A"]),
tf.constant([[0], [2], [2]]),
tf.constant([[1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0]], dtype=tf.float32),
),
(
CategoricalSearchSpace([["R", "G", "B"], ["Y", "N"]]),
tf.constant([[0, 0], [2, 0], [1, 1]]),
tf.constant([[1, 0, 0, 1, 0], [0, 0, 1, 1, 0], [0, 1, 0, 0, 1]], dtype=tf.float32),
),
(
CategoricalSearchSpace([["R", "G", "B"], ["Y", "N"]]),
tf.constant([[[0, 0], [0, 0]], [[2, 0], [1, 1]]]),
tf.constant(
[[[1, 0, 0, 1, 0], [1, 0, 0, 1, 0]], [[0, 0, 1, 1, 0], [0, 1, 0, 0, 1]]],
dtype=tf.float32,
),
),
(
TaggedProductSearchSpace([Box([0.0], [1.0]), CategoricalSearchSpace(["R", "G", "B"])]),
tf.constant([[0.5, 0], [0.3, 2]]),
tf.constant([[0.5, 1, 0, 0], [0.3, 0, 0, 1]], dtype=tf.float32),
),
(
TaggedProductSearchSpace([Box([0.0], [1.0]), CategoricalSearchSpace(["R", "G", "B"])]),
tf.constant([[[0.5, 0]], [[0.3, 2]]]),
tf.constant([[[0.5, 1, 0, 0]], [[0.3, 0, 0, 1]]], dtype=tf.float32),
),
(
Box([0.0], [1.0]),
tf.constant([[0.5], [0.3]]),
tf.constant([[0.5], [0.3]], dtype=tf.float32),
),
],
)
def test_categorical_search_space_one_hot_encoding(
search_space: SearchSpace, query_points: TensorType, encoded_points: TensorType
) -> None:
encoder = one_hot_encoder(search_space)
points = encoder(query_points)
npt.assert_array_equal(encoded_points, points)


@pytest.mark.parametrize(
"search_space, query_points, exception",
[
pytest.param(
CategoricalSearchSpace(["Y", "N"]),
tf.constant([0, 2, 1]),
InvalidArgumentError,
id="Wrong input rank",
),
pytest.param(
CategoricalSearchSpace(["Y", "N"]),
tf.constant([[0], [2], [1]]),
InvalidArgumentError,
id="Out of range input value",
),
pytest.param(
CategoricalSearchSpace([["R", "G", "B"], ["Y", "N"]]),
tf.constant([[0], [1], [1]]),
ValueError,
id="Wrong input shape",
),
],
)
def test_categorical_search_space_one_hot_encoding__raises(
search_space: CategoricalSearchSpace, query_points: TensorType, exception: type
) -> None:
encoder = one_hot_encoder(search_space)
with pytest.raises(exception):
encoder(query_points)


@pytest.mark.parametrize(
"space",
[
CategoricalSearchSpace([3, 2]),
CategoricalSearchSpace(["R", "G", "B"]),
TaggedProductSearchSpace([Box([-1, -2], [2, 3]), CategoricalSearchSpace(2)]),
],
)
def test_unbound_search_spaces(
space: SearchSpace,
) -> None:
assert not space.has_bounds
with pytest.raises(AttributeError):
space.lower
with pytest.raises(AttributeError):
space.upper
Loading

0 comments on commit 43e6f01

Please sign in to comment.