From 4c9e7be8404b63dcb6323c97fd917d49b3ac0bf2 Mon Sep 17 00:00:00 2001 From: Mehmet Deniz Birlikci Date: Sun, 18 Feb 2024 16:22:12 -0500 Subject: [PATCH] safetensors support when loading from hf_hub --> check .bin file first, .safetensors only if .bin not found --- requirements.txt | 1 + src/open_clip/factory.py | 4 ++++ src/open_clip/pretrained.py | 26 ++++++++++++++++++++------ 3 files changed, 25 insertions(+), 6 deletions(-) diff --git a/requirements.txt b/requirements.txt index 3ff4f5d0a..75cca400a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,7 @@ regex ftfy tqdm huggingface_hub +safetensors sentencepiece protobuf timm diff --git a/src/open_clip/factory.py b/src/open_clip/factory.py index cf62d5a1c..ae32a2177 100644 --- a/src/open_clip/factory.py +++ b/src/open_clip/factory.py @@ -144,6 +144,10 @@ def load_checkpoint(model, checkpoint_path, strict=True): from .big_vision import load_big_vision_weights load_big_vision_weights(model, checkpoint_path) return {} + elif Path(checkpoint_path).suffix in ('.safetensors',): + from safetensors.torch import load_model + load_model(model, checkpoint_path) + return {} state_dict = load_state_dict(checkpoint_path) # detect old format and make compatible with new format diff --git a/src/open_clip/pretrained.py b/src/open_clip/pretrained.py index e7cd74fe1..b525c3515 100644 --- a/src/open_clip/pretrained.py +++ b/src/open_clip/pretrained.py @@ -543,14 +543,28 @@ def has_hf_hub(necessary=False): def download_pretrained_from_hf( - model_id: str, - filename: str = 'open_clip_pytorch_model.bin', - revision=None, - cache_dir: Union[str, None] = None, + model_id: str, + filename: Union[str, None] = None, + revision=None, + cache_dir: Union[str, None] = None, ): has_hf_hub(True) - cached_file = hf_hub_download(model_id, filename, revision=revision, cache_dir=cache_dir) - return cached_file + + # List of filenames to try downloading --> try safetensors if .bin file is missing + filenames_to_try = ['open_clip_pytorch_model.bin', 'open_clip_pytorch_model.safetensors'] if filename is None else [filename] + + last_exception = None # Variable to store the last exception + for fname in filenames_to_try: + try: + # Attempt to download the file + cached_file = hf_hub_download(model_id, fname, revision=revision, cache_dir=cache_dir) + return cached_file # Return the path to the downloaded file if successful + except Exception as e: + last_exception = e # Store the last exception encountered + continue # Try the next file + + # If the loop completes without returning, raise the last encountered exception + raise FileNotFoundError(f"Failed to download any files for {model_id}. Last error: {last_exception}") from last_exception def download_pretrained(