Skip to content

Commit

Permalink
add instruction for downloading content vec
Browse files Browse the repository at this point in the history
  • Loading branch information
leng-yue committed Mar 5, 2023
1 parent d75ea7c commit 3201bd6
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 3 deletions.
16 changes: 16 additions & 0 deletions fish_diffusion/modules/feature_extractors/content_vec.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import logging
import os

import torch
from loguru import logger

from .base import BaseFeatureExtractor
from .builder import FEATURE_EXTRACTORS
Expand All @@ -19,6 +21,20 @@ def __init__(
):
super().__init__()

if os.path.exists(checkpoint_path) is False:
logger.error(f"Checkpoint {checkpoint_path} does not exist")
logger.error(
f"If you are trying to use the pretrained ContentVec, you can either:"
)
logger.error(
f"1. Download the ContentVec from https://github.com/fishaudio/fish-diffusion/releases/tag/v1.12"
)
logger.error(
f"2. Run `python tools/download_nsf_hifigan.py --content-vec` to download the ContentVec model"
)

raise FileNotFoundError(f"Checkpoint {checkpoint_path} does not exist")

models, _, _ = checkpoint_utils.load_model_ensemble_and_task(
[checkpoint_path], suffix=""
)
Expand Down
5 changes: 3 additions & 2 deletions tools/diffusion/clean_speaker_embeddings.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import torch

data = torch.load(
"checkpoints/epoch=198-step=260000-valid_loss=0.18.ckpt", map_location="cpu"
"logs/DiffSVC/9ddsi2gk/checkpoints/epoch=88-step=300000-valid_loss=0.08.ckpt",
map_location="cpu",
)

del data["state_dict"]["model.speaker_encoder.embedding.weight"]

torch.save(data, "checkpoints/epoch=198-step=260000-valid_loss=0.18-fixed.ckpt")
torch.save(data, "checkpoints/content-vec-pretrained-v1.ckpt")
2 changes: 1 addition & 1 deletion tools/diffusion/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
i.replace(".naive_noise_predictor.", ".") for i in missing_keys
)

assert len(unexpected_keys) == 0
assert len(unexpected_keys) == 0, f"Unexpected keys: {unexpected_keys}"

if args.only_train_speaker_embeddings:
for name, param in model.named_parameters():
Expand Down

0 comments on commit 3201bd6

Please sign in to comment.