From 5e3cae0422790504074ca0940a835b9b354ce7a5 Mon Sep 17 00:00:00 2001 From: Uri Granta Date: Mon, 29 Jul 2024 11:52:20 +0100 Subject: [PATCH] DiscreteSearchSpaceABC --- trieste/space.py | 49 ++++++++++++++++++++++++++---------------------- 1 file changed, 27 insertions(+), 22 deletions(-) diff --git a/trieste/space.py b/trieste/space.py index ec8be7e96..564e062f9 100644 --- a/trieste/space.py +++ b/trieste/space.py @@ -374,22 +374,14 @@ def is_feasible(self, points: TensorType) -> TensorType: @property def has_constraints(self) -> bool: """Returns `True` if this search space has any explicit constraints specified.""" - # By default assume there are no constraints; can be overridden by a subclass. + # By default, assume there are no constraints; can be overridden by a subclass. return False -class DiscreteSearchSpace(SearchSpace): - r""" - A discrete :class:`SearchSpace` representing a finite set of :math:`D`-dimensional points in - :math:`\mathbb{R}^D`. - - For example: - - >>> points = tf.constant([[-1.0, 0.4], [-1.0, 0.6], [0.0, 0.4]]) - >>> search_space = DiscreteSearchSpace(points) - >>> assert tf.constant([0.0, 0.4]) in search_space - >>> assert tf.constant([1.0, 0.5]) not in search_space - +class DiscreteSearchSpaceABC(SearchSpace): + """ + An ABC representing different types of discrete search spaces. This contains + a default implementation using explicitly provided points which subclasses may ignore. """ def __init__(self, points: TensorType): @@ -402,10 +394,6 @@ def __init__(self, points: TensorType): self._points = points self._dimension = tf.shape(self._points)[-1] - def __repr__(self) -> str: - """""" - return f"DiscreteSearchSpace({self._points!r})" - @property def lower(self) -> TensorType: """The lowest value taken across all points by each search space dimension.""" @@ -448,6 +436,25 @@ def sample(self, num_samples: int, seed: Optional[int] = None) -> TensorType: ) return tf.gather(self.points, sampled_indices)[0, :, :] # [num_samples, D] + +class DiscreteSearchSpace(DiscreteSearchSpaceABC): + r""" + A discrete :class:`SearchSpace` representing a finite set of :math:`D`-dimensional points in + :math:`\mathbb{R}^D`. + + For example: + + >>> points = tf.constant([[-1.0, 0.4], [-1.0, 0.6], [0.0, 0.4]]) + >>> search_space = DiscreteSearchSpace(points) + >>> assert tf.constant([0.0, 0.4]) in search_space + >>> assert tf.constant([1.0, 0.5]) not in search_space + + """ + + def __repr__(self) -> str: + """""" + return f"DiscreteSearchSpace({self._points!r})" + def product(self, other: DiscreteSearchSpace) -> DiscreteSearchSpace: r""" Return the Cartesian product of the two :class:`DiscreteSearchSpace`\ s. For example: @@ -485,7 +492,7 @@ def __eq__(self, other: object) -> bool: return bool(tf.reduce_all(tf.sort(self.points, 0) == tf.sort(other.points, 0))) -class CategoricalSearchSpace(DiscreteSearchSpace): +class CategoricalSearchSpace(DiscreteSearchSpaceABC): r""" A categorical :class:`SearchSpace` representing a finite set :math:`\mathcal{C}` of categories, or a finite Cartesian product :math:`\mathcal{C_1} \times \cdots \times \mathcal{C_n}` of @@ -540,6 +547,7 @@ def __init__(self, categories: int | Sequence[int] | Sequence[str] | Sequence[Se meshgrid = tf.meshgrid(*ranges, indexing="ij") points = tf.reshape(tf.stack(meshgrid, axis=-1), [-1, len(tags)]) + # TODO super().__init__(points) def __repr__(self) -> str: @@ -592,7 +600,7 @@ def extract_tags(row: TensorType) -> TensorType: return tf.map_fn(extract_tags, indices, dtype=tf.string) - def product(self, other: DiscreteSearchSpace) -> CategoricalSearchSpace: + def product(self, other: CategoricalSearchSpace) -> CategoricalSearchSpace: r""" Return the Cartesian product of the two :class:`CategoricalSearchSpace`\ s. For example: @@ -604,9 +612,6 @@ def product(self, other: DiscreteSearchSpace) -> CategoricalSearchSpace: :param other: A :class:`CategoricalSearchSpace`. :return: The Cartesian product of the two :class:`CategoricalSearchSpace`\ s. """ - if not isinstance(other, CategoricalSearchSpace): - return NotImplemented - return CategoricalSearchSpace(tuple(chain(self.tags, other.tags))) def __eq__(self, other: object) -> bool: