-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
157 lines (143 loc) · 6.73 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import argparse
from torchtext.vocab import Vectors
import torch
import torch.nn.functional as F
import re
import sys
from torchtext import data
import jieba
import logging
from model import TextCNN
import pickle
def word_cut(text):
text = re.compile(r'[^A-Za-z0-9\u4e00-\u9fa5]').sub(' ', text) # 将非中文字符、非a-z, 非A-Z,非0-9 全部替换为' '
return [word.strip() for word in jieba.cut(text) if word.strip()]
def load_iter(text_field, label_field, args, **kwargs):
print("load_iter...")
train_dataset, test_dataset = data.TabularDataset.splits(
path='data', format='tsv', skip_header=True,
train='train.tsv', test='test.tsv',
fields=[
('index', None),
('label', label_field),
('text', text_field)
]
)
if args.static and args.pretrained_name and args.pretrained_path:
vectors = Vectors(name=args.pretrained_name, cache=args.pretrained_path)
text_field.build_vocab(train_dataset, test_dataset, vectors=vectors)
else:
text_field.build_vocab(train_dataset, test_dataset)
label_field.build_vocab(train_dataset, test_dataset) # word2index
train_iter, test_iter = data.Iterator.splits(
(train_dataset, test_dataset),
batch_sizes=(args.batch_size, len(test_dataset)),
sort_key=lambda x: len(x.text),
**kwargs
)
print("finish load_iter")
return train_iter, test_iter
def load(args):
print('load_data...')
# , unk_token=None, pad_token=None
text_field = data.Field(lower=True, tokenize=word_cut)
label_field = data.Field(sequential=False, unk_token=None, pad_token=None)
train_iter, test_iter = load_iter(text_field, label_field, args, device=-1, repeat=False, shuffle=True)
args.vocabulary_size = len(text_field.vocab)
# print(label_field.vocab.itos) # ['<unk>', '0', '1']
if args.static:
args.embedding_dim = text_field.vocab.vectors.size()[-1]
args.vectors = text_field.vocab.vectors
if args.multichannel:
args.static = True
args.non_static = True
args.class_num = len(label_field.vocab)
# if '<unk>' in label_field.vocab.itos:
# args.class_num -= 1
# if '<pad>' in label_field.vocab.itos:
# args.class_num -= 1
# Field对象有个specialist对象里面有unk_token默认为<unk>,如果Field的时候没有说明,Vocab就里面会有<unk>
args.cuda = args.device != -1 and torch.cuda.is_available()
args.filter_sizes = [int(size) for size in args.filter_sizes.split(',')]
print('Finish load_data')
return train_iter, test_iter, text_field.vocab
def set_args():
parser = argparse.ArgumentParser(description='TextCNN text classifier')
# model
parser.add_argument('-lr', type=float, default=0.001, help='initial learning rate [default: 0.001]')
parser.add_argument('-epochs', type=int, default=256, help='number of epochs for train [default: 256]')
parser.add_argument('-batch-size', type=int, default=128, help='batch size for training [default: 128]')
parser.add_argument('-dropout', type=float, default=0.5, help='the probability for dropout [default: 0.5]')
parser.add_argument('-embedding-dim', type=int, default=128, help='number of embedding dimension [default: 128]')
parser.add_argument('-filter-num', type=int, default=100, help='number of each size of filter')
parser.add_argument('-filter-sizes', type=str, default='3,4,5', help='comma-separated filter sizes to use for convolution')
parser.add_argument('-static', type=bool, default=False, help='whether to use static pre-trained word vectors')
parser.add_argument('-non-static', type=bool, default=False, help='whether to fine-tune static pre-trained word vectors')
parser.add_argument('-multichannel', type=bool, default=False, help='whether to use 2 channel of word vectors')
parser.add_argument('-pretrained-name', type=str, default='sgns.zhihu.word', help='filename of pre-trained word vectors')
parser.add_argument('-pretrained-path', type=str, default='data', help='path of pre-trained word vectors')
# device
parser.add_argument('-device', type=int, default=-1, help='device to use for iterate data, -1 mean cpu [default: -1]')
return parser.parse_args()
def train(train_iter, model, args):
if args.cuda:
model.cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
bestLoss = None
model.train()
for epoch in range(1, args.epochs + 1):
epochLoss = 0
steps = 0
for batch in train_iter:
feature, target = batch.text, batch.label
feature.data.t_()
if args.cuda:
feature, target = feature.cuda(), target.cuda()
optimizer.zero_grad()
logits = model(feature)
loss = F.cross_entropy(logits, target)
epochLoss += loss.item()
loss.backward()
optimizer.step()
steps += 1
corrects = (torch.max(logits, 1)[1].view(target.size()).data == target.data).sum()
train_acc = 100.0 * corrects / batch.batch_size
sys.stdout.write('\repoch[{}/{}] batch[{}/{}] - loss: {:.6f} acc: {}/{}={:.4f}%'.
format(epoch, args.epochs, steps, len(train_iter), loss.item(), corrects, batch.batch_size, train_acc))
# save model
print("")
if bestLoss is None or epochLoss < bestLoss:
print('epoch[{}/{}] save model with loss={}'.format(epoch, args.epochs, epochLoss))
bestLoss = epochLoss
torch.save(model, 'model/textCNNModel.pt')
# update lr
for param_group in optimizer.param_groups:
param_group['lr'] -= args.lr / args.epochs
def evaluate(test_iter, model, args):
model.eval()
corrects, avg_loss = 0, 0
for batch in test_iter:
feature, target = batch.text, batch.label
feature.data.t_(), target.data.sub_(1)
if args.cuda:
feature, target = feature.cuda(), target.cuda()
logits = model(feature)
loss = F.cross_entropy(logits, target)
avg_loss += loss.item()
corrects += (torch.max(logits, 1)[1].view(target.size()).data == target.data).sum()
size = len(test_iter.dataset)
avg_loss /= size
accuracy = 100.0 * corrects / size
print('\nEvaluation - loss: {:.6f} acc: {:.4f}%({}/{}) \n'.format(avg_loss, accuracy, corrects, size))
return accuracy
if __name__ == '__main__':
jieba.setLogLevel(logging.INFO)
args = set_args()
train_iter, test_iter, vocab = load(args)
with open("data/vocab.pkl", "wb") as f:
pickle.dump(vocab, f)
text_cnn = TextCNN(args)
if args.cuda:
torch.cuda.set_device(args.device)
text_cnn = text_cnn.cuda()
train(train_iter, text_cnn, args)