diff --git a/tests/unit/test_space.py b/tests/unit/test_space.py index 3800e6436..b2caae161 100644 --- a/tests/unit/test_space.py +++ b/tests/unit/test_space.py @@ -22,19 +22,23 @@ import numpy.testing as npt import pytest import tensorflow as tf +from tensorflow.python.framework.errors_impl import InvalidArgumentError from typing_extensions import Final from tests.util.misc import TF_DEBUGGING_ERROR_TYPES, ShapeLike, various_shapes from trieste.space import ( Box, + CategoricalSearchSpace, CollectionSearchSpace, Constraint, DiscreteSearchSpace, + GeneralDiscreteSearchSpace, LinearConstraint, NonlinearConstraint, SearchSpace, TaggedMultiSearchSpace, TaggedProductSearchSpace, + one_hot_encoder, ) from trieste.types import TensorType @@ -44,6 +48,10 @@ def __init__(self, exclusive_limit: int): assert exclusive_limit > 0 self.limit: Final[int] = exclusive_limit + @property + def has_bounds(self) -> bool: + return True + @property def lower(self) -> None: pass @@ -183,6 +191,7 @@ def test_discrete_search_space_returns_correct_dimension( def test_discrete_search_space_returns_correct_bounds( space: DiscreteSearchSpace, lower: tf.Tensor, upper: tf.Tensor ) -> None: + assert space.has_bounds npt.assert_array_equal(space.lower, lower) npt.assert_array_equal(space.upper, upper) @@ -207,22 +216,47 @@ def test_discrete_search_space_contains_raises_for_invalid_shapes( @pytest.mark.parametrize("num_samples", [0, 1, 3, 5, 6, 10, 20]) -def test_discrete_search_space_sampling(num_samples: int) -> None: - search_space = DiscreteSearchSpace(_points_in_2D_search_space()) +@pytest.mark.parametrize( + "search_space", + [ + pytest.param(DiscreteSearchSpace(_points_in_2D_search_space()), id="DiscreteSearchSpace"), + pytest.param(CategoricalSearchSpace([3, 2]), id="CategoricalSearchSpace"), + ], +) +def test_discrete_search_space_sampling( + search_space: GeneralDiscreteSearchSpace, num_samples: int +) -> None: samples = search_space.sample(num_samples) assert all(sample in search_space for sample in samples) assert len(samples) == num_samples @pytest.mark.parametrize("seed", [1, 42, 123]) -def test_discrete_search_space_sampling_returns_same_points_for_same_seed(seed: int) -> None: - search_space = DiscreteSearchSpace(_points_in_2D_search_space()) +@pytest.mark.parametrize( + "search_space", + [ + pytest.param(DiscreteSearchSpace(_points_in_2D_search_space()), id="DiscreteSearchSpace"), + pytest.param(CategoricalSearchSpace([3, 2]), id="CategoricalSearchSpace"), + ], +) +def test_discrete_search_space_sampling_returns_same_points_for_same_seed( + search_space: GeneralDiscreteSearchSpace, seed: int +) -> None: random_samples_1 = search_space.sample(num_samples=100, seed=seed) random_samples_2 = search_space.sample(num_samples=100, seed=seed) npt.assert_allclose(random_samples_1, random_samples_2) -def test_discrete_search_space_sampling_returns_different_points_for_different_call() -> None: +@pytest.mark.parametrize( + "search_space", + [ + pytest.param(DiscreteSearchSpace(_points_in_2D_search_space()), id="DiscreteSearchSpace"), + pytest.param(CategoricalSearchSpace([3, 2]), id="CategoricalSearchSpace"), + ], +) +def test_discrete_search_space_sampling_returns_different_points_for_different_call( + search_space: GeneralDiscreteSearchSpace, +) -> None: search_space = DiscreteSearchSpace(_points_in_2D_search_space()) random_samples_1 = search_space.sample(num_samples=100) random_samples_2 = search_space.sample(num_samples=100) @@ -380,6 +414,7 @@ def test_box_returns_correct_dimension(space: Box, dimension: int) -> None: def test_box_bounds_attributes() -> None: lower, upper = tf.zeros([2]), tf.ones([2]) box = Box(lower, upper) + assert box.has_bounds npt.assert_array_equal(box.lower, lower) npt.assert_array_equal(box.upper, upper) @@ -898,6 +933,7 @@ def test_collection_space_contains_raises_on_point_of_different_shape( [ (TaggedMultiSearchSpace, Box([-1, -2], [2, 3]), [2, 2]), (TaggedProductSearchSpace, Box([-1], [2]), [3]), + (TaggedProductSearchSpace, CategoricalSearchSpace(["A", "B", "C"]), [3]), ], ) @pytest.mark.parametrize("num_samples", [0, 1, 10]) @@ -1081,6 +1117,7 @@ def test_product_space_returns_correct_bounds( spaces: Sequence[SearchSpace], lower: tf.Tensor, upper: tf.Tensor ) -> None: for space in (TaggedProductSearchSpace(spaces=spaces), reduce(operator.mul, spaces)): + assert space.has_bounds npt.assert_array_equal(space.lower, lower) npt.assert_array_equal(space.upper, upper) @@ -1150,10 +1187,10 @@ def test_product_space_fix_subspace_doesnt_fix_undesired_subspace(points: tf.Ten [ Box([-1, -2], [2, 3]), DiscreteSearchSpace(tf.constant([[-0.5]])), - Box([-1], [2]), + CategoricalSearchSpace([3, 2]), ], ["A", "B", "C"], - {"A": [0, 2], "B": [2, 3], "C": [3, 4]}, + {"A": [0, 2], "B": [2, 3], "C": [3, 5]}, ), ], ) @@ -1438,6 +1475,11 @@ def _nlc_func(x: TensorType) -> TensorType: ), False, ), + (CategoricalSearchSpace([3, 2]), CategoricalSearchSpace([3, 2]), True), + (CategoricalSearchSpace([3]), CategoricalSearchSpace([3, 2]), False), + (CategoricalSearchSpace(3), CategoricalSearchSpace(["0", "1", "2"]), True), + (CategoricalSearchSpace(3), CategoricalSearchSpace(["R", "G", "B"]), False), + (CategoricalSearchSpace(3), DiscreteSearchSpace(tf.constant([[0], [1]])), False), ], ) def test___eq___search_spaces(a: SearchSpace, b: SearchSpace, equal: bool) -> None: @@ -1591,3 +1633,229 @@ def test_box_empty_halton_sampling_returns_correct_dtype(dtype: tf.DType) -> Non box = Box(tf.zeros((3,), dtype=dtype), tf.ones((3,), dtype=dtype)) sobol_samples = box.sample_halton(0) assert sobol_samples.dtype == dtype + + +@pytest.mark.parametrize( + "categories, points", + [ + pytest.param([], tf.zeros([0, 0])), + pytest.param(3, tf.constant([[0.0], [1.0], [2.0]])), + pytest.param([3], tf.constant([[0.0], [1.0], [2.0]])), + pytest.param( + [3, 2], + tf.constant([[0.0, 0.0], [0.0, 1.0], [1.0, 0.0], [1.0, 1.0], [2.0, 0.0], [2.0, 1.0]]), + ), + pytest.param(["R", "G", "B"], tf.constant([[0.0], [1.0], [2.0]])), + pytest.param([["R", "G", "B"]], tf.constant([[0.0], [1.0], [2.0]])), + pytest.param( + [["R", "G", "B"], ["Y", "N"]], + tf.constant([[0.0, 0.0], [0.0, 1.0], [1.0, 0.0], [1.0, 1.0], [2.0, 0.0], [2.0, 1.0]]), + ), + ], +) +def test_categorical_search_space__points( + categories: int | Sequence[int] | Sequence[str] | Sequence[Sequence[str]], points: TensorType +) -> None: + space = CategoricalSearchSpace(categories, dtype=tf.float32) + npt.assert_array_equal(space.points, points) + contains = space.contains(points) + assert len(contains) == len(points) + assert tf.reduce_all(contains) + + +@pytest.mark.parametrize( + "categories, exception", + [ + pytest.param([3, 0, 2], ValueError, id="Empty category size"), + pytest.param([-2, 1], ValueError, id="Negative category size"), + pytest.param([["R", "G", "B"], [], ["Y", "N"]], ValueError, id="Empty category list"), + pytest.param([3, ["A", "B"]], TypeError, id="Mixed description types"), + ], +) +def test_categorical_search_space__raises(categories: Any, exception: type) -> None: + with pytest.raises(exception): + CategoricalSearchSpace(categories) + + +@pytest.mark.parametrize( + "space1, space2, expected_product", + [ + (CategoricalSearchSpace(3), CategoricalSearchSpace(2), CategoricalSearchSpace([3, 2])), + (CategoricalSearchSpace([]), CategoricalSearchSpace(2), CategoricalSearchSpace([2])), + ( + CategoricalSearchSpace(["R", "G", "B"]), + CategoricalSearchSpace(["Y", "N"]), + CategoricalSearchSpace([["R", "G", "B"], ["Y", "N"]]), + ), + ], +) +def test_categorical_search_space__product( + space1: CategoricalSearchSpace, + space2: CategoricalSearchSpace, + expected_product: CategoricalSearchSpace, +) -> None: + product = space1 * space2 + assert product == expected_product + npt.assert_array_equal(product.points, expected_product.points) + assert product.tags == expected_product.tags + + +@pytest.mark.parametrize( + "categories, tags", + [ + ([], []), + (3, [("0", "1", "2")]), + ([2, 3], [("0", "1"), ("0", "1", "2")]), + (["R", "G", "B"], [("R", "G", "B")]), + ([["R", "G", "B"], ["Y", "N"]], [("R", "G", "B"), ("Y", "N")]), + ], +) +def test_categorical_search_space__tags( + categories: int | Sequence[int] | Sequence[str] | Sequence[Sequence[str]], + tags: Sequence[Sequence[str]], +) -> None: + search_space = CategoricalSearchSpace(categories) + assert search_space.tags == tags + + +@pytest.mark.parametrize( + "categories, indices, expected_tags", + [ + (3, tf.constant([[0], [2], [2]]), tf.constant([["0"], ["2"], ["2"]])), + (["A", "B", "C"], tf.constant([[0.0], [2.0], [2.0]]), tf.constant([["A"], ["C"], ["C"]])), + ( + (3, 2), + tf.constant([[0, 1.0], [2.0, 0.0], [2.0, 1.0]]), + tf.constant([["0", "1"], ["2", "0"], ["2", "1"]]), + ), + ( + [("A", "B", "C"), ("Y", "N")], + tf.constant([[0.0, 1.0], [2.0, 0.0], [2.0, 1.0]]), + tf.constant([["A", "N"], ["C", "Y"], ["C", "N"]]), + ), + ], +) +def test_categorical_search_space__to_tags( + categories: int | Sequence[int] | Sequence[str] | Sequence[Sequence[str]], + indices: TensorType, + expected_tags: TensorType, +) -> None: + search_space = CategoricalSearchSpace(categories) + tags = search_space.to_tags(indices) + npt.assert_array_equal(tags, expected_tags) + + +def test_categorical_search_space__to_tags_raises_for_non_integers() -> None: + search_space = CategoricalSearchSpace(["A", "B", "C"], dtype=tf.float32) + with pytest.raises(ValueError): + search_space.to_tags(tf.constant([[1.0], [1.2]])) + + +@pytest.mark.parametrize( + "search_space, query_points, encoded_points", + [ + ( + CategoricalSearchSpace(["V"]), + tf.constant([[0], [0]]), + tf.constant([[1], [1]], dtype=tf.float32), + ), + ( + CategoricalSearchSpace(["R", "G", "B"]), + tf.constant([[0], [2], [1]]), + tf.constant([[1, 0, 0], [0, 0, 1], [0, 1, 0]], dtype=tf.float32), + ), + ( + CategoricalSearchSpace(["R", "G", "B"]), + tf.constant([[[[[0]]]]]), + tf.constant([[[[[1, 0, 0]]]]], dtype=tf.float32), + ), + ( + CategoricalSearchSpace(["R", "G", "B", "A"]), + tf.constant([[0], [2], [2]]), + tf.constant([[1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0]], dtype=tf.float32), + ), + ( + CategoricalSearchSpace([["R", "G", "B"], ["Y", "N"]]), + tf.constant([[0, 0], [2, 0], [1, 1]]), + tf.constant([[1, 0, 0, 1, 0], [0, 0, 1, 1, 0], [0, 1, 0, 0, 1]], dtype=tf.float32), + ), + ( + CategoricalSearchSpace([["R", "G", "B"], ["Y", "N"]]), + tf.constant([[[0, 0], [0, 0]], [[2, 0], [1, 1]]]), + tf.constant( + [[[1, 0, 0, 1, 0], [1, 0, 0, 1, 0]], [[0, 0, 1, 1, 0], [0, 1, 0, 0, 1]]], + dtype=tf.float32, + ), + ), + ( + TaggedProductSearchSpace([Box([0.0], [1.0]), CategoricalSearchSpace(["R", "G", "B"])]), + tf.constant([[0.5, 0], [0.3, 2]]), + tf.constant([[0.5, 1, 0, 0], [0.3, 0, 0, 1]], dtype=tf.float32), + ), + ( + TaggedProductSearchSpace([Box([0.0], [1.0]), CategoricalSearchSpace(["R", "G", "B"])]), + tf.constant([[[0.5, 0]], [[0.3, 2]]]), + tf.constant([[[0.5, 1, 0, 0]], [[0.3, 0, 0, 1]]], dtype=tf.float32), + ), + ( + Box([0.0], [1.0]), + tf.constant([[0.5], [0.3]]), + tf.constant([[0.5], [0.3]], dtype=tf.float32), + ), + ], +) +def test_categorical_search_space_one_hot_encoding( + search_space: SearchSpace, query_points: TensorType, encoded_points: TensorType +) -> None: + encoder = one_hot_encoder(search_space) + points = encoder(query_points) + npt.assert_array_equal(encoded_points, points) + + +@pytest.mark.parametrize( + "search_space, query_points, exception", + [ + pytest.param( + CategoricalSearchSpace(["Y", "N"]), + tf.constant([0, 2, 1]), + InvalidArgumentError, + id="Wrong input rank", + ), + pytest.param( + CategoricalSearchSpace(["Y", "N"]), + tf.constant([[0], [2], [1]]), + InvalidArgumentError, + id="Out of range input value", + ), + pytest.param( + CategoricalSearchSpace([["R", "G", "B"], ["Y", "N"]]), + tf.constant([[0], [1], [1]]), + ValueError, + id="Wrong input shape", + ), + ], +) +def test_categorical_search_space_one_hot_encoding__raises( + search_space: CategoricalSearchSpace, query_points: TensorType, exception: type +) -> None: + encoder = one_hot_encoder(search_space) + with pytest.raises(exception): + encoder(query_points) + + +@pytest.mark.parametrize( + "space", + [ + CategoricalSearchSpace([3, 2]), + CategoricalSearchSpace(["R", "G", "B"]), + TaggedProductSearchSpace([Box([-1, -2], [2, 3]), CategoricalSearchSpace(2)]), + ], +) +def test_unbound_search_spaces( + space: SearchSpace, +) -> None: + assert not space.has_bounds + with pytest.raises(AttributeError): + space.lower + with pytest.raises(AttributeError): + space.upper diff --git a/trieste/space.py b/trieste/space.py index 1c361dc6e..326cc052b 100644 --- a/trieste/space.py +++ b/trieste/space.py @@ -17,6 +17,7 @@ import operator from abc import ABC, abstractmethod from functools import reduce +from itertools import chain from typing import Callable, Optional, Sequence, Tuple, TypeVar, Union, overload import numpy as np @@ -24,8 +25,10 @@ import tensorflow as tf import tensorflow_probability as tfp from check_shapes import check_shapes +from typing_extensions import Protocol, runtime_checkable from .types import TensorType +from .utils import flatten_leading_dims SearchSpaceType = TypeVar("SearchSpaceType", bound="SearchSpace") """ A type variable bound to :class:`SearchSpace`. """ @@ -33,6 +36,9 @@ DEFAULT_DTYPE: tf.DType = tf.float64 """ Default dtype to use when none is provided. """ +EncoderFunction = Callable[[TensorType], TensorType] +""" Type alias for point encoders. These transform points from one search space to another. """ + class SampleTimeoutError(Exception): """Raised when sampling from a search space has timed out.""" @@ -261,6 +267,11 @@ def __contains__(self, value: TensorType) -> bool: def dimension(self) -> TensorType: """The number of inputs in this search space.""" + @property + @abstractmethod + def has_bounds(self) -> bool: + """Whether the search space has meaningful numerical bounds.""" + @property @abstractmethod def lower(self) -> TensorType: @@ -369,22 +380,15 @@ def is_feasible(self, points: TensorType) -> TensorType: @property def has_constraints(self) -> bool: """Returns `True` if this search space has any explicit constraints specified.""" - # By default assume there are no constraints; can be overridden by a subclass. + # By default, assume there are no constraints; can be overridden by a subclass. return False -class DiscreteSearchSpace(SearchSpace): - r""" - A discrete :class:`SearchSpace` representing a finite set of :math:`D`-dimensional points in - :math:`\mathbb{R}^D`. - - For example: - - >>> points = tf.constant([[-1.0, 0.4], [-1.0, 0.6], [0.0, 0.4]]) - >>> search_space = DiscreteSearchSpace(points) - >>> assert tf.constant([0.0, 0.4]) in search_space - >>> assert tf.constant([1.0, 0.5]) not in search_space - +class GeneralDiscreteSearchSpace(SearchSpace): + """ + An ABC representing different types of discrete search spaces (not just numerical). + This contains a default implementation using explicitly provided points which subclasses + may ignore. """ def __init__(self, points: TensorType): @@ -397,20 +401,6 @@ def __init__(self, points: TensorType): self._points = points self._dimension = tf.shape(self._points)[-1] - def __repr__(self) -> str: - """""" - return f"DiscreteSearchSpace({self._points!r})" - - @property - def lower(self) -> TensorType: - """The lowest value taken across all points by each search space dimension.""" - return tf.reduce_min(self.points, -2) - - @property - def upper(self) -> TensorType: - """The highest value taken across all points by each search space dimension.""" - return tf.reduce_max(self.points, -2) - @property def points(self) -> TensorType: """All the points in this space.""" @@ -422,7 +412,7 @@ def dimension(self) -> TensorType: return self._dimension def _contains(self, value: TensorType) -> TensorType: - comparison = tf.math.equal(self._points, tf.expand_dims(value, -2)) # [..., N, D] + comparison = tf.math.equal(self.points, tf.expand_dims(value, -2)) # [..., N, D] return tf.reduce_any(tf.reduce_all(comparison, axis=-1), axis=-1) # [...] def sample(self, num_samples: int, seed: Optional[int] = None) -> TensorType: @@ -443,6 +433,39 @@ def sample(self, num_samples: int, seed: Optional[int] = None) -> TensorType: ) return tf.gather(self.points, sampled_indices)[0, :, :] # [num_samples, D] + +class DiscreteSearchSpace(GeneralDiscreteSearchSpace): + r""" + A discrete :class:`SearchSpace` representing a finite set of :math:`D`-dimensional points in + :math:`\mathbb{R}^D`. + + For example: + + >>> points = tf.constant([[-1.0, 0.4], [-1.0, 0.6], [0.0, 0.4]]) + >>> search_space = DiscreteSearchSpace(points) + >>> assert tf.constant([0.0, 0.4]) in search_space + >>> assert tf.constant([1.0, 0.5]) not in search_space + + """ + + def __repr__(self) -> str: + """""" + return f"DiscreteSearchSpace({self._points!r})" + + @property + def has_bounds(self) -> bool: + return True + + @property + def lower(self) -> TensorType: + """The lowest value taken across all points by each search space dimension.""" + return tf.reduce_min(self.points, -2) + + @property + def upper(self) -> TensorType: + """The highest value taken across all points by each search space dimension.""" + return tf.reduce_max(self.points, -2) + def product(self, other: DiscreteSearchSpace) -> DiscreteSearchSpace: r""" Return the Cartesian product of the two :class:`DiscreteSearchSpace`\ s. For example: @@ -480,6 +503,186 @@ def __eq__(self, other: object) -> bool: return bool(tf.reduce_all(tf.sort(self.points, 0) == tf.sort(other.points, 0))) +@runtime_checkable +class HasOneHotEncoder(Protocol): + """A categorical search space that contains default logic for one-hot encoding.""" + + @property + @abstractmethod + def one_hot_encoder(self) -> EncoderFunction: + "A one-hot encoder for points in the search space." + + +def one_hot_encoder(space: SearchSpace) -> EncoderFunction: + "A utility function for one-hot encoding a search space when it supports it." + return space.one_hot_encoder if isinstance(space, HasOneHotEncoder) else lambda x: x + + +class CategoricalSearchSpace(GeneralDiscreteSearchSpace, HasOneHotEncoder): + r""" + A categorical :class:`SearchSpace` representing a finite set :math:`\mathcal{C}` of categories, + or a finite Cartesian product :math:`\mathcal{C}_1 \times \cdots \times \mathcal{C}_n` of + such sets. + + For example: + + >>> CategoricalSearchSpace(5) + CategoricalSearchSpace([('0', '1', '2', '3', '4')]) + >>> CategoricalSearchSpace(["Red", "Green", "Blue"]) + CategoricalSearchSpace([('Red', 'Green', 'Blue')]) + >>> CategoricalSearchSpace([2,3]) + CategoricalSearchSpace([('0', '1'), ('0', '1', '2')]) + >>> CategoricalSearchSpace([["R", "G", "B"], ["Y", "N"]]) + CategoricalSearchSpace([('R', 'G', 'B'), ('Y', 'N')]) + + Note that internally categories are represented by numeric indices: + + >>> rgb = CategoricalSearchSpace(["Red", "Green", "Blue"]) + >>> assert tf.constant([1], dtype=tf.float64) in rgb + >>> assert tf.constant([3], dtype=tf.float64) not in rgb + >>> rgb.to_tags(tf.constant([[1], [0], [2]])) + + + """ + + def __init__( + self, + categories: int | Sequence[int] | Sequence[str] | Sequence[Sequence[str]], + dtype: tf.DType = DEFAULT_DTYPE, + ): + """ + :param categories: Number of categories or category names. Can be an array for + multidimensional spaces. + :param dtype: The dtype of the returned indices, either tf.float32 or tf.float64. + """ + if isinstance(categories, int) or any(isinstance(x, str) for x in categories): + categories = [categories] # type: ignore[assignment] + + if not isinstance(categories, Sequence) or not ( + all( + isinstance(x, Sequence) + and not isinstance(x, str) + and all(isinstance(y, str) for y in x) + for x in categories + ) + or all(isinstance(x, int) for x in categories) + ): + raise TypeError( + f"Invalid category description {categories!r}: " "expected either numbers or names." + ) + + elif any(isinstance(x, int) for x in categories): + category_lens: Sequence[int] = categories # type: ignore[assignment] + if any(x <= 0 for x in category_lens): + raise ValueError("Numbers of categories must be positive") + tags = [tuple(f"{i}" for i in range(n)) for n in category_lens] + else: + category_names: Sequence[Sequence[str]] = categories # type: ignore[assignment] + if any(len(ts) == 0 for ts in category_names): + raise ValueError("Category name lists cannot be empty") + tags = [tuple(ts) for ts in category_names] + + self._tags = tags + + ranges = [tf.range(len(ts), dtype=dtype) for ts in tags] + meshgrid = tf.meshgrid(*ranges, indexing="ij") + points = ( + tf.reshape(tf.stack(meshgrid, axis=-1), [-1, len(tags)]) if tags else tf.zeros([0, 0]) + ) + + super().__init__(points) + + def __repr__(self) -> str: + """""" + return f"CategoricalSearchSpace({self._tags!r})" + + @property + def has_bounds(self) -> bool: + return False + + @property + def lower(self) -> TensorType: + raise AttributeError("Categorical search spaces do not have numerical bounds") + + @property + def upper(self) -> TensorType: + raise AttributeError("Categorical search spaces do not have numerical bounds") + + @property + def tags(self) -> Sequence[Sequence[str]]: + """The tags of the categories.""" + return self._tags + + @property + def one_hot_encoder(self) -> EncoderFunction: + """A one-hot encoder for the numerical indices.""" + + def encoder(x: TensorType) -> TensorType: + flat_x, unflatten = flatten_leading_dims(x) + if flat_x.shape[-1] != len(self.tags): + raise ValueError( + "Invalid input for one-hot encoding: " + f"expected {len(self.tags)} tags, got {flat_x.shape[-1]}" + ) + columns = tf.split(flat_x, flat_x.shape[-1], axis=1) + encoders = [ + tf.keras.layers.CategoryEncoding(num_tokens=len(ts), output_mode="one_hot") + for ts in self.tags + ] + encoded = tf.concat( + [encoder(column) for encoder, column in zip(encoders, columns)], axis=1 + ) + return unflatten(encoded) + + return encoder + + def to_tags(self, indices: TensorType) -> TensorType: + """ + Convert a tensor of indices (such as one returned by :meth:`sample`) to one of + category tags. + + :param indices: A tensor of integer indices. + :return: A tensor of string tags. + """ + if indices.dtype.is_floating: + if not tf.reduce_all(tf.math.equal(indices, tf.math.floor(indices))): + raise ValueError("Non-integral indices passed to to_tags") + indices = tf.cast(indices, dtype=tf.int32) + + def extract_tags(row: TensorType) -> TensorType: + return tf.stack( + [tf.gather(tf.constant(self._tags[i]), row[i]) for i in range(len(row))] + ) + + return tf.map_fn(extract_tags, indices, dtype=tf.string) + + def product(self, other: CategoricalSearchSpace) -> CategoricalSearchSpace: + r""" + Return the Cartesian product of the two :class:`CategoricalSearchSpace`\ s. For example: + + >>> rgb = CategoricalSearchSpace(["Red", "Green", "Blue"]) + >>> yn = CategoricalSearchSpace(["Yes", "No"]) + >>> rgb * yn + CategoricalSearchSpace([('Red', 'Green', 'Blue'), ('Yes', 'No')]) + + :param other: A :class:`CategoricalSearchSpace`. + :return: The Cartesian product of the two :class:`CategoricalSearchSpace`\ s. + """ + return CategoricalSearchSpace(tuple(chain(self.tags, other.tags))) + + def __eq__(self, other: object) -> bool: + """ + :param other: A search space. + :return: Whether the search space is identical to this one. + """ + if not isinstance(other, CategoricalSearchSpace): + return NotImplemented + return self.tags == other.tags + + class Box(SearchSpace): r""" Continuous :class:`SearchSpace` representing a :math:`D`-dimensional box in @@ -559,6 +762,10 @@ def __repr__(self) -> str: """""" return f"Box({self._lower!r}, {self._upper!r}, {self._constraints!r}, {self._ctol!r})" + @property + def has_bounds(self) -> bool: + return True + @property def lower(self) -> tf.Tensor: """The lower bounds of the box.""" @@ -921,6 +1128,11 @@ def __repr__(self) -> str: tags = {self.subspace_tags}) """ + @property + def has_bounds(self) -> bool: + """Whether the search space has meaningful numerical bounds.""" + return all(self.get_subspace(tag).has_bounds for tag in self.subspace_tags) + @property def subspace_lower(self) -> Sequence[TensorType]: """The lowest values taken by each space dimension, in the same order as specified when @@ -988,7 +1200,7 @@ def __eq__(self, other: object) -> bool: return self._tags == other._tags and self._spaces == other._spaces -class TaggedProductSearchSpace(CollectionSearchSpace): +class TaggedProductSearchSpace(CollectionSearchSpace, HasOneHotEncoder): r""" Product :class:`SearchSpace` consisting of a product of multiple :class:`SearchSpace`. This class provides functionality for @@ -1136,6 +1348,23 @@ def product(self, other: TaggedProductSearchSpace) -> TaggedProductSearchSpace: """ return TaggedProductSearchSpace(spaces=[self, other]) + @property + def one_hot_encoder(self) -> EncoderFunction: + """An encoder that one-hot-encodes all subpsaces that support it (and leaves + the other subspaces unchanged).""" + + def encoder(x: TensorType) -> TensorType: + components = [] + for tag in self.subspace_tags: + component = self.get_subspace_component(tag, x) + space = self.get_subspace(tag) + if isinstance(space, HasOneHotEncoder): + component = space.one_hot_encoder(component) + components.append(component) + return tf.concat(components, axis=-1) + + return encoder + class TaggedMultiSearchSpace(CollectionSearchSpace): r"""