diff --git a/keras/src/saving/serialization_lib.py b/keras/src/saving/serialization_lib.py index e01d22d52728..91ff178459e8 100644 --- a/keras/src/saving/serialization_lib.py +++ b/keras/src/saving/serialization_lib.py @@ -149,10 +149,12 @@ def serialize_keras_object(obj): # Special cases: if isinstance(obj, bytes): - return { - "class_name": "__bytes__", - "config": {"value": obj.decode("utf-8")}, - } + try: + value = obj.decode("utf-8") + # For `torch` backend `latin-1` works + except UnicodeDecodeError: + value = obj.decode("latin-1") + return {"class_name": "__bytes__", "config": {"value": value}} if isinstance(obj, slice): return { "class_name": "__slice__", diff --git a/keras/src/utils/torch_utils.py b/keras/src/utils/torch_utils.py index a5e1e1e2fb0e..0c340f83ccb0 100644 --- a/keras/src/utils/torch_utils.py +++ b/keras/src/utils/torch_utils.py @@ -96,7 +96,10 @@ def __init__(self, module, name=None, **kwargs): f"Received uninitialized LazyModule: module={module}" ) - self.module = module.to(get_device()) + if hasattr(module, "to"): + self.module = module.to(get_device()) + elif isinstance(module, dict) and module["class_name"] == "__bytes__": + self.module = module["config"]["value"] self._track_module_parameters() def parameters(self, recurse=True): @@ -152,8 +155,12 @@ def from_config(cls, config): import torch if "module" in config: - buffer = io.BytesIO(config["module"]) - config["module"] = torch.load(buffer, weights_only=False) + buffer = io.BytesIO( + config["module"]["config"]["value"].encode("latin-1") + ) + config["module"]["config"]["value"] = torch.load( + buffer, weights_only=False + ) return cls(**config) diff --git a/keras/src/utils/torch_utils_test.py b/keras/src/utils/torch_utils_test.py index 1be561d94f5e..20dcb8c6abd4 100644 --- a/keras/src/utils/torch_utils_test.py +++ b/keras/src/utils/torch_utils_test.py @@ -10,6 +10,8 @@ from keras.src import models from keras.src import saving from keras.src import testing +from keras.src.models import Model +from keras.src.saving import load_model from keras.src.utils.torch_utils import TorchModuleWrapper @@ -235,3 +237,20 @@ def test_from_config(self): new_mw = TorchModuleWrapper.from_config(config) for ref_w, new_w in zip(mw.get_weights(), new_mw.get_weights()): self.assertAllClose(ref_w, new_w, atol=1e-5) + + def test_serialize_deserialize_TorchModuleWrapper(self): + torch_module = torch.nn.Linear(4, 4) + wrapped_layer = TorchModuleWrapper(torch_module) + + inputs = layers.Input(shape=(4,)) + outputs = wrapped_layer(inputs) + model = Model(inputs=inputs, outputs=outputs) + + test_ds = np.random.random(size=(10, 4)) + result_1 = model.predict(test_ds) + + model.save("./serialized.keras") + reloaded_model = load_model("./serialized.keras") + + result_2 = reloaded_model.predict(test_ds) + self.assertAllClose(result_1, result_2)