diff --git a/ChatTTS/core.py b/ChatTTS/core.py index a05660421..736987e5b 100644 --- a/ChatTTS/core.py +++ b/ChatTTS/core.py @@ -67,9 +67,13 @@ def download_models( source: Literal["huggingface", "local", "custom"] = "local", force_redownload=False, custom_path: Optional[torch.serialization.FILE_LIKE] = None, + cache_dir: Optional[str] = None, + local_dir: Optional[str] = None, ) -> Optional[str]: if source == "local": - download_path = custom_path if custom_path is not None else os.getcwd() + download_path = ( + local_dir if local_dir else (cache_dir if cache_dir else os.getcwd()) + ) if ( not check_all_assets(Path(download_path), self.sha256_map, update=True) or force_redownload @@ -83,43 +87,98 @@ def download_models( "download to local path %s failed.", download_path ) return None + elif source == "huggingface": try: - download_path = ( - get_latest_modified_file( - os.path.join( - os.getenv( - "HF_HOME", os.path.expanduser("~/.cache/huggingface") - ), - "hub/models--2Noise--ChatTTS/snapshots", - ) - ) - if custom_path is None - else get_latest_modified_file( - os.path.join(custom_path, "models--2Noise--ChatTTS/snapshots") + if local_dir: + download_path = snapshot_download( + repo_id="2Noise/ChatTTS", + allow_patterns=[ + "*.yaml", + "*.json", + "*.safetensors", + "spk_stat.pt", + "tokenizer.pt", + ], + local_dir=local_dir, + force_download=force_redownload, ) - ) - except: - download_path = None - if download_path is None or force_redownload: - self.logger.log( - logging.INFO, - f"download from HF: https://huggingface.co/2Noise/ChatTTS", - ) - try: + if not check_all_assets( + Path(download_path), self.sha256_map, update=False + ): + self.logger.error("Model verification failed") + return None + elif cache_dir: download_path = snapshot_download( repo_id="2Noise/ChatTTS", - allow_patterns=["*.yaml", "*.json", "*.safetensors"], - cache_dir=custom_path, + allow_patterns=[ + "*.yaml", + "*.json", + "*.safetensors", + "spk_stat.pt", + "tokenizer.pt", + ], + cache_dir=cache_dir, force_download=force_redownload, ) - except: - download_path = None + if not check_all_assets( + Path(download_path), self.sha256_map, update=False + ): + self.logger.error("Model verification failed") + return None else: - self.logger.log( - logging.INFO, - f"load latest snapshot from cache: {download_path}", - ) + try: + download_path = ( + get_latest_modified_file( + os.path.join( + os.getenv( + "HF_HOME", + os.path.expanduser("~/.cache/huggingface"), + ), + "hub/models--2Noise--ChatTTS/snapshots", + ) + ) + if custom_path is None + else get_latest_modified_file( + os.path.join( + custom_path, "models--2Noise--ChatTTS/snapshots" + ) + ) + ) + except: + download_path = None + if download_path is None or force_redownload: + self.logger.log( + logging.INFO, + f"download from HF: https://huggingface.co/2Noise/ChatTTS", + ) + try: + download_path = snapshot_download( + repo_id="2Noise/ChatTTS", + allow_patterns=[ + "*.yaml", + "*.json", + "*.safetensors", + "spk_stat.pt", + "tokenizer.pt", + ], + ) + if not check_all_assets( + Path(download_path), self.sha256_map, update=False + ): + self.logger.error("Model verification failed") + return None + except: + download_path = None + else: + self.logger.log( + logging.INFO, + f"load latest snapshot from cache: {download_path}", + ) + except Exception as e: + self.logger.error(f"Failed to download models: {str(e)}") + download_path = None + elif source == "custom": self.logger.log(logging.INFO, f"try to load from local: {custom_path}") if not check_all_assets(Path(custom_path), self.sha256_map, update=False): @@ -144,8 +203,12 @@ def load( use_flash_attn=False, use_vllm=False, experimental: bool = False, + cache_dir: Optional[str] = None, + local_dir: Optional[str] = None, ) -> bool: - download_path = self.download_models(source, force_redownload, custom_path) + download_path = self.download_models( + source, force_redownload, custom_path, cache_dir, local_dir + ) if download_path is None: return False return self._load(