-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Handling checkpoint-breaking changes #48
Comments
These are some very important considerations. I myself have angered some colleagues by making old checkpoints unusable. Now I am also looking at #49 which would introduce much more flexibility to the user wrt model choices. Mostly for that reason and because I don't think we have the human-power to assure backwards compatibility I am leaning towards option 1. Maybe in the future with a more stable repo + more staff we can implement 3?
With such information every checkpoint should be usable for a long time. Maybe I am very much overestimating how much time 3 would require. If that is the case I gladly change my opinion. |
I am a bit unsure myself about how much work it would really be. As long as we only rename members or change the hierarchy of neural-lam/neural_lam/models/ar_model.py Lines 584 to 596 in 9d558d1
It just has to be generalized to more than g2m_gnn.grid_mlp.0.weight .
When things can get tricky is if we reorder input features or change dimensionalities of something. But thinking about this a bit more now I realize:
|
I had to do some "surgery" to one of my old checkpoint files, after I had changed the ordering of input features in the implementation. This corresponds to the first bullet point in my comment above. I'll put the script here as an example of what a checkpoint-conversion script could look like: # Standard library
import os
from argparse import ArgumentParser
from collections import OrderedDict
# Third-party
import torch
# Parameters to reorder dimensions in
# NOTE: If multiple reoders per parameter they are applied sequentially
REORDER_INPUT_DIMS = {
"grid_prev_embedder.0.weight": OrderedDict({49: 34}),
"grid_current_embedder.0.weight": OrderedDict({66: 51}),
}
def main():
"""
Upgrade a checkpoint file to reflect changes to architecture.
Here specifically reordering of input features.
"""
parser = ArgumentParser(description="Upgrade checkpoint file")
parser.add_argument(
"--load",
type=str,
help="Path to checkpoint file to upgrade",
)
args = parser.parse_args()
assert args.load, "Must specify path to checkpoint file to load"
# Load checkpoint file
checkpoint_dict = torch.load(args.load, map_location="cpu")
state_dict = checkpoint_dict["state_dict"]
# Reorder dimensions
for param_name, reorder_dict in REORDER_INPUT_DIMS.items():
param_tensor = state_dict[
param_name
] # Reorder dimensions in this param
for from_dim, to_dim in reorder_dict.items():
# Extract vector at from_dim
# indexing along dim 1 for input features
moved_vec = param_tensor[:, from_dim : (from_dim + 1)]
# Remove from_dim from param
param_tensor = torch.cat(
(param_tensor[:, :from_dim], param_tensor[:, (from_dim + 1) :]),
dim=1,
)
# Insert vector as dimension to_dim
param_tensor = torch.cat(
(param_tensor[:, :to_dim], moved_vec, param_tensor[:, to_dim:]),
dim=1,
)
# Re-write parameter in state dict
state_dict[param_name] = param_tensor
# Save updated state dict
path_dirname, path_basename = os.path.split(args.load)
upgraded_ckpt_path = os.path.join(path_dirname, f"upgraded_{path_basename}")
torch.save(checkpoint_dict, upgraded_ckpt_path)
if __name__ == "__main__":
main() |
Background
As we make more changes to the code there will be points where checkpoints from saved models can not be directly loaded in a newer version of neural-lam. This happens in particular if we start making changes to variable names of nn.Module attributes and the overall structure of the model classes. It would be good to have a policy of how we want to handle such breaking changes. This issue is for discussing this.
Proposals
I see three main options:
ARModel
:neural-lam/neural_lam/models/ar_model.py
Lines 576 to 596 in 9d558d1
Considerations for point 2 and 3
My view
on_load_checkpoint
would get unnecessarily complicated and I'd rather just do the conversion once and have a set of new checkpoint files. It is also easy to do both 2 and 3: if you try to load an old checkpoint you just convert it before loading.Tagging @leifdenby and @sadamov to get your input.
The text was updated successfully, but these errors were encountered: