Skip to content

Commit

Permalink
Fix DE serialisation
Browse files Browse the repository at this point in the history
  • Loading branch information
Uri Granta committed Sep 4, 2024
1 parent a37ffc7 commit 9045c4c
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions trieste/models/keras/architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
from gpflow.keras import tf_keras

try:
from keras.src.saving.serialization_lib import SafeModeScope
except ImportError: # pragma: no cover (tested but not by coverage)
SafeModeScope = tf_keras.src.saving.serialization_lib.SafeModeScope
except AttributeError: # pragma: no cover (tested but not by coverage)
SafeModeScope = contextlib.nullcontext
from tensorflow_probability.python.layers.distribution_layer import DistributionLambda, _serialize

Expand Down Expand Up @@ -147,7 +147,7 @@ def __setstate__(self, state: dict[str, Any]) -> None:
# When unpickling restore the model using model_from_json.
self.__dict__.update(state)
# TF 2.15 disallows loading lambdas without "safe-mode" being disabled
# unfortunately, tfp.layers.DistributionLambda seems to use lambdas
# unfortunately, tfp.layers.DistributionLambda uses lambdas
with SafeModeScope(False):
self._model = tf_keras.models.model_from_json(
state["_model"], custom_objects={"MultivariateNormalTriL": MultivariateNormalTriL}
Expand Down

0 comments on commit 9045c4c

Please sign in to comment.