Skip to content

Commit

Permalink
DiscreteSearchSpaceABC
Browse files Browse the repository at this point in the history
  • Loading branch information
Uri Granta committed Jul 29, 2024
1 parent 2a863b5 commit 5e3cae0
Showing 1 changed file with 27 additions and 22 deletions.
49 changes: 27 additions & 22 deletions trieste/space.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,22 +374,14 @@ def is_feasible(self, points: TensorType) -> TensorType:
@property
def has_constraints(self) -> bool:
"""Returns `True` if this search space has any explicit constraints specified."""
# By default assume there are no constraints; can be overridden by a subclass.
# By default, assume there are no constraints; can be overridden by a subclass.
return False


class DiscreteSearchSpace(SearchSpace):
r"""
A discrete :class:`SearchSpace` representing a finite set of :math:`D`-dimensional points in
:math:`\mathbb{R}^D`.
For example:
>>> points = tf.constant([[-1.0, 0.4], [-1.0, 0.6], [0.0, 0.4]])
>>> search_space = DiscreteSearchSpace(points)
>>> assert tf.constant([0.0, 0.4]) in search_space
>>> assert tf.constant([1.0, 0.5]) not in search_space
class DiscreteSearchSpaceABC(SearchSpace):
"""
An ABC representing different types of discrete search spaces. This contains
a default implementation using explicitly provided points which subclasses may ignore.
"""

def __init__(self, points: TensorType):
Expand All @@ -402,10 +394,6 @@ def __init__(self, points: TensorType):
self._points = points
self._dimension = tf.shape(self._points)[-1]

def __repr__(self) -> str:
""""""
return f"DiscreteSearchSpace({self._points!r})"

@property
def lower(self) -> TensorType:
"""The lowest value taken across all points by each search space dimension."""
Expand Down Expand Up @@ -448,6 +436,25 @@ def sample(self, num_samples: int, seed: Optional[int] = None) -> TensorType:
)
return tf.gather(self.points, sampled_indices)[0, :, :] # [num_samples, D]


class DiscreteSearchSpace(DiscreteSearchSpaceABC):
r"""
A discrete :class:`SearchSpace` representing a finite set of :math:`D`-dimensional points in
:math:`\mathbb{R}^D`.
For example:
>>> points = tf.constant([[-1.0, 0.4], [-1.0, 0.6], [0.0, 0.4]])
>>> search_space = DiscreteSearchSpace(points)
>>> assert tf.constant([0.0, 0.4]) in search_space
>>> assert tf.constant([1.0, 0.5]) not in search_space
"""

def __repr__(self) -> str:
""""""
return f"DiscreteSearchSpace({self._points!r})"

def product(self, other: DiscreteSearchSpace) -> DiscreteSearchSpace:
r"""
Return the Cartesian product of the two :class:`DiscreteSearchSpace`\ s. For example:
Expand Down Expand Up @@ -485,7 +492,7 @@ def __eq__(self, other: object) -> bool:
return bool(tf.reduce_all(tf.sort(self.points, 0) == tf.sort(other.points, 0)))


class CategoricalSearchSpace(DiscreteSearchSpace):
class CategoricalSearchSpace(DiscreteSearchSpaceABC):
r"""
A categorical :class:`SearchSpace` representing a finite set :math:`\mathcal{C}` of categories,
or a finite Cartesian product :math:`\mathcal{C_1} \times \cdots \times \mathcal{C_n}` of
Expand Down Expand Up @@ -540,6 +547,7 @@ def __init__(self, categories: int | Sequence[int] | Sequence[str] | Sequence[Se
meshgrid = tf.meshgrid(*ranges, indexing="ij")
points = tf.reshape(tf.stack(meshgrid, axis=-1), [-1, len(tags)])

# TODO
super().__init__(points)

def __repr__(self) -> str:
Expand Down Expand Up @@ -592,7 +600,7 @@ def extract_tags(row: TensorType) -> TensorType:

return tf.map_fn(extract_tags, indices, dtype=tf.string)

def product(self, other: DiscreteSearchSpace) -> CategoricalSearchSpace:
def product(self, other: CategoricalSearchSpace) -> CategoricalSearchSpace:
r"""
Return the Cartesian product of the two :class:`CategoricalSearchSpace`\ s. For example:
Expand All @@ -604,9 +612,6 @@ def product(self, other: DiscreteSearchSpace) -> CategoricalSearchSpace:
:param other: A :class:`CategoricalSearchSpace`.
:return: The Cartesian product of the two :class:`CategoricalSearchSpace`\ s.
"""
if not isinstance(other, CategoricalSearchSpace):
return NotImplemented

return CategoricalSearchSpace(tuple(chain(self.tags, other.tags)))

def __eq__(self, other: object) -> bool:
Expand Down

0 comments on commit 5e3cae0

Please sign in to comment.