forked from snakers4/silero-models
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathhubconf.py
executable file
·149 lines (131 loc) · 6.12 KB
/
hubconf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
dependencies = ['torch']
import os
import torch
def silero_stt(language='en',
version='latest',
jit_model='jit',
**kwargs):
""" Silero Speech-To-Text Model(s)
language (str): language of the model, now available are ['en', 'de', 'es']
Returns a model, decoder object and a set of utils
Please see https://github.com/snakers4/silero-models for usage examples
"""
from omegaconf import OmegaConf
from utils import (init_jit_model,
read_audio,
read_batch,
split_into_batches,
prepare_model_input)
models_list_file = os.path.join(os.path.dirname(__file__), "models.yml")
if not os.path.exists(models_list_file):
models_list_file = 'latest_silero_models.yml'
if not os.path.exists(models_list_file):
torch.hub.download_url_to_file('https://raw.githubusercontent.com/snakers4/silero-models/master/models.yml',
'latest_silero_models.yml',
progress=False)
assert os.path.exists(models_list_file)
models = OmegaConf.load(models_list_file)
available_languages = list(models.stt_models.keys())
assert language in available_languages
model, decoder = init_jit_model(model_url=models.stt_models.get(language).get(version).get(jit_model),
**kwargs)
utils = (read_batch,
split_into_batches,
read_audio,
prepare_model_input)
return model, decoder, utils
def silero_tts(language='en',
speaker='kseniya_16khz',
**kwargs):
""" Silero Text-To-Speech Models
language (str): language of the model, now available are ['ru', 'en', 'de', 'es', 'fr']
Returns a model and a set of utils
Please see https://github.com/snakers4/silero-models for usage examples
"""
from omegaconf import OmegaConf
from tts_utils import apply_tts
from tts_utils import init_jit_model as init_jit_model_tts
models_list_file = os.path.join(os.path.dirname(__file__), "models.yml")
if not os.path.exists(models_list_file):
models_list_file = 'latest_silero_models.yml'
if not os.path.exists(models_list_file):
torch.hub.download_url_to_file('https://raw.githubusercontent.com/snakers4/silero-models/master/models.yml',
'latest_silero_models.yml',
progress=False)
assert os.path.exists(models_list_file)
models = OmegaConf.load(models_list_file)
available_languages = list(models.tts_models.keys())
assert language in available_languages, f'Language not in the supported list {available_languages}'
available_speakers = []
speaker_language = {}
for lang in available_languages:
speakers = list(models.tts_models.get(lang).keys())
available_speakers.extend(speakers)
for _ in speakers:
speaker_language[_] = lang
assert speaker in available_speakers, f'Speaker not in the supported list {available_speakers}'
assert language == speaker_language[speaker], f"Incorrect language '{language}' for this speaker, please specify '{speaker_language[speaker]}'"
model_conf = models.tts_models[language][speaker].latest
if '_v2' in speaker:
from torch import package
model_url = model_conf.package
model_dir = os.path.join(os.path.dirname(__file__), "model")
os.makedirs(model_dir, exist_ok=True)
model_path = os.path.join(model_dir, os.path.basename(model_url))
if not os.path.isfile(model_path):
torch.hub.download_url_to_file(model_url,
model_path,
progress=True)
imp = package.PackageImporter(model_path)
model = imp.load_pickle("tts_models", "model")
if speaker == 'multi_v2':
avail_speakers = model_conf.speakers
return model, avail_speakers
else:
example_text = model_conf.example
return model, example_text
else:
model = init_jit_model_tts(model_conf.jit)
symbols = model_conf.tokenset
example_text = model_conf.example
sample_rate = model_conf.sample_rate
return model, symbols, sample_rate, example_text, apply_tts
def silero_te():
""" Silero Texts Enhancing Models
Current model supports the following languages: ['en', 'de', 'ru', 'es']
Returns a model and a set of utils
Please see https://github.com/snakers4/silero-models for usage examples
"""
import yaml
from torch import package
models_list_file = os.path.join(os.path.dirname(__file__), "models.yml")
if not os.path.exists(models_list_file):
models_list_file = 'latest_silero_models.yml'
if not os.path.exists(models_list_file):
torch.hub.download_url_to_file('https://raw.githubusercontent.com/snakers4/silero-models/master/models.yml',
'latest_silero_models.yml',
progress=False)
assert os.path.exists(models_list_file)
with open(models_list_file, 'r', encoding='utf8') as yaml_file:
models = yaml.load(yaml_file, Loader=yaml.SafeLoader)
model_conf = models.get('te_models').get('latest')
model_url = model_conf.get('package')
model_dir = os.path.join(os.path.dirname(__file__), "model")
os.makedirs(model_dir, exist_ok=True)
model_path = os.path.join(model_dir, os.path.basename(model_url))
if not os.path.isfile(model_path):
torch.hub.download_url_to_file(model_url,
model_path,
progress=True)
imp = package.PackageImporter(model_path)
model = imp.load_pickle("te_model", "model")
example_texts = model.examples
languages = model_conf.get('languages')
punct = model_conf.get('punct')
def apply_te(text, lan='en'):
return model.enhance_text(text, lan)
return (model,
example_texts,
languages,
punct,
apply_te)