Skip to content

Commit

Permalink
Add compile_args to deep ensemble model (#816)
Browse files Browse the repository at this point in the history
  • Loading branch information
nfergu authored Jan 31, 2024
1 parent 0aefc2b commit a450749
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 9 deletions.
23 changes: 23 additions & 0 deletions tests/unit/models/gpflux/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,3 +553,26 @@ def test_deepgp_log(

assert mocked_summary_scalar.call_count == num_scalars
assert mocked_summary_histogram.call_count == num_histogram


def test_deepgp_compile_args_specified() -> None:
x_observed = np.linspace(0, 10, 10).reshape((-1, 1))
model = single_layer_dgp_model(x_observed)
# If we get this error we know that the compile_args are being passed to the model
# because Keras will throw an error if it receives both of these arguments.
with pytest.raises(
ValueError, match="You cannot enable `run_eagerly` and `jit_compile` at the same time."
):
DeepGaussianProcess(model, compile_args={"jit_compile": True, "run_eagerly": True})


def test_deepgp_disallowed_compile_args_specified() -> None:
mock_model = unittest.mock.MagicMock(spec=DeepGP)
with pytest.raises(ValueError):
DeepGaussianProcess(
mock_model, compile_args={"run_eagerly": True, "optimizer": unittest.mock.MagicMock()}
)
with pytest.raises(ValueError):
DeepGaussianProcess(
mock_model, compile_args={"run_eagerly": True, "metrics": unittest.mock.MagicMock()}
)
30 changes: 30 additions & 0 deletions tests/unit/models/keras/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,36 @@ def test_deep_ensemble_is_compiled() -> None:
assert model.model.optimizer is not None


def test_deep_ensemble_compile_args_specified() -> None:
example_data = empty_dataset([1], [1])
# If we get this error we know that the compile_args are being passed to the model
# because Keras will throw an error if it receives both of these arguments.
with pytest.raises(
ValueError, match="You cannot enable `run_eagerly` and `jit_compile` at the same time."
):
model, _, _ = trieste_deep_ensemble_model(
example_data, _ENSEMBLE_SIZE, compile_args={"run_eagerly": True, "jit_compile": True}
)


def test_deep_ensemble_disallowed_compile_args_specified() -> None:
mock_ensemble = unittest.mock.MagicMock(spec=KerasEnsemble)
mock_ensemble.ensemble_size = _ENSEMBLE_SIZE
with pytest.raises(ValueError):
DeepEnsemble(
mock_ensemble,
compile_args={"run_eagerly": True, "optimizer": unittest.mock.MagicMock()},
)
with pytest.raises(ValueError):
DeepEnsemble(
mock_ensemble, compile_args={"run_eagerly": True, "loss": unittest.mock.MagicMock()}
)
with pytest.raises(ValueError):
DeepEnsemble(
mock_ensemble, compile_args={"run_eagerly": True, "metrics": unittest.mock.MagicMock()}
)


def test_deep_ensemble_resets_lr_with_lr_schedule() -> None:
example_data = _get_example_data([100, 1])

Expand Down
7 changes: 5 additions & 2 deletions tests/util/models/keras/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from __future__ import annotations

from typing import Optional, Tuple
from typing import Any, Mapping, Optional, Tuple

import tensorflow as tf
from packaging.version import Version
Expand Down Expand Up @@ -63,6 +63,7 @@ def trieste_deep_ensemble_model(
ensemble_size: int,
bootstrap_data: bool = False,
independent_normal: bool = False,
compile_args: Optional[Mapping[str, Any]] = None,
) -> Tuple[DeepEnsemble, KerasEnsemble, KerasOptimizer]:
keras_ensemble = trieste_keras_ensemble_model(example_data, ensemble_size, independent_normal)

Expand All @@ -75,7 +76,9 @@ def trieste_deep_ensemble_model(
}
optimizer_wrapper = KerasOptimizer(optimizer, fit_args)

model = DeepEnsemble(keras_ensemble, optimizer_wrapper, bootstrap_data)
model = DeepEnsemble(
keras_ensemble, optimizer_wrapper, bootstrap_data, compile_args=compile_args
)

return model, keras_ensemble, optimizer_wrapper

Expand Down
23 changes: 20 additions & 3 deletions trieste/models/gpflux/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from __future__ import annotations

from typing import Any, Callable, Optional
from typing import Any, Callable, Mapping, Optional

import dill
import gpflow
Expand Down Expand Up @@ -64,6 +64,7 @@ def __init__(
optimizer: KerasOptimizer | None = None,
num_rff_features: int = 1000,
continuous_optimisation: bool = True,
compile_args: Optional[Mapping[str, Any]] = None,
):
"""
:param model: The underlying GPflux deep Gaussian process model. Passing in a named closure
Expand All @@ -82,9 +83,23 @@ def __init__(
:param continuous_optimisation: if True (default), the optimizer will keep track of the
number of epochs across BO iterations and use this number as initial_epoch. This is
essential to allow monitoring of model training across BO iterations.
:param compile_args: Keyword arguments to pass to the ``compile`` method of the
Keras model (:class:`~tf.keras.Model`).
See https://keras.io/api/models/model_training_apis/#compile-method for a
list of possible arguments. The ``optimizer`` and ``metrics`` arguments
must not be included.
:raise ValueError: If ``model`` has unsupported layers, ``num_rff_features`` is less than 0,
or if the ``optimizer`` is not of a supported type.
if the ``optimizer`` is not of a supported type, or `compile_args` contains
disallowed arguments.
"""
if compile_args is None:
compile_args = {}

if not {"optimizer", "metrics"}.isdisjoint(compile_args):
raise ValueError(
"optimizer and metrics arguments must not be included in compile_args."
)

if isinstance(model, DeepGP):
self._model_closure = None
else:
Expand Down Expand Up @@ -152,7 +167,9 @@ def scheduler(epoch: int, lr: float) -> float:
dtype=tf.float64,
)
self._model_keras = model.as_training_model()
self._model_keras.compile(self.optimizer.optimizer, metrics=self.optimizer.metrics)
self._model_keras.compile(
optimizer=self.optimizer.optimizer, metrics=self.optimizer.metrics, **compile_args
)
self._absolute_epochs = 0
self._continuous_optimisation = continuous_optimisation

Expand Down
23 changes: 19 additions & 4 deletions trieste/models/keras/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from __future__ import annotations

import re
from typing import Any, Dict, Optional
from typing import Any, Dict, Mapping, Optional

import dill
import keras.callbacks
Expand Down Expand Up @@ -85,6 +85,7 @@ def __init__(
bootstrap: bool = False,
diversify: bool = False,
continuous_optimisation: bool = True,
compile_args: Optional[Mapping[str, Any]] = None,
) -> None:
"""
:param model: A Keras ensemble model with probabilistic networks as ensemble members. The
Expand All @@ -106,15 +107,28 @@ def __init__(
:param continuous_optimisation: If True (default), the optimizer will keep track of the
number of epochs across BO iterations and use this number as initial_epoch. This is
essential to allow monitoring of model training across BO iterations.
:param compile_args: Keyword arguments to pass to the ``compile`` method of the
Keras model (:class:`~tf.keras.Model`).
See https://keras.io/api/models/model_training_apis/#compile-method for a
list of possible arguments. The ``optimizer``, ``loss`` and ``metrics`` arguments
must not be included.
:raise ValueError: If ``model`` is not an instance of
:class:`~trieste.models.keras.KerasEnsemble` or ensemble has less than two base
learners (networks).
:class:`~trieste.models.keras.KerasEnsemble`, or ensemble has less than two base
learners (networks), or `compile_args` contains disallowed arguments.
"""
if model.ensemble_size < 2:
raise ValueError(f"Ensemble size must be greater than 1 but got {model.ensemble_size}.")

super().__init__(optimizer)

if compile_args is None:
compile_args = {}

if not {"optimizer", "loss", "metrics"}.isdisjoint(compile_args):
raise ValueError(
"optimizer, loss and metrics arguments must not be included in compile_args."
)

if not self.optimizer.fit_args:
self.optimizer.fit_args = {
"verbose": 0,
Expand All @@ -134,9 +148,10 @@ def __init__(
self.optimizer.metrics = ["mse"]

model.model.compile(
self.optimizer.optimizer,
optimizer=self.optimizer.optimizer,
loss=[self.optimizer.loss] * model.ensemble_size,
metrics=[self.optimizer.metrics] * model.ensemble_size,
**compile_args,
)

if not isinstance(
Expand Down

0 comments on commit a450749

Please sign in to comment.