-
Notifications
You must be signed in to change notification settings - Fork 7
/
test_phrase_grammar.py
executable file
·186 lines (157 loc) · 5.49 KB
/
test_phrase_grammar.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import argparse
import sys
import nltk
import numpy
import torch
from torch.autograd import Variable
import data_ptb
import data
#import data_nli as data
# Test model
def build_tree(depth, sen):
assert len(depth) == len(sen)
if len(depth) == 1:
parse_tree = sen[0]
else:
idx_max = numpy.argmax(depth)
parse_tree = []
if len(sen[:idx_max]) > 0:
tree0 = build_tree(depth[:idx_max], sen[:idx_max])
parse_tree.append(tree0)
tree1 = sen[idx_max]
if len(sen[idx_max + 1:]) > 0:
tree2 = build_tree(depth[idx_max + 1:], sen[idx_max + 1:])
tree1 = [tree1, tree2]
if parse_tree == []:
parse_tree = tree1
else:
parse_tree.append(tree1)
return parse_tree
def get_brackets(tree, idx=0):
brackets = set()
if isinstance(tree, list) or isinstance(tree, nltk.Tree):
for node in tree:
node_brac, next_idx = get_brackets(node, idx)
if next_idx - idx > 1:
brackets.add((idx, next_idx))
brackets.update(node_brac)
idx = next_idx
return brackets, idx
else:
return brackets, idx + 1
def mean(x):
return sum(x) / len(x)
def test(model, corpus, cuda, mode='test', dictionary=None, prt=False):
model.eval()
prec_list = []
reca_list = []
f1_list = []
nsens = 0
if dictionary is None:
dictionary = corpus.dictionary
if mode == 'train':
sentences = corpus.train_sens
trees = corpus.train_trees
elif mode == 'valid':
sentences = corpus.valid_sens
trees = corpus.valid_trees
elif mode == 'test':
sentences = corpus.test_sens
trees = corpus.test_trees
elif mode == 'test_snli':
sentences = corpus.test_snli_sens
trees = corpus.test_snli_trees
elif mode == 'test_mnli':
sentences = corpus.test_mnli_sens
trees = corpus.test_mnli_trees
#sentences
#trees
for sen, sen_tree in zip(sentences, trees):
if len(sen) > 12:
continue
x = numpy.array([dictionary[w] for w in sen])
input = Variable(torch.LongTensor(x[:, None]))
if cuda:
input = input.cuda()
hidden = model.init_hidden(1)
_, hidden = model(input, hidden)
attentions = model.attentions.squeeze().data.cpu().numpy()
gates = model.gates.squeeze().data.cpu().numpy()
depth = gates[1:-1]
sen = sen[1:-1]
attentions = attentions[1:-1]
parse_tree = build_tree(depth, sen)
model_out, _ = get_brackets(parse_tree)
std_out, _ = get_brackets(sen_tree)
overlap = model_out.intersection(std_out)
prec = float(len(overlap)) / (len(model_out) + 1e-8)
reca = float(len(overlap)) / (len(std_out) + 1e-8)
if len(std_out) == 0:
reca = 1.
if len(model_out) == 0:
prec = 1.
f1 = 2 * prec * reca / (prec + reca + 1e-8)
prec_list.append(prec)
reca_list.append(reca)
f1_list.append(f1)
nsens += 1
if prt and nsens % 100 == 0:
# for i in range(len(sen)):
# print '%15s\t%.2f\t%s' % (sen[i], depth[i], str(attentions[i, 1]))
'''
print 'Model output:'
print parse_tree
print model_out
print 'Standard output:'
print sen_tree
print std_out
'''
print 'Prec: %f, Reca: %f, F1: %f' % (prec, reca, f1)
print '-' * 80
sys.stdout.flush()
if prt:
print '-' * 80
print 'Mean Prec: %f, Mean Reca: %f, Mean F1: %f' % (mean(prec_list), mean(reca_list), mean(f1_list))
print 'Number of sentence: %i' % nsens
sys.stdout.flush()
return mean(f1_list)
if __name__ == '__main__':
marks = [' ', '-', '=']
numpy.set_printoptions(precision=2, suppress=True, linewidth=5000)
parser = argparse.ArgumentParser(description='PyTorch PTB Language Model')
# Model parameters.
parser.add_argument('--data', type=str, default='./data/penn',
help='location f the train data corpus')
parser.add_argument('--checkpoint', type=str, default='model/model_lm.pt',
help='model checkpoint to use')
parser.add_argument('--seed', type=int, default=1111,
help='random seed')
parser.add_argument('--cuda', action='store_true',
help='use CUDA')
parser.add_argument('--mode', type=str, default='train', help='train/test/valid mode')
args = parser.parse_args()
print("data: ", args.data)
sys.stdout.flush()
# Set the random seed manually for reproducibility.
torch.manual_seed(args.seed)
# Load model
with open(args.checkpoint, 'rb') as f:
model = torch.load(f)
if args.cuda:
model.cuda()
torch.cuda.manual_seed(args.seed)
else:
model.cpu()
# Load data
# we are going to test the ptb parses
# so the default corpus is ptb
corpus = data_ptb.Corpus('./data/ptb')
# This is to load the dictionary used in training
if 'ptb' in args.data:
c2_dict = corpus.dictionary
else:
corpus2 = data.Corpus(args.data)
c2_dict = corpus2.dictionary
sys.stdout.flush()
#test(model, corpus, args.cuda, mode='test', dictionary=None, prt=True)
test(model, corpus, args.cuda, mode=args.mode, dictionary=c2_dict, prt=True)