Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add local dir parameter #887

Open
wants to merge 3 commits into
base: dev
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 94 additions & 31 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,13 @@ def download_models(
source: Literal["huggingface", "local", "custom"] = "local",
force_redownload=False,
custom_path: Optional[torch.serialization.FILE_LIKE] = None,
cache_dir: Optional[str] = None,
local_dir: Optional[str] = None,
) -> Optional[str]:
if source == "local":
download_path = custom_path if custom_path is not None else os.getcwd()
download_path = (
local_dir if local_dir else (cache_dir if cache_dir else os.getcwd())
)
if (
not check_all_assets(Path(download_path), self.sha256_map, update=True)
or force_redownload
Expand All @@ -83,43 +87,98 @@ def download_models(
"download to local path %s failed.", download_path
)
return None

elif source == "huggingface":
try:
download_path = (
get_latest_modified_file(
os.path.join(
os.getenv(
"HF_HOME", os.path.expanduser("~/.cache/huggingface")
),
"hub/models--2Noise--ChatTTS/snapshots",
)
)
if custom_path is None
else get_latest_modified_file(
os.path.join(custom_path, "models--2Noise--ChatTTS/snapshots")
if local_dir:
download_path = snapshot_download(
repo_id="2Noise/ChatTTS",
allow_patterns=[
"*.yaml",
"*.json",
"*.safetensors",
"spk_stat.pt",
"tokenizer.pt",
],
local_dir=local_dir,
force_download=force_redownload,
)
)
except:
download_path = None
if download_path is None or force_redownload:
self.logger.log(
logging.INFO,
f"download from HF: https://huggingface.co/2Noise/ChatTTS",
)
try:
if not check_all_assets(
Path(download_path), self.sha256_map, update=False
):
self.logger.error("Model verification failed")
return None
elif cache_dir:
download_path = snapshot_download(
repo_id="2Noise/ChatTTS",
allow_patterns=["*.yaml", "*.json", "*.safetensors"],
cache_dir=custom_path,
allow_patterns=[
"*.yaml",
"*.json",
"*.safetensors",
"spk_stat.pt",
"tokenizer.pt",
],
cache_dir=cache_dir,
force_download=force_redownload,
)
except:
download_path = None
if not check_all_assets(
Path(download_path), self.sha256_map, update=False
):
self.logger.error("Model verification failed")
return None
else:
self.logger.log(
logging.INFO,
f"load latest snapshot from cache: {download_path}",
)
try:
download_path = (
get_latest_modified_file(
os.path.join(
os.getenv(
"HF_HOME",
os.path.expanduser("~/.cache/huggingface"),
),
"hub/models--2Noise--ChatTTS/snapshots",
)
)
if custom_path is None
else get_latest_modified_file(
os.path.join(
custom_path, "models--2Noise--ChatTTS/snapshots"
)
)
)
except:
download_path = None
if download_path is None or force_redownload:
self.logger.log(
logging.INFO,
f"download from HF: https://huggingface.co/2Noise/ChatTTS",
)
try:
download_path = snapshot_download(
repo_id="2Noise/ChatTTS",
allow_patterns=[
"*.yaml",
"*.json",
"*.safetensors",
"spk_stat.pt",
"tokenizer.pt",
],
)
if not check_all_assets(
Path(download_path), self.sha256_map, update=False
):
self.logger.error("Model verification failed")
return None
except:
download_path = None
else:
self.logger.log(
logging.INFO,
f"load latest snapshot from cache: {download_path}",
)
except Exception as e:
self.logger.error(f"Failed to download models: {str(e)}")
download_path = None

elif source == "custom":
self.logger.log(logging.INFO, f"try to load from local: {custom_path}")
if not check_all_assets(Path(custom_path), self.sha256_map, update=False):
Expand All @@ -144,8 +203,12 @@ def load(
use_flash_attn=False,
use_vllm=False,
experimental: bool = False,
cache_dir: Optional[str] = None,
local_dir: Optional[str] = None,
) -> bool:
download_path = self.download_models(source, force_redownload, custom_path)
download_path = self.download_models(
source, force_redownload, custom_path, cache_dir, local_dir
)
if download_path is None:
return False
return self._load(
Expand Down