-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathutil.py
45 lines (39 loc) · 1.63 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# -*- coding:utf8 -*-
from torchtext.legacy.data import Iterator, BucketIterator
from torchtext.legacy import data
# from torchtext.data import Iterator, BucketIterator
# from torchtext import data
import torch
def load_iters(batch_size=32, device="cpu", data_path='data', vectors=None, limit=100000):
TEXT = data.Field(lower=True, batch_first=True, include_lengths=True)
LABEL = data.LabelField(batch_first=True)
fields = {'text': ('text', TEXT),
'label': ('label', LABEL)}
train_data, test_data = data.TabularDataset.splits(
path=data_path,
train='train.jsonl',
test='test.jsonl',
format='json',
fields=fields,
filter_pred=lambda ex: ex.label != '-' # filter the example which label is '-'(means unlabeled)
)
train_data = data.Dataset(train_data.examples[:limit], train_data.fields)
dev_data = test_data
print(f'using {len(train_data)} train data...')
if vectors is not None:
TEXT.build_vocab(train_data, vectors=vectors, unk_init=torch.Tensor.normal_)
else:
TEXT.build_vocab(train_data, max_size=50000)
LABEL.build_vocab(train_data.label)
train_iter, dev_iter = BucketIterator.splits(
(train_data, dev_data),
batch_sizes=(batch_size, batch_size),
device=device,
sort_key=lambda x: len(x.text),
sort_within_batch=True,
repeat=False,
shuffle=True
)
test_iter = Iterator(test_data, batch_size=batch_size, device=device, sort=False, sort_within_batch=False,
repeat=False, shuffle=False)
return train_iter, dev_iter, test_iter, TEXT, LABEL