-
Notifications
You must be signed in to change notification settings - Fork 6
/
main.py
executable file
·170 lines (146 loc) · 6.95 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
#! /usr/bin/env python
import os
import re
import argparse
import datetime
import torch
import torchtext.data as data
import torchtext.datasets as datasets
import model
import train
import mydatasets
import word2vec
from tqdm import tqdm
import pandas as pd
from gensim.models import word2vec
import data.textProcess as preprocess
import chardet
from urllib.request import urlopen
from bs4 import BeautifulSoup as bs
parser = argparse.ArgumentParser(description='CNN text classificer')
# learning
parser.add_argument('-lr', type=float, default=0.001, help='initial learning rate [default: 0.001]')
parser.add_argument('-epochs', type=int, default=100, help='number of epochs for train [default: 100]')
parser.add_argument('-batch-size', type=int, default=4, help='batch size for training [default: ]')
parser.add_argument('-log-interval', type=int, default=1, help='how many steps to wait before logging training status [default: 1]')
parser.add_argument('-test-interval', type=int, default=100, help='how many steps to wait before testing [default: 100]')
parser.add_argument('-save-interval', type=int, default=500, help='how many steps to wait before saving [default:500]')
parser.add_argument('-save-dir', type=str, default='snapshot', help='where to save the snapshot')
parser.add_argument('-early-stop', type=int, default=1000, help='iteration numbers to stop without performance increasing')
parser.add_argument('-save-best', type=bool, default=True, help='whether to save when get best performance')
# data
parser.add_argument('-shuffle', action='store_true', default=False, help='shuffle the data every epoch')
# model
parser.add_argument('-dropout', type=float, default=0.5, help='the probability for dropout [default: 0.5]')
parser.add_argument('-max-norm', type=float, default=3.0, help='l2 constraint of parameters [default: 3.0]')
parser.add_argument('-embed-dim', type=int, default=128, help='number of embedding dimension [default: 128]')
parser.add_argument('-kernel-num', type=int, default=100, help='number of each kind of kernel')
parser.add_argument('-kernel-sizes', type=str, default='3,4,5', help='comma-separated kernel size to use for convolution')
parser.add_argument('-static', action='store_true', default=False, help='fix the embedding')
parser.add_argument('-word2vec-model', type=str, default='./word2Vec/model/word2Vec.model', help='filename of word2vec model')
parser.add_argument('-use-word2vec', action='store_true', default=False, help='whether use word2vec or not')
parser.add_argument('-model', type=str, default='CNN', help='model type: CNN, LSTM, GRU')
parser.add_argument('-hidden-size', type=int, default=128, help='hidden layer size of LSTM')
parser.add_argument('-lstm-num', type=int, default=1, help='the number of LSTM')
# device
parser.add_argument('-device', type=int, default=-1, help='device to use for iterate data, -1 mean cpu [default: -1]')
parser.add_argument('-no-cuda', action='store_true', default=False, help='disable the gpu')
# option
parser.add_argument('-snapshot', type=str, default=None, help='filename of model snapshot [default: None]')
parser.add_argument('-predict', type=str, default=None, help='predict the sentence given')
parser.add_argument('-predict-url', type=str, default=None, help='predict the resume url given')
parser.add_argument('-test', action='store_true', default=False, help='train or test')
args = parser.parse_args()
excel = './data/标注数据.xlsx'
text_path = './data/text/processed/'
#new_model = word2vec.Word2Vec.load(args.word2vec_model)
print("\nLoading data...")
tokenize = lambda x: re.split(' ', x)
text_field = data.Field(tokenize=tokenize, lower=True, stop_words=['\r', '\n', '\t', '\xa0', ' ', ''])
label_field = data.Field(sequential=False)
train_iter, dev_iter = mydatasets.resume(excel, text_path, text_field, label_field, args.batch_size,
device=-1, repeat=False,
use_wv=args.use_word2vec,
wv_model=args.word2vec_model,
)
# 查看测试集数据
'''
for batch in dev_iter:
print(batch.label, batch.text)
os._exit(1)
'''
# 查看text数据的voceb
'''
for word in text_field.vocab.itos[:100]:
print(word)
print('vocab size:', len(text_field.vocab))
os._exit(1)
'''
# 查看label数据的vocab
'''
for label in label_field.vocab.itos:
print(label)
print(len(label_field.vocab))
os._exit(1)
'''
args.embed_num = len(text_field.vocab) # 9499
args.class_num = len(label_field.vocab) - 1 # 16, -1是为了除去<unk>
print('embed num:', args.embed_num, '\nclass num:', args.class_num)
args.cuda = (not args.no_cuda) and torch.cuda.is_available(); del args.no_cuda
args.kernel_sizes = [int(k) for k in args.kernel_sizes.split(',')]
args.save_dir = os.path.join(args.save_dir, datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S'))
if args.model == 'CNN':
print('Using CNN model')
cnn = model.CNN_Text(args, text_field)
elif args.model == 'LSTM':
print('Using LSTM model')
cnn = model.LSTM_Text(args, text_field)
if args.snapshot is not None:
print('\nLoading model from {}...'.format(args.snapshot))
cnn.load_state_dict(torch.load(args.snapshot))
if args.cuda:
#torch.cuda.set_device(args.device)
print('Using CUDA')
cnn = cnn.cuda()
else:
print('Using CPU')
#------ 预测 ------#
if args.predict or args.predict_url:
args.snapshot = './snapshot/TextCNN_Word2Vec/best_steps_4100.pt'
print('\nLoading model from {}...'.format(args.snapshot))
cnn.load_state_dict(torch.load(args.snapshot))
# 输入文本进行预测
if args.predict is not None:
# 先进行文本预处理,在进行专业类别预测
text = preprocess.textPreprocess(args.predict)
label = train.predict(text, cnn, text_field, label_field, args.cuda)
print('\n[简历文本] {}\n[预测专业] {}\n'.format(text, label))
# 输入链接,先爬取该链接的文本,再预测
elif args.predict_url is not None:
html = urlopen(args.predict_url).read()
encoding = chardet.detect(html)['encoding']
if encoding == 'GB2312':
encoding = 'gbk'
elif encoding == 'iso-8859-1':
encoding = 'utf-8'
html = bs(html, 'html.parser', from_encoding=encoding)
[script.extract() for script in html.findAll('script')]
[style.extract() for style in html.findAll('style')]
text = ''.join([s for s in html.text.splitlines(True) if s.strip()])
text = preprocess.textPreprocess(text)
label = train.predict(text, cnn, text_field, label_field, args.cuda)
print('\n[简历文本] {}\n[预测专业] {}\n'.format(text, label))
#------ 测试 ------#
elif args.test:
try:
train.eval(test_iter, cnn, args)
except Exception as e:
print("\nSorry. The test dataset doesn't exist.\n")
#------ 训练 ------#
else:
print()
try:
train.train(train_iter, dev_iter, cnn, args)
except KeyboardInterrupt:
print('\n' + '-' * 89)
print('Exiting from training early')