Skip to content

Commit

Permalink
Explore adding event_shape property to ProbabilisticModel
Browse files Browse the repository at this point in the history
  • Loading branch information
Uri Granta committed Jan 26, 2024
1 parent 5528236 commit 5855fd7
Show file tree
Hide file tree
Showing 7 changed files with 41 additions and 1 deletion.
4 changes: 4 additions & 0 deletions tests/unit/test_bayesian_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,10 @@ def test_bayesian_optimizer_optimize_raises_for_negative_steps(num_steps: int) -

def test_bayesian_optimizer_optimize_is_noop_for_zero_steps() -> None:
class _UnusableModel(TrainableProbabilisticModel):
@property
def event_shape(self) -> TensorType:
assert False

def predict(self, query_points: TensorType) -> NoReturn:
assert False

Expand Down
10 changes: 9 additions & 1 deletion tests/util/models/gpflow/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,22 @@ def __init__(
mean_functions: Sequence[Callable[[TensorType], TensorType]],
kernels: Sequence[tfp.math.psd_kernels.PositiveSemidefiniteKernel],
noise_variance: float = 1.0,
event_shape: Optional[TensorType] = None,
):
self._mean_functions = mean_functions
self._kernels = kernels
self._noise_variance = noise_variance
self._event_shape = event_shape

def __repr__(self) -> str:
return f"GaussianProcess({self._mean_functions!r}, {self._kernels!r})"

@property
def event_shape(self) -> TensorType:
if self._event_shape is None:
raise NotImplementedError("event_shape not specified for this model")
return self._event_shape

@inherit_check_shapes
def predict(self, query_points: TensorType) -> tuple[TensorType, TensorType]:
mean, cov = self.predict_joint(query_points[..., None, :])
Expand Down Expand Up @@ -198,7 +206,7 @@ def __init__(
):
self.kernel = tfp.math.psd_kernels.ExponentiatedQuadratic(kernel_amplitude)
self.mean_function = lambda x: quadratic(x - tf.cast(x_shift, dtype=x.dtype))
super().__init__([self.mean_function], [self.kernel], noise_variance)
super().__init__([self.mean_function], [self.kernel], noise_variance, tf.constant([1]))

def __repr__(self) -> str:
return "QuadraticMeanAndRBFKernel()"
Expand Down
4 changes: 4 additions & 0 deletions trieste/models/gpflow/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@ def update_posterior_cache(self) -> None:
def model(self) -> GPModel:
"""The underlying GPflow model."""

@property
def event_shape(self) -> TensorType:
return tf.constant([self.model.num_latent_gps])

@inherit_check_shapes
def predict(self, query_points: TensorType) -> tuple[TensorType, TensorType]:
mean, cov = (self._posterior or self.model).predict_f(query_points)
Expand Down
4 changes: 4 additions & 0 deletions trieste/models/gpflow/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1417,6 +1417,10 @@ def __init__(
def num_fidelities(self) -> int:
return self._num_fidelities

@property
def event_shape(self) -> TensorType:
return self.lowest_fidelity_signal_model.event_shape

@inherit_check_shapes
def predict(self, query_points: TensorType) -> tuple[TensorType, TensorType]:
"""
Expand Down
4 changes: 4 additions & 0 deletions trieste/models/gpflux/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,10 @@ def model_gpflux(self) -> DeepGP:
def model_keras(self) -> tf.keras.Model:
return self._model_keras

@property
def event_shape(self) -> TensorType:
raise NotImplementedError # TODO: ???

@inherit_check_shapes
def sample(self, query_points: TensorType, num_samples: int) -> TensorType:
trajectory = self.trajectory_sampler().get_trajectory()
Expand Down
12 changes: 12 additions & 0 deletions trieste/models/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ class ProbabilisticModel(Protocol):
https://github.com/python/typing/issues/213 is implemented.
"""

@property
@abstractmethod
def event_shape(self) -> TensorType:
"""
The event shape of the model's predictive distribution.
"""

@abstractmethod
@check_shapes(
"query_points: [batch..., D]",
Expand Down Expand Up @@ -367,6 +374,11 @@ def __init__(
# NB we don't use @inherit_shapes below as some classes break the shape API (👀 fantasizer)
# instead we rely on the shape checking inside the submodels

@property
def event_shape(self) -> TensorType:
shapes = [model.event_shape for model in self._models]
return tf.concat([shapes[0][:-1], [sum(shape[-1] for shape in shapes)]], axis=0)

def predict(self, query_points: TensorType) -> tuple[TensorType, TensorType]:
r"""
:param query_points: The points at which to make predictions, of shape [..., D].
Expand Down
4 changes: 4 additions & 0 deletions trieste/models/keras/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ def model(self) -> tf.keras.Model:
"""The compiled Keras model."""
raise NotImplementedError

@property
def event_shape(self) -> TensorType:
return tf.constant([self.model.output_shape[0][-1]])

@property
def optimizer(self) -> KerasOptimizer:
"""The optimizer wrapper for training the model."""
Expand Down

0 comments on commit 5855fd7

Please sign in to comment.