Skip to content

Commit

Permalink
fix(xtts): clearer error message when file given to checkpoint_dir
Browse files Browse the repository at this point in the history
  • Loading branch information
eginhard committed Dec 2, 2024
1 parent 98a372b commit ce20253
Showing 1 changed file with 13 additions and 9 deletions.
22 changes: 13 additions & 9 deletions TTS/tts/models/xtts.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Optional

import librosa
import torch
Expand All @@ -10,6 +11,7 @@
from coqpit import Coqpit
from trainer.io import load_fsspec

from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.layers.xtts.gpt import GPT
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
from TTS.tts.layers.xtts.stream_generator import init_stream_support
Expand Down Expand Up @@ -719,14 +721,14 @@ def get_compatible_checkpoint_state_dict(self, model_path):

def load_checkpoint(
self,
config,
checkpoint_dir=None,
checkpoint_path=None,
vocab_path=None,
eval=True,
strict=True,
use_deepspeed=False,
speaker_file_path=None,
config: XttsConfig,
checkpoint_dir: Optional[str] = None,
checkpoint_path: Optional[str] = None,
vocab_path: Optional[str] = None,
eval: bool = True,
strict: bool = True,
use_deepspeed: bool = False,
speaker_file_path: Optional[str] = None,
):
"""
Loads a checkpoint from disk and initializes the model's state and tokenizer.
Expand All @@ -742,7 +744,9 @@ def load_checkpoint(
Returns:
None
"""

if checkpoint_dir is not None and Path(checkpoint_dir).is_file():
msg = f"You passed a file to `checkpoint_dir=`. Use `checkpoint_path={checkpoint_dir}` instead."
raise ValueError(msg)
model_path = checkpoint_path or os.path.join(checkpoint_dir, "model.pth")
if vocab_path is None:
if checkpoint_dir is not None and (Path(checkpoint_dir) / "vocab.json").is_file():
Expand Down

0 comments on commit ce20253

Please sign in to comment.