Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model weights not saved if a custom layer class contains a list of layers named self._layers #20260

Open
mpetteno opened this issue Sep 16, 2024 · 2 comments
Assignees
Labels

Comments

@mpetteno
Copy link

Hi everyone,

I would like to point out a problem that I found while developing a custom model: layer's weights are not saved if a Model subclass initializes a custom layer that contains sublayers in a class-level list variable called self._layers. The code below proves the issue and should illustrates it better.

import os
from pathlib import Path

import keras
import numpy as np


@keras.saving.register_keras_serializable(package="KerasTest", name="CustomLayer")
class CustomLayer(keras.Layer):

    def __init__(self, bugged: bool = False, name="test_layer", **kwargs):
        super(CustomLayer, self).__init__(name=name, **kwargs)
        self._bugged = bugged
        if self._bugged:
            self._layers = [
                keras.layers.Dense(64, activation='relu'),
                keras.layers.Dense(32, activation='relu'),
                keras.layers.Dense(10, activation='softmax')
            ]
        else:
            self._custom_layers = [
                keras.layers.Dense(64, activation='relu'),
                keras.layers.Dense(32, activation='relu'),
                keras.layers.Dense(10, activation='softmax')
            ]

    def call(self, inputs):
        x = inputs
        layer_list = self._layers if self._bugged else self._custom_layers
        for layer in layer_list:
            x = layer(x)
        return x

    def get_config(self):
        base_config = super().get_config()
        config = {
            "bugged": self._bugged
        }
        return {**base_config, **config}


@keras.saving.register_keras_serializable(package="KerasTest", name="TestModel")
class TestModel(keras.Model):

    def __init__(self, bugged: bool = False, name="test_model", **kwargs):
        super(TestModel, self).__init__(name=name, **kwargs)
        self._bugged = bugged
        self._custom_layer = CustomLayer(bugged=bugged)

    def call(self, inputs):
        return self._custom_layer(inputs)

    def get_config(self):
        base_config = super().get_config()
        config = {
            "bugged": self._bugged
        }
        return {**base_config, **config}


def test_model(bugged: bool = False):
    output_path = Path("./output")
    model_name_prefix = "bugged" if bugged else "fixed"

    # Dataset generation
    num_samples = 1000
    input_data = keras.random.uniform(shape=(num_samples, 32))
    labels = keras.random.randint(minval=0, maxval=10, shape=(num_samples,))
    labels = keras.utils.to_categorical(labels, num_classes=10)

    # Test bugged model
    model = TestModel(bugged=bugged)
    model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
    model.fit(input_data, labels, epochs=5, batch_size=32, validation_split=0.2)
    trained_weights = model.get_weights()

    # Save and load model
    output_path.mkdir(parents=True, exist_ok=True)
    model.save(output_path / f"{model_name_prefix}_model.keras")
    loaded_model = keras.saving.load_model(output_path / f"{model_name_prefix}_model.keras")
    loaded_weights = loaded_model.get_weights()

    print(f"------------ {model_name_prefix.capitalize()} - Compare trained weights to loaded weights -------------")
    for i, (uw, tw) in enumerate(zip(trained_weights, loaded_weights)):
        comparison = np.array_equal(uw, tw)
        print(f"Layer {i} -> Weights match: {comparison}")


if __name__ == '__main__':
    os.environ["KERAS_BACKEND"] = "tensorflow"
    test_model(bugged=True)
    test_model(bugged=False)

The output is:

------------ Bugged - Compare trained weights to loaded weights -------------
Layer 0 -> Weights match: False
Layer 1 -> Weights match: False
Layer 2 -> Weights match: False
Layer 3 -> Weights match: False
Layer 4 -> Weights match: False
Layer 5 -> Weights match: False
------------ Fixed - Compare trained weights to loaded weights -------------
Layer 0 -> Weights match: True
Layer 1 -> Weights match: True
Layer 2 -> Weights match: True
Layer 3 -> Weights match: True
Layer 4 -> Weights match: True
Layer 5 -> Weights match: True

I think that the same problem has been solved for the Model class by declaring a setter method (line 170). Perhaps it is possible to use the same approach for the Layer class.

@fchollet
Copy link
Member

_layers is an internal property and you should not override it. Just use any other name.

@mpetteno
Copy link
Author

mpetteno commented Sep 16, 2024

Yes now I know, but since it is not documented anywhere and it is not an uncommon name to use for a property, I was suggesting to apply the same approach used for the Model class and print an error when trying to set the property from the child class. Because it was not easy to find the error since you only know that something is wrong after the training and you have no errors or warnings.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants