-
Notifications
You must be signed in to change notification settings - Fork 4
/
back_translate.py
66 lines (57 loc) · 2.87 KB
/
back_translate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from typing import Dict, List, Callable, Any
import json
import random
def back_translate(text: str,
schemas: Dict[str, List[List[str]]],
keywords: List[str] = None,
handle_func: Callable[[Dict[str, str]], Any] = lambda x: x) -> Any:
"""
输入一句话,使用不同翻译平台、翻译模式(中间语言)进行数据增强,生成多个回复句子。
参数:
------
text: ``str``
输入的句子
schemas: ``Dict[str, List[List[str]]]``
定义了翻译平台和翻译模式(中间语言),见``input/schemas.json``
keywords: ``List[str]``, optional, default=``None``
如果指定了keywords,则使用keyword mask的方法,否则不使用
handle_func: ``Callable[[Dict[str, str]], Any]``, optional, default=``lambda x: x``
对结果(res)进行处理,例如:
过滤掉重复的生成结果、改变输出结构、限制最大生成个数、使用匹配模型进行过滤等
"""
res = {"origin": text}
for platform, schema_list in schemas.items():
trans_func = __import__(f"{platform}.main", fromlist=platform).back_translate
for schema in schema_list:
try:
schema_key = "->".join(schema)
res[f"{platform} {schema_key}"] = trans_func(text, lang_list=schema)
except Exception:
pass
if keywords: # 使用keyword mask
keywords = list(set(keywords)) # 过滤重复keywords
hit_keywords = [keyword for keyword in keywords if keyword in text]
for selected_keyword in hit_keywords:
try:
replaced_text = text.replace(selected_keyword, "UNK")
back_translate_res = trans_func(replaced_text, lang_list=schema)
if "UNK" in back_translate_res or "unk" in back_translate_res:
back_translate_res = back_translate_res.replace("UNK", selected_keyword)
back_translate_res = back_translate_res.replace("unk", selected_keyword)
res[f"{platform} {schema_key} kw_mask{selected_keyword}"] = back_translate_res
except Exception:
pass
return handle_func(res)
def test():
def handle_res(res):
no_repeat = list(set(res.values()))
for item in no_repeat:
print(item)
return no_repeat
schemas = json.load(open("./input/schemas.json", "r"))
keywords = [line.strip() for line in open("./input/keywords.txt", "r")]
result = back_translate("后评估工作如何开展?", schemas, keywords)
print(json.dumps(result, indent=4, ensure_ascii=False))
print(json.dumps(handle_res(result), indent=4, ensure_ascii=False))
if __name__ == "__main__":
test()