diff --git a/fish_diffusion/modules/feature_extractors/content_vec.py b/fish_diffusion/modules/feature_extractors/content_vec.py index 3ede324f..0cabf3f0 100644 --- a/fish_diffusion/modules/feature_extractors/content_vec.py +++ b/fish_diffusion/modules/feature_extractors/content_vec.py @@ -1,6 +1,8 @@ import logging +import os import torch +from loguru import logger from .base import BaseFeatureExtractor from .builder import FEATURE_EXTRACTORS @@ -19,6 +21,20 @@ def __init__( ): super().__init__() + if os.path.exists(checkpoint_path) is False: + logger.error(f"Checkpoint {checkpoint_path} does not exist") + logger.error( + f"If you are trying to use the pretrained ContentVec, you can either:" + ) + logger.error( + f"1. Download the ContentVec from https://github.com/fishaudio/fish-diffusion/releases/tag/v1.12" + ) + logger.error( + f"2. Run `python tools/download_nsf_hifigan.py --content-vec` to download the ContentVec model" + ) + + raise FileNotFoundError(f"Checkpoint {checkpoint_path} does not exist") + models, _, _ = checkpoint_utils.load_model_ensemble_and_task( [checkpoint_path], suffix="" ) diff --git a/tools/diffusion/clean_speaker_embeddings.py b/tools/diffusion/clean_speaker_embeddings.py index 29965e7f..9c73b098 100644 --- a/tools/diffusion/clean_speaker_embeddings.py +++ b/tools/diffusion/clean_speaker_embeddings.py @@ -1,9 +1,10 @@ import torch data = torch.load( - "checkpoints/epoch=198-step=260000-valid_loss=0.18.ckpt", map_location="cpu" + "logs/DiffSVC/9ddsi2gk/checkpoints/epoch=88-step=300000-valid_loss=0.08.ckpt", + map_location="cpu", ) del data["state_dict"]["model.speaker_encoder.embedding.weight"] -torch.save(data, "checkpoints/epoch=198-step=260000-valid_loss=0.18-fixed.ckpt") +torch.save(data, "checkpoints/content-vec-pretrained-v1.ckpt") diff --git a/tools/diffusion/train.py b/tools/diffusion/train.py index e2e78dea..c71768bf 100644 --- a/tools/diffusion/train.py +++ b/tools/diffusion/train.py @@ -59,7 +59,7 @@ i.replace(".naive_noise_predictor.", ".") for i in missing_keys ) - assert len(unexpected_keys) == 0 + assert len(unexpected_keys) == 0, f"Unexpected keys: {unexpected_keys}" if args.only_train_speaker_embeddings: for name, param in model.named_parameters():