Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass config to constructor when reviving custom functional model #20321

Open
wants to merge 2 commits into
base: master
Choose a base branch
from

Conversation

TrAyZeN
Copy link

@TrAyZeN TrAyZeN commented Oct 3, 2024

Currently, when loading a model instantiated from a custom Model subclass its config is not passed to it's constructor. This leads to some parameters not being restored.

Here is a snippet showing the behavior mentioned:

class CustomModel(Model):
    def __init__(self, *args, param=1, **kwargs):
        super().__init__(*args, **kwargs)
        self.param = param

    def get_config(self):
        base_config = super().get_config()
        config = {"param": self.param}
        return base_config | config

inputs = layers.Input((3,))
outputs = layers.Dense(5)(inputs)
model = CustomModel(inputs=inputs, outputs=outputs, param=3)

new_model = CustomModel.from_config(model.get_config())
print(new_model.param) # prints 1 currently i.e. default value of param

This PR proposes to fix this issue by passing config in functional_from_config to the model constructor.

When loading a model instantiated from a custom Model subclass its
config is not passed to it's constructor. This leads to some parameters
not being restored.
Copy link

google-cla bot commented Oct 3, 2024

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@codecov-commenter
Copy link

codecov-commenter commented Oct 3, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 78.87%. Comparing base (ca88613) to head (0be55ca).
Report is 6 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##           master   #20321      +/-   ##
==========================================
+ Coverage   78.81%   78.87%   +0.06%     
==========================================
  Files         512      513       +1     
  Lines       49056    49236     +180     
  Branches     9033     9076      +43     
==========================================
+ Hits        38664    38837     +173     
- Misses       8528     8532       +4     
- Partials     1864     1867       +3     
Flag Coverage Δ
keras 78.73% <100.00%> (+0.06%) ⬆️
keras-jax 62.38% <100.00%> (+0.11%) ⬆️
keras-numpy 57.40% <100.00%> (-0.01%) ⬇️
keras-tensorflow 63.64% <100.00%> (+0.08%) ⬆️
keras-torch 62.37% <100.00%> (+0.11%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR -- looks good to me!

# remaining config will be passed as keyword arguments to the Model
# constructor.
functional_config = {}
for key in ["layers", "input_layers", "output_layers"]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can simply this block:

  • single iteration over set of keys
  • use pop(key, None)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have applied your suggestion, I like the idea to simplify it. However, it does not keep the previous behavior. Now instead of throwing a KeyError if layers, input_layers or output_layers is not present, it will silently set its value to None. It doesn't seem to be an issue, but I just wanted to mention it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, it would lead to a strange error message instead of a clear one if the config is malformed. Let's avoid that.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reverted the commit then. I am still waiting for the CLA to be handled. I'll let you know when that's done.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants