diff --git a/tests/unit/test_bayesian_optimizer.py b/tests/unit/test_bayesian_optimizer.py index f6d2596544..cec30bf924 100644 --- a/tests/unit/test_bayesian_optimizer.py +++ b/tests/unit/test_bayesian_optimizer.py @@ -517,6 +517,10 @@ def test_bayesian_optimizer_optimize_raises_for_negative_steps(num_steps: int) - def test_bayesian_optimizer_optimize_is_noop_for_zero_steps() -> None: class _UnusableModel(TrainableProbabilisticModel): + @property + def event_shape(self) -> TensorType: + assert False + def predict(self, query_points: TensorType) -> NoReturn: assert False diff --git a/tests/util/models/gpflow/models.py b/tests/util/models/gpflow/models.py index 650fc65f08..c7ea8bf009 100644 --- a/tests/util/models/gpflow/models.py +++ b/tests/util/models/gpflow/models.py @@ -91,14 +91,22 @@ def __init__( mean_functions: Sequence[Callable[[TensorType], TensorType]], kernels: Sequence[tfp.math.psd_kernels.PositiveSemidefiniteKernel], noise_variance: float = 1.0, + event_shape: Optional[TensorType] = None, ): self._mean_functions = mean_functions self._kernels = kernels self._noise_variance = noise_variance + self._event_shape = event_shape def __repr__(self) -> str: return f"GaussianProcess({self._mean_functions!r}, {self._kernels!r})" + @property + def event_shape(self) -> TensorType: + if self._event_shape is None: + raise NotImplementedError("event_shape not specified for this model") + return self._event_shape + @inherit_check_shapes def predict(self, query_points: TensorType) -> tuple[TensorType, TensorType]: mean, cov = self.predict_joint(query_points[..., None, :]) @@ -198,7 +206,7 @@ def __init__( ): self.kernel = tfp.math.psd_kernels.ExponentiatedQuadratic(kernel_amplitude) self.mean_function = lambda x: quadratic(x - tf.cast(x_shift, dtype=x.dtype)) - super().__init__([self.mean_function], [self.kernel], noise_variance) + super().__init__([self.mean_function], [self.kernel], noise_variance, tf.constant([1])) def __repr__(self) -> str: return "QuadraticMeanAndRBFKernel()" diff --git a/trieste/models/gpflow/interface.py b/trieste/models/gpflow/interface.py index 5fad89f679..83d63b8347 100644 --- a/trieste/models/gpflow/interface.py +++ b/trieste/models/gpflow/interface.py @@ -102,6 +102,10 @@ def update_posterior_cache(self) -> None: def model(self) -> GPModel: """The underlying GPflow model.""" + @property + def event_shape(self) -> TensorType: + return tf.constant([self.model.num_latent_gps]) + @inherit_check_shapes def predict(self, query_points: TensorType) -> tuple[TensorType, TensorType]: mean, cov = (self._posterior or self.model).predict_f(query_points) diff --git a/trieste/models/gpflow/models.py b/trieste/models/gpflow/models.py index 235325e399..9b41b73da1 100644 --- a/trieste/models/gpflow/models.py +++ b/trieste/models/gpflow/models.py @@ -1417,6 +1417,10 @@ def __init__( def num_fidelities(self) -> int: return self._num_fidelities + @property + def event_shape(self) -> TensorType: + return self.lowest_fidelity_signal_model.event_shape + @inherit_check_shapes def predict(self, query_points: TensorType) -> tuple[TensorType, TensorType]: """ diff --git a/trieste/models/gpflux/models.py b/trieste/models/gpflux/models.py index 903a01222c..4d1b763a7c 100644 --- a/trieste/models/gpflux/models.py +++ b/trieste/models/gpflux/models.py @@ -287,6 +287,10 @@ def model_gpflux(self) -> DeepGP: def model_keras(self) -> tf.keras.Model: return self._model_keras + @property + def event_shape(self) -> TensorType: + raise NotImplementedError # TODO: ??? + @inherit_check_shapes def sample(self, query_points: TensorType, num_samples: int) -> TensorType: trajectory = self.trajectory_sampler().get_trajectory() diff --git a/trieste/models/interfaces.py b/trieste/models/interfaces.py index ae265d4f85..683b444a42 100644 --- a/trieste/models/interfaces.py +++ b/trieste/models/interfaces.py @@ -45,6 +45,13 @@ class ProbabilisticModel(Protocol): https://github.com/python/typing/issues/213 is implemented. """ + @property + @abstractmethod + def event_shape(self) -> TensorType: + """ + The event shape of the model's predictive distribution. + """ + @abstractmethod @check_shapes( "query_points: [batch..., D]", @@ -367,6 +374,11 @@ def __init__( # NB we don't use @inherit_shapes below as some classes break the shape API (👀 fantasizer) # instead we rely on the shape checking inside the submodels + @property + def event_shape(self) -> TensorType: + shapes = [model.event_shape for model in self._models] + return tf.concat([shapes[0][:-1], [sum(shape[-1] for shape in shapes)]], axis=0) + def predict(self, query_points: TensorType) -> tuple[TensorType, TensorType]: r""" :param query_points: The points at which to make predictions, of shape [..., D]. diff --git a/trieste/models/keras/interface.py b/trieste/models/keras/interface.py index 96b222cccd..00560aff9e 100644 --- a/trieste/models/keras/interface.py +++ b/trieste/models/keras/interface.py @@ -57,6 +57,10 @@ def model(self) -> tf.keras.Model: """The compiled Keras model.""" raise NotImplementedError + @property + def event_shape(self) -> TensorType: + return tf.constant([self.model.output_shape[0][-1]]) + @property def optimizer(self) -> KerasOptimizer: """The optimizer wrapper for training the model."""