diff --git a/requirements.txt b/requirements.txt index 3ff4f5d0a..75cca400a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,7 @@ regex ftfy tqdm huggingface_hub +safetensors sentencepiece protobuf timm diff --git a/src/open_clip/factory.py b/src/open_clip/factory.py index cf62d5a1c..ae32a2177 100644 --- a/src/open_clip/factory.py +++ b/src/open_clip/factory.py @@ -144,6 +144,10 @@ def load_checkpoint(model, checkpoint_path, strict=True): from .big_vision import load_big_vision_weights load_big_vision_weights(model, checkpoint_path) return {} + elif Path(checkpoint_path).suffix in ('.safetensors',): + from safetensors.torch import load_model + load_model(model, checkpoint_path) + return {} state_dict = load_state_dict(checkpoint_path) # detect old format and make compatible with new format diff --git a/src/open_clip/pretrained.py b/src/open_clip/pretrained.py index e7cd74fe1..b525c3515 100644 --- a/src/open_clip/pretrained.py +++ b/src/open_clip/pretrained.py @@ -543,14 +543,28 @@ def has_hf_hub(necessary=False): def download_pretrained_from_hf( - model_id: str, - filename: str = 'open_clip_pytorch_model.bin', - revision=None, - cache_dir: Union[str, None] = None, + model_id: str, + filename: Union[str, None] = None, + revision=None, + cache_dir: Union[str, None] = None, ): has_hf_hub(True) - cached_file = hf_hub_download(model_id, filename, revision=revision, cache_dir=cache_dir) - return cached_file + + # List of filenames to try downloading --> try safetensors if .bin file is missing + filenames_to_try = ['open_clip_pytorch_model.bin', 'open_clip_pytorch_model.safetensors'] if filename is None else [filename] + + last_exception = None # Variable to store the last exception + for fname in filenames_to_try: + try: + # Attempt to download the file + cached_file = hf_hub_download(model_id, fname, revision=revision, cache_dir=cache_dir) + return cached_file # Return the path to the downloaded file if successful + except Exception as e: + last_exception = e # Store the last exception encountered + continue # Try the next file + + # If the loop completes without returning, raise the last encountered exception + raise FileNotFoundError(f"Failed to download any files for {model_id}. Last error: {last_exception}") from last_exception def download_pretrained(