Skip to content

Commit

Permalink
✨ Add script for uploading ckpts
Browse files Browse the repository at this point in the history
  • Loading branch information
o-laurent committed Dec 22, 2023
1 parent 223407a commit a1d2530
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 2 deletions.
10 changes: 8 additions & 2 deletions torch_uncertainty/utils/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
from huggingface_hub import hf_hub_download


def load_hf(weight_id: str) -> tuple[torch.Tensor, dict]:
def load_hf(weight_id: str, version: int = 0) -> tuple[torch.Tensor, dict]:
"""Load a model from the HuggingFace hub.
Args:
weight_id (str): The id of the model to load.
version (int): The id of the version when there are several on HF.
Returns:
Tuple[Tensor, Dict]: The model weights and config.
Expand All @@ -20,7 +21,12 @@ def load_hf(weight_id: str) -> tuple[torch.Tensor, dict]:
repo_id = f"torch-uncertainty/{weight_id}"

# Load the weights
weight_path = hf_hub_download(repo_id=repo_id, filename=f"{weight_id}.ckpt")
if version is None or version == 0:
filename = f"{weight_id}.ckpt"
else:
filename = f"{weight_id}_{version}.ckpt"

weight_path = hf_hub_download(repo_id=repo_id, filename=filename)
weight = torch.load(weight_path, map_location=torch.device("cpu"))
if "state_dict" in weight: # coverage: ignore
weight = weight["state_dict"]
Expand Down
35 changes: 35 additions & 0 deletions torch_uncertainty/utils/to_hub_format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""Creates the checkpoint for hugging face."""

import argparse
from pathlib import Path

import torch

parser = argparse.ArgumentParser(
prog="to_hub_format",
description="Post-process the checkpoints before the upload to HuggingFace",
)
parser.add_argument(
"--name", type=str, required=True, help="path to the checkpoint"
)
parser.add_argument(
"--path", type=Path, required=True, help="path to the checkpoint"
)
parser.add_argument(
"--version", type=int, default=0, help="path to the checkpoint"
)

args = parser.parse_args()

if not args.path.exists():
raise ValueError("File does not exist")

model = torch.load(args.path)["state_dict"]
model = {key.replace("model.", ""): val.cpu() for key, val in model.items()}

output_name = args.name
if args.version != 0:
output_name += "_" + str(args.version)
output_name += ".ckpt"

torch.save(model, output_name)

0 comments on commit a1d2530

Please sign in to comment.