forked from yeyupiaoling/MASR
-
Notifications
You must be signed in to change notification settings - Fork 0
/
infer_server.py
111 lines (98 loc) · 5.47 KB
/
infer_server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import argparse
import functools
import os
import sys
import time
from datetime import datetime
from flask import request, Flask, render_template
from flask_cors import CORS
from masr import SUPPORT_MODEL
from masr.predict import Predictor
from masr.utils.audio_vad import crop_audio_vad
from masr.utils.utils import add_arguments, print_arguments
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
add_arg('use_model', str, 'deepspeech2', "所使用的模型", choices=SUPPORT_MODEL)
add_arg("host", str, "0.0.0.0", "监听主机的IP地址")
add_arg("port", int, 5000, "服务所使用的端口号")
add_arg("save_path", str, 'dataset/upload/', "上传音频文件的保存目录")
add_arg('use_gpu', bool, True, "是否使用GPU预测")
add_arg('to_an', bool, False, "是否转为阿拉伯数字")
add_arg('use_pun', bool, False, "是否给识别结果加标点符号")
add_arg('beam_size', int, 300, "集束搜索解码相关参数,搜索大小,范围:[5, 500]")
add_arg('alpha', float, 2.2, "集束搜索解码相关参数,LM系数")
add_arg('beta', float, 4.3, "集束搜索解码相关参数,WC系数")
add_arg('cutoff_prob', float, 0.99, "集束搜索解码相关参数,剪枝的概率")
add_arg('cutoff_top_n', int, 40, "集束搜索解码相关参数,剪枝的最大值")
add_arg('vocab_path', str, 'dataset/vocabulary.txt', "数据集的词汇表文件路径")
add_arg('model_path', str, 'models/{}_{}/inference.pt', "导出的预测模型文件路径")
add_arg('pun_model_dir', str, 'models/pun_models/', "加标点符号的模型文件夹路径")
add_arg('lang_model_path', str, 'lm/zh_giga.no_cna_cmn.prune01244.klm', "集束搜索解码相关参数,语言模型文件路径")
add_arg('feature_method', str, 'linear', "音频预处理方法", choices=['linear', 'mfcc', 'fbank'])
add_arg('decoder', str, 'ctc_beam_search', "结果解码方法", choices=['ctc_beam_search', 'ctc_greedy'])
add_arg('pinyin_mode', bool, False, '使用拼音识别模式')
args = parser.parse_args()
app = Flask(__name__, template_folder="templates", static_folder="static", static_url_path="/")
# 允许跨越访问
CORS(app)
predictor = Predictor(model_path=args.model_path.format(args.use_model, args.feature_method), vocab_path=args.vocab_path, use_model=args.use_model,
decoder=args.decoder, alpha=args.alpha, beta=args.beta, lang_model_path=args.lang_model_path,
beam_size=args.beam_size, cutoff_prob=args.cutoff_prob, cutoff_top_n=args.cutoff_top_n,
use_gpu=args.use_gpu, use_pun_model=args.use_pun, pun_model_dir=args.pun_model_dir,
pinyin_mode=args.pinyin_mode,
feature_method=args.feature_method)
# 语音识别接口
@app.route("/recognition", methods=['POST'])
def recognition():
f = request.files['audio']
if f:
# 临时保存路径
file_path = os.path.join(args.save_path, f.filename)
f.save(file_path)
try:
start = time.time()
# 执行识别
score, text = predictor.predict(audio_path=file_path, to_an=args.to_an)
end = time.time()
print("识别时间:%dms,识别结果:%s, 得分: %f" % (round((end - start) * 1000), text, score))
result = str({"code": 0, "msg": "success", "result": text, "score": round(score, 3)}).replace("'", '"')
return result
except Exception as e:
print(f'[{datetime.now()}] 短语音识别失败,错误信息:{e}', file=sys.stderr)
return str({"error": 1, "msg": "audio read fail!"})
return str({"error": 3, "msg": "audio is None!"})
# 长语音识别接口
@app.route("/recognition_long_audio", methods=['POST'])
def recognition_long_audio():
f = request.files['audio']
if f:
# 临时保存路径
file_path = os.path.join(args.save_path, f.filename)
f.save(file_path)
try:
start = time.time()
# 分割长音频
audios_bytes = crop_audio_vad(file_path)
texts = ''
scores = []
# 执行识别
for i, audio_bytes in enumerate(audios_bytes):
score, text = predictor.predict(audio_bytes=audio_bytes, to_an=args.to_an)
texts = texts + text if args.use_pun else texts + ',' + text
scores.append(score)
end = time.time()
print("识别时间:%dms,识别结果:%s, 得分: %f" % (round((end - start) * 1000), texts, sum(scores) / len(scores)))
result = str({"code": 0, "msg": "success", "result": texts, "score": round(float(sum(scores) / len(scores)), 3)}).replace("'", '"')
return result
except Exception as e:
print(f'[{datetime.now()}] 长语音识别失败,错误信息:{e}', file=sys.stderr)
return str({"error": 1, "msg": "audio read fail!"})
return str({"error": 3, "msg": "audio is None!"})
@app.route('/')
def home():
return render_template("index.html")
if __name__ == '__main__':
print_arguments(args)
if not os.path.exists(args.save_path):
os.makedirs(args.save_path)
app.run(host=args.host, port=args.port)