From a450749ebf0e1cfa1b79f8941c5d6b82be649cdc Mon Sep 17 00:00:00 2001 From: Neil Ferguson Date: Wed, 31 Jan 2024 13:21:41 +0000 Subject: [PATCH] Add compile_args to deep ensemble model (#816) --- tests/unit/models/gpflux/test_models.py | 23 +++++++++++++++++++ tests/unit/models/keras/test_models.py | 30 +++++++++++++++++++++++++ tests/util/models/keras/models.py | 7 ++++-- trieste/models/gpflux/models.py | 23 ++++++++++++++++--- trieste/models/keras/models.py | 23 +++++++++++++++---- 5 files changed, 97 insertions(+), 9 deletions(-) diff --git a/tests/unit/models/gpflux/test_models.py b/tests/unit/models/gpflux/test_models.py index e024a539f3..65abf611e5 100644 --- a/tests/unit/models/gpflux/test_models.py +++ b/tests/unit/models/gpflux/test_models.py @@ -553,3 +553,26 @@ def test_deepgp_log( assert mocked_summary_scalar.call_count == num_scalars assert mocked_summary_histogram.call_count == num_histogram + + +def test_deepgp_compile_args_specified() -> None: + x_observed = np.linspace(0, 10, 10).reshape((-1, 1)) + model = single_layer_dgp_model(x_observed) + # If we get this error we know that the compile_args are being passed to the model + # because Keras will throw an error if it receives both of these arguments. + with pytest.raises( + ValueError, match="You cannot enable `run_eagerly` and `jit_compile` at the same time." + ): + DeepGaussianProcess(model, compile_args={"jit_compile": True, "run_eagerly": True}) + + +def test_deepgp_disallowed_compile_args_specified() -> None: + mock_model = unittest.mock.MagicMock(spec=DeepGP) + with pytest.raises(ValueError): + DeepGaussianProcess( + mock_model, compile_args={"run_eagerly": True, "optimizer": unittest.mock.MagicMock()} + ) + with pytest.raises(ValueError): + DeepGaussianProcess( + mock_model, compile_args={"run_eagerly": True, "metrics": unittest.mock.MagicMock()} + ) diff --git a/tests/unit/models/keras/test_models.py b/tests/unit/models/keras/test_models.py index e25f73400a..8a6eaf1413 100644 --- a/tests/unit/models/keras/test_models.py +++ b/tests/unit/models/keras/test_models.py @@ -198,6 +198,36 @@ def test_deep_ensemble_is_compiled() -> None: assert model.model.optimizer is not None +def test_deep_ensemble_compile_args_specified() -> None: + example_data = empty_dataset([1], [1]) + # If we get this error we know that the compile_args are being passed to the model + # because Keras will throw an error if it receives both of these arguments. + with pytest.raises( + ValueError, match="You cannot enable `run_eagerly` and `jit_compile` at the same time." + ): + model, _, _ = trieste_deep_ensemble_model( + example_data, _ENSEMBLE_SIZE, compile_args={"run_eagerly": True, "jit_compile": True} + ) + + +def test_deep_ensemble_disallowed_compile_args_specified() -> None: + mock_ensemble = unittest.mock.MagicMock(spec=KerasEnsemble) + mock_ensemble.ensemble_size = _ENSEMBLE_SIZE + with pytest.raises(ValueError): + DeepEnsemble( + mock_ensemble, + compile_args={"run_eagerly": True, "optimizer": unittest.mock.MagicMock()}, + ) + with pytest.raises(ValueError): + DeepEnsemble( + mock_ensemble, compile_args={"run_eagerly": True, "loss": unittest.mock.MagicMock()} + ) + with pytest.raises(ValueError): + DeepEnsemble( + mock_ensemble, compile_args={"run_eagerly": True, "metrics": unittest.mock.MagicMock()} + ) + + def test_deep_ensemble_resets_lr_with_lr_schedule() -> None: example_data = _get_example_data([100, 1]) diff --git a/tests/util/models/keras/models.py b/tests/util/models/keras/models.py index 1b7ca11c3c..8e40a9f71b 100644 --- a/tests/util/models/keras/models.py +++ b/tests/util/models/keras/models.py @@ -18,7 +18,7 @@ from __future__ import annotations -from typing import Optional, Tuple +from typing import Any, Mapping, Optional, Tuple import tensorflow as tf from packaging.version import Version @@ -63,6 +63,7 @@ def trieste_deep_ensemble_model( ensemble_size: int, bootstrap_data: bool = False, independent_normal: bool = False, + compile_args: Optional[Mapping[str, Any]] = None, ) -> Tuple[DeepEnsemble, KerasEnsemble, KerasOptimizer]: keras_ensemble = trieste_keras_ensemble_model(example_data, ensemble_size, independent_normal) @@ -75,7 +76,9 @@ def trieste_deep_ensemble_model( } optimizer_wrapper = KerasOptimizer(optimizer, fit_args) - model = DeepEnsemble(keras_ensemble, optimizer_wrapper, bootstrap_data) + model = DeepEnsemble( + keras_ensemble, optimizer_wrapper, bootstrap_data, compile_args=compile_args + ) return model, keras_ensemble, optimizer_wrapper diff --git a/trieste/models/gpflux/models.py b/trieste/models/gpflux/models.py index 903a01222c..79a1777df5 100644 --- a/trieste/models/gpflux/models.py +++ b/trieste/models/gpflux/models.py @@ -14,7 +14,7 @@ from __future__ import annotations -from typing import Any, Callable, Optional +from typing import Any, Callable, Mapping, Optional import dill import gpflow @@ -64,6 +64,7 @@ def __init__( optimizer: KerasOptimizer | None = None, num_rff_features: int = 1000, continuous_optimisation: bool = True, + compile_args: Optional[Mapping[str, Any]] = None, ): """ :param model: The underlying GPflux deep Gaussian process model. Passing in a named closure @@ -82,9 +83,23 @@ def __init__( :param continuous_optimisation: if True (default), the optimizer will keep track of the number of epochs across BO iterations and use this number as initial_epoch. This is essential to allow monitoring of model training across BO iterations. + :param compile_args: Keyword arguments to pass to the ``compile`` method of the + Keras model (:class:`~tf.keras.Model`). + See https://keras.io/api/models/model_training_apis/#compile-method for a + list of possible arguments. The ``optimizer`` and ``metrics`` arguments + must not be included. :raise ValueError: If ``model`` has unsupported layers, ``num_rff_features`` is less than 0, - or if the ``optimizer`` is not of a supported type. + if the ``optimizer`` is not of a supported type, or `compile_args` contains + disallowed arguments. """ + if compile_args is None: + compile_args = {} + + if not {"optimizer", "metrics"}.isdisjoint(compile_args): + raise ValueError( + "optimizer and metrics arguments must not be included in compile_args." + ) + if isinstance(model, DeepGP): self._model_closure = None else: @@ -152,7 +167,9 @@ def scheduler(epoch: int, lr: float) -> float: dtype=tf.float64, ) self._model_keras = model.as_training_model() - self._model_keras.compile(self.optimizer.optimizer, metrics=self.optimizer.metrics) + self._model_keras.compile( + optimizer=self.optimizer.optimizer, metrics=self.optimizer.metrics, **compile_args + ) self._absolute_epochs = 0 self._continuous_optimisation = continuous_optimisation diff --git a/trieste/models/keras/models.py b/trieste/models/keras/models.py index e434185988..54fdfe44a5 100644 --- a/trieste/models/keras/models.py +++ b/trieste/models/keras/models.py @@ -15,7 +15,7 @@ from __future__ import annotations import re -from typing import Any, Dict, Optional +from typing import Any, Dict, Mapping, Optional import dill import keras.callbacks @@ -85,6 +85,7 @@ def __init__( bootstrap: bool = False, diversify: bool = False, continuous_optimisation: bool = True, + compile_args: Optional[Mapping[str, Any]] = None, ) -> None: """ :param model: A Keras ensemble model with probabilistic networks as ensemble members. The @@ -106,15 +107,28 @@ def __init__( :param continuous_optimisation: If True (default), the optimizer will keep track of the number of epochs across BO iterations and use this number as initial_epoch. This is essential to allow monitoring of model training across BO iterations. + :param compile_args: Keyword arguments to pass to the ``compile`` method of the + Keras model (:class:`~tf.keras.Model`). + See https://keras.io/api/models/model_training_apis/#compile-method for a + list of possible arguments. The ``optimizer``, ``loss`` and ``metrics`` arguments + must not be included. :raise ValueError: If ``model`` is not an instance of - :class:`~trieste.models.keras.KerasEnsemble` or ensemble has less than two base - learners (networks). + :class:`~trieste.models.keras.KerasEnsemble`, or ensemble has less than two base + learners (networks), or `compile_args` contains disallowed arguments. """ if model.ensemble_size < 2: raise ValueError(f"Ensemble size must be greater than 1 but got {model.ensemble_size}.") super().__init__(optimizer) + if compile_args is None: + compile_args = {} + + if not {"optimizer", "loss", "metrics"}.isdisjoint(compile_args): + raise ValueError( + "optimizer, loss and metrics arguments must not be included in compile_args." + ) + if not self.optimizer.fit_args: self.optimizer.fit_args = { "verbose": 0, @@ -134,9 +148,10 @@ def __init__( self.optimizer.metrics = ["mse"] model.model.compile( - self.optimizer.optimizer, + optimizer=self.optimizer.optimizer, loss=[self.optimizer.loss] * model.ensemble_size, metrics=[self.optimizer.metrics] * model.ensemble_size, + **compile_args, ) if not isinstance(