-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
53 lines (39 loc) · 1.44 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import math
from typing import List
import numpy as np
def input_transpose(sents, pad_token):
"""
This function transforms a list of sentences of shape (batch_size, token_num) into
a list of shape (token_num, batch_size). You may find this function useful if you
use pytorch
"""
max_len = max(len(s) for s in sents)
batch_size = len(sents)
sents_t = []
for i in range(max_len):
sents_t.append([sents[k][i] if len(sents[k]) > i else pad_token for k in range(batch_size)])
return sents_t
def read_corpus(file_path, source):
data = []
for line in open(file_path):
sent = line.strip().split(' ')
# only append <s> and </s> to the target sentence
if source == 'tgt':
sent = ['<s>'] + sent + ['</s>']
data.append(sent)
return data
def batch_iter(data, batch_size, shuffle=False):
"""
Given a list of examples, shuffle and slice them into mini-batches
"""
batch_num = math.ceil(len(data) / batch_size)
index_array = list(range(len(data)))
if shuffle:
np.random.shuffle(index_array)
for i in range(batch_num):
indices = index_array[i * batch_size: (i + 1) * batch_size]
examples = [data[idx] for idx in indices]
examples = sorted(examples, key=lambda e: len(e[0]), reverse=True)
src_sents = [e[0] for e in examples]
tgt_sents = [e[1] for e in examples]
yield src_sents, tgt_sents