From 84a659d3e595afae66c810efec422ef200d24f3a Mon Sep 17 00:00:00 2001 From: Sora Date: Sun, 8 Oct 2023 18:16:46 +0800 Subject: [PATCH 1/2] Extract get_text() and infer() from webui.py. --- infer_utils.py | 80 ++++++++++++++++++++++++++++++++++++++++++++++++++ server.py | 73 ++++----------------------------------------- webui.py | 77 +++--------------------------------------------- 3 files changed, 89 insertions(+), 141 deletions(-) create mode 100644 infer_utils.py diff --git a/infer_utils.py b/infer_utils.py new file mode 100644 index 000000000..3a4af4e02 --- /dev/null +++ b/infer_utils.py @@ -0,0 +1,80 @@ +""" +@Author: Kasugano Sora +@Github: https://github.com/jiangyuxiaoxiao +@Date: 2023/10/08-18:01 +@Desc: +@Ver : 1.0.0 +""" +import torch +import commons +from text import cleaned_text_to_sequence, get_bert +from text.cleaner import clean_text + + +def get_text(text, language_str, hps, device): + norm_text, phone, tone, word2ph = clean_text(text, language_str) + phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str) + + if hps.data.add_blank: + phone = commons.intersperse(phone, 0) + tone = commons.intersperse(tone, 0) + language = commons.intersperse(language, 0) + for i in range(len(word2ph)): + word2ph[i] = word2ph[i] * 2 + word2ph[0] += 1 + bert = get_bert(norm_text, word2ph, language_str, device) + del word2ph + assert bert.shape[-1] == len(phone), phone + + if language_str == "ZH": + bert = bert + ja_bert = torch.zeros(768, len(phone)) + elif language_str == "JP": + ja_bert = bert + bert = torch.zeros(1024, len(phone)) + else: + bert = torch.zeros(1024, len(phone)) + ja_bert = torch.zeros(768, len(phone)) + + assert bert.shape[-1] == len( + phone + ), f"Bert seq len {bert.shape[-1]} != {len(phone)}" + + phone = torch.LongTensor(phone) + tone = torch.LongTensor(tone) + language = torch.LongTensor(language) + return bert, ja_bert, phone, tone, language + + +def infer(text, sdp_ratio, noise_scale, noise_scale_w, length_scale, sid, language, hps, net_g, device): + bert, ja_bert, phones, tones, lang_ids = get_text(text, language, hps, device) + with torch.no_grad(): + x_tst = phones.to(device).unsqueeze(0) + tones = tones.to(device).unsqueeze(0) + lang_ids = lang_ids.to(device).unsqueeze(0) + bert = bert.to(device).unsqueeze(0) + ja_bert = ja_bert.to(device).unsqueeze(0) + x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device) + del phones + speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device) + audio = ( + net_g.infer( + x_tst, + x_tst_lengths, + speakers, + tones, + lang_ids, + bert, + ja_bert, + sdp_ratio=sdp_ratio, + noise_scale=noise_scale, + noise_scale_w=noise_scale_w, + length_scale=length_scale, + )[0][0, 0] + .data.cpu() + .float() + .numpy() + ) + del x_tst, tones, lang_ids, bert, x_tst_lengths, speakers + torch.cuda.empty_cache() + return audio diff --git a/server.py b/server.py index 656f77b8a..085320d01 100644 --- a/server.py +++ b/server.py @@ -3,12 +3,10 @@ import torch from av import open as avopen -import commons import utils from models import SynthesizerTrn from text.symbols import symbols -from text import cleaned_text_to_sequence, get_bert -from text.cleaner import clean_text +from infer_utils import infer from scipy.io import wavfile # Flask Init @@ -16,70 +14,6 @@ app.config["JSON_AS_ASCII"] = False -def get_text(text, language_str, hps): - norm_text, phone, tone, word2ph = clean_text(text, language_str) - phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str) - - if hps.data.add_blank: - phone = commons.intersperse(phone, 0) - tone = commons.intersperse(tone, 0) - language = commons.intersperse(language, 0) - for i in range(len(word2ph)): - word2ph[i] = word2ph[i] * 2 - word2ph[0] += 1 - bert = get_bert(norm_text, word2ph, language_str, dev) - del word2ph - assert bert.shape[-1] == len(phone), phone - - if language_str == "ZH": - bert = bert - ja_bert = torch.zeros(768, len(phone)) - elif language_str == "JA": - ja_bert = bert - bert = torch.zeros(1024, len(phone)) - else: - bert = torch.zeros(1024, len(phone)) - ja_bert = torch.zeros(768, len(phone)) - assert bert.shape[-1] == len( - phone - ), f"Bert seq len {bert.shape[-1]} != {len(phone)}" - phone = torch.LongTensor(phone) - tone = torch.LongTensor(tone) - language = torch.LongTensor(language) - return bert, ja_bert, phone, tone, language - - -def infer(text, sdp_ratio, noise_scale, noise_scale_w, length_scale, sid, language): - bert, ja_bert, phones, tones, lang_ids = get_text(text, language, hps) - with torch.no_grad(): - x_tst = phones.to(dev).unsqueeze(0) - tones = tones.to(dev).unsqueeze(0) - lang_ids = lang_ids.to(dev).unsqueeze(0) - bert = bert.to(dev).unsqueeze(0) - ja_bert = ja_bert.to(dev).unsqueeze(0) - x_tst_lengths = torch.LongTensor([phones.size(0)]).to(dev) - speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(dev) - audio = ( - net_g.infer( - x_tst, - x_tst_lengths, - speakers, - tones, - lang_ids, - bert, - ja_bert, - sdp_ratio=sdp_ratio, - noise_scale=noise_scale, - noise_scale_w=noise_scale_w, - length_scale=length_scale, - )[0][0, 0] - .data.cpu() - .float() - .numpy() - ) - return audio - - def replace_punctuation(text, i=2): punctuation = ",。?!" for char in punctuation: @@ -141,7 +75,7 @@ def main(): return "Missing Parameter" if fmt not in ("mp3", "wav", "ogg"): return "Invalid Format" - if language not in ("JA", "ZH"): + if language not in ("JP", "ZH"): return "Invalid language" except: return "Invalid Parameter" @@ -155,6 +89,9 @@ def main(): length_scale=length, sid=speaker, language=language, + hps=hps, + net_g=net_g, + device=dev ) with BytesIO() as wav: diff --git a/webui.py b/webui.py index d9dbd2ca7..19fe0db82 100644 --- a/webui.py +++ b/webui.py @@ -16,12 +16,10 @@ import torch import argparse -import commons import utils from models import SynthesizerTrn from text.symbols import symbols -from text import cleaned_text_to_sequence, get_bert -from text.cleaner import clean_text +from infer_utils import infer import gradio as gr import webbrowser import numpy as np @@ -35,76 +33,6 @@ device = "cuda" -def get_text(text, language_str, hps): - norm_text, phone, tone, word2ph = clean_text(text, language_str) - phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str) - - if hps.data.add_blank: - phone = commons.intersperse(phone, 0) - tone = commons.intersperse(tone, 0) - language = commons.intersperse(language, 0) - for i in range(len(word2ph)): - word2ph[i] = word2ph[i] * 2 - word2ph[0] += 1 - bert = get_bert(norm_text, word2ph, language_str, device) - del word2ph - assert bert.shape[-1] == len(phone), phone - - if language_str == "ZH": - bert = bert - ja_bert = torch.zeros(768, len(phone)) - elif language_str == "JP": - ja_bert = bert - bert = torch.zeros(1024, len(phone)) - else: - bert = torch.zeros(1024, len(phone)) - ja_bert = torch.zeros(768, len(phone)) - - assert bert.shape[-1] == len( - phone - ), f"Bert seq len {bert.shape[-1]} != {len(phone)}" - - phone = torch.LongTensor(phone) - tone = torch.LongTensor(tone) - language = torch.LongTensor(language) - return bert, ja_bert, phone, tone, language - - -def infer(text, sdp_ratio, noise_scale, noise_scale_w, length_scale, sid, language): - global net_g - bert, ja_bert, phones, tones, lang_ids = get_text(text, language, hps) - with torch.no_grad(): - x_tst = phones.to(device).unsqueeze(0) - tones = tones.to(device).unsqueeze(0) - lang_ids = lang_ids.to(device).unsqueeze(0) - bert = bert.to(device).unsqueeze(0) - ja_bert = ja_bert.to(device).unsqueeze(0) - x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device) - del phones - speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device) - audio = ( - net_g.infer( - x_tst, - x_tst_lengths, - speakers, - tones, - lang_ids, - bert, - ja_bert, - sdp_ratio=sdp_ratio, - noise_scale=noise_scale, - noise_scale_w=noise_scale_w, - length_scale=length_scale, - )[0][0, 0] - .data.cpu() - .float() - .numpy() - ) - del x_tst, tones, lang_ids, bert, x_tst_lengths, speakers - torch.cuda.empty_cache() - return audio - - def tts_fn( text, speaker, sdp_ratio, noise_scale, noise_scale_w, length_scale, language ): @@ -120,6 +48,9 @@ def tts_fn( length_scale=length_scale, sid=speaker, language=language, + hps=hps, + net_g=net_g, + device=device ) audio_list.append(audio) silence = np.zeros(hps.data.sampling_rate) # 生成1秒的静音 From 4d5a16c26f1084382550e568805be61aa653dce2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 8 Oct 2023 10:19:19 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- infer_utils.py | 13 ++++++++++++- server.py | 2 +- webui.py | 2 +- 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/infer_utils.py b/infer_utils.py index 3a4af4e02..25af7715c 100644 --- a/infer_utils.py +++ b/infer_utils.py @@ -46,7 +46,18 @@ def get_text(text, language_str, hps, device): return bert, ja_bert, phone, tone, language -def infer(text, sdp_ratio, noise_scale, noise_scale_w, length_scale, sid, language, hps, net_g, device): +def infer( + text, + sdp_ratio, + noise_scale, + noise_scale_w, + length_scale, + sid, + language, + hps, + net_g, + device, +): bert, ja_bert, phones, tones, lang_ids = get_text(text, language, hps, device) with torch.no_grad(): x_tst = phones.to(device).unsqueeze(0) diff --git a/server.py b/server.py index 085320d01..9e562d773 100644 --- a/server.py +++ b/server.py @@ -91,7 +91,7 @@ def main(): language=language, hps=hps, net_g=net_g, - device=dev + device=dev, ) with BytesIO() as wav: diff --git a/webui.py b/webui.py index 19fe0db82..d0f5f40e6 100644 --- a/webui.py +++ b/webui.py @@ -50,7 +50,7 @@ def tts_fn( language=language, hps=hps, net_g=net_g, - device=device + device=device, ) audio_list.append(audio) silence = np.zeros(hps.data.sampling_rate) # 生成1秒的静音