-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cast learning_rate to float lambda for pickle safety when doing model.load #1901
Cast learning_rate to float lambda for pickle safety when doing model.load #1901
Conversation
…ssed a function with non-float types
stable_baselines3/common/utils.py
Outdated
@@ -92,7 +92,7 @@ def get_schedule_fn(value_schedule: Union[Schedule, float]) -> Schedule: | |||
value_schedule = constant_fn(float(value_schedule)) | |||
else: | |||
assert callable(value_schedule) | |||
return value_schedule | |||
return lambda _: float(value_schedule(_)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe a better solution is to do a call to value_schedule(1.0)
and check that the return type is a float (and output a useful error message if not).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
or, what you do is fine but I would explicitly name the parameter progress_remaining
and add a comment of why
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm... I see the value in both. Let me noodle a bit and I'll see if I can sort it out during my lunch. Thanks araffin!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks =)
Awesome! Thanks again araffin! The docs on SB3 you and the crew wrote and your guidance made this a breeze :) |
….load (DLR-RM#1901) * create failing test for unpickle error * Fix learning_rate argument causing failure in weights_only=True if passed a function with non-float types * Updated with feedback from araffin on PR#1901 * Update test and version * Update changelog and SBX doc --------- Co-authored-by: Antonin Raffin <[email protected]>
Description
closes #1900
Motivation and Context
closes [Bug]: if learning_rate function uses special types, they can cause torch.load to fail when weights_only=True #1900
Types of changes
Checklist
make format
(required)make check-codestyle
andmake lint
(required)make pytest
andmake type
both pass. (required)make doc
(required)Note: You can run most of the checks using
make commit-checks
.Note: we are using a maximum length of 127 characters per line