diff --git a/trieste/models/gpflux/models.py b/trieste/models/gpflux/models.py index 8f685f580d..b311d41322 100644 --- a/trieste/models/gpflux/models.py +++ b/trieste/models/gpflux/models.py @@ -212,6 +212,10 @@ def __getstate__(self) -> dict[str, Any]: finally: self._model_keras.history.model = history_model + # don't try to serialize any other copies of the history callback + if isinstance(state.get("_last_optimization_result"), keras.callbacks.History): + state["_last_optimization_result"] = ... + return state def __setstate__(self, state: dict[str, Any]) -> None: @@ -265,6 +269,8 @@ def __setstate__(self, state: dict[str, Any]) -> None: model = tf.keras.models.model_from_json(model_json) model.set_weights(weights) self._model_keras.history.set_model(model) + if state.get("_last_optimization_result") is ...: + self._last_optimization_result = self._model_keras.history def __repr__(self) -> str: """""" diff --git a/trieste/models/keras/architectures.py b/trieste/models/keras/architectures.py index 06dad5931f..7fc19385d0 100644 --- a/trieste/models/keras/architectures.py +++ b/trieste/models/keras/architectures.py @@ -22,6 +22,7 @@ from typing import Any, Callable, Sequence import dill +import keras.callbacks import numpy as np import tensorflow as tf import tensorflow_probability as tfp @@ -127,6 +128,10 @@ def __getstate__(self) -> dict[str, Any]: finally: self._model.history.model = history_model + # Don't try to serialize any other copies of the history callback + if isinstance(state.get("_last_optimization_result"), keras.callbacks.History): + state["_last_optimization_result"] = ... + return state def __setstate__(self, state: dict[str, Any]) -> None: @@ -150,6 +155,8 @@ def __setstate__(self, state: dict[str, Any]) -> None: ) model.set_weights(weights) self._model.history.set_model(model) + if state.get("_last_optimization_result") is ...: + self._last_optimization_result = self._model.history class KerasEnsembleNetwork: