-
Notifications
You must be signed in to change notification settings - Fork 1
/
utils.py
79 lines (71 loc) · 3.19 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: CC-BY-NC-4.0
import re
import json
from copy import deepcopy
def process_text(text):
text = text.replace('\\xe2\\x80\\x9c', '"')
text = text.replace('\\xe2\\x80\\x9d', '"')
text = text.replace('\\xe2\\x80\\x98', "'")
text = text.replace('\\xe2\\x80\\x99', "'")
text = text.replace('\u201c', '"')
text = text.replace('\u201d', '"')
text = text.replace('\\xc2\\xa0', " ")
text = text.replace('\\"', '\"')
text = text.replace("\\'", "\'")
text = text.replace("\\xe2\\x80\\x94", "-")
text = text.replace("\\xe2\\x80\\x93", "-")
text = json.loads(json.dumps(text.encode('utf-8').decode('unicode_escape').encode("ascii", "ignore").decode())).strip()
text = re.sub(r'[ \t\r]*\n\n+[ \t\r]*', '\n\n', text)
text = re.sub(r'[\t \r]+', ' ', text)
text = "\n\n".join(re.split(r'\n\n+', text))
return text
def get_offsets(rsd_str, tokenizer, doc_id):
indexer = 0
sent_offset = list()
sents = rsd_str.split("\n\n")
for i, s in enumerate(sents):
while not rsd_str[indexer:].startswith(s) and indexer < len(rsd_str):
indexer += 1
if indexer < len(rsd_str):
sent_start = indexer
sent_end = sent_start + len(s) - 1
assert rsd_str[sent_start:sent_end+1] == s, "sentence offset not match %s-%d" % (doc_id, i)
sent_offset.append((sent_start, sent_end))
indexer = sent_end + 1
assert len(sent_offset) == len(sents), "sentence segmentation offset error in: %s" % doc_id
token_offsets = list()
for i, sent_text in enumerate(sents):
indexer = 0
t_offset = list()
tok = [t.text for t in tokenizer.tokenize(sent_text)]
for j, t in enumerate(tok):
while not sent_text[indexer:].startswith(t) and indexer < len(sent_text):
indexer += 1
if indexer < len(sent_text):
t_start = indexer
t_end = t_start + len(t) - 1
assert sent_text[t_start:t_end+1] == t, "token offset not match %s-%d-%d" % (doc_id, i, j)
t_offset.append((t_start, t_end))
indexer = t_end + 1
token_offsets.append(deepcopy(t_offset))
data = list()
for i in range(0, len(sents)):
sent_item = dict()
sent_item["text"] = sents[i]
sent_item["segment_id"] = doc_id + "_$$segment" + str(i)
sent_item["sent_start"] = int(sent_offset[i][0])
sent_item["sent_end"] = int(sent_offset[i][1]) + 1
sent_item["tokens"] = list()
tok = [t.text for t in tokenizer.tokenize(sents[i])]
for j in range(0, len(token_offsets[i])):
token_item = dict()
token_id = 'token-%d-%d' % (i, j)
token_item["start_char"] = token_offsets[i][j][0]
token_item["token_id"] = token_id
token_item["end_char"] = token_offsets[i][j][1] + 1
token_item["text"] = tok[j]
assert tok[j] == sents[i][token_offsets[i][j][0]: token_offsets[i][j][1] + 1]
sent_item["tokens"].append(deepcopy(token_item))
data.append(deepcopy(sent_item))
return data