-
Notifications
You must be signed in to change notification settings - Fork 5
/
Evaluate.lua
105 lines (93 loc) · 3.69 KB
/
Evaluate.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
--[[/**************************************************************************
ReVal - A Simple and Effective Machine Translation Evaluation Metric Based on Recurrent Neural Networks.
Copyright (C) 2014 Rohit Gupta, University of Wolverhampton
This file is part of ReVal and is a modified version of the code distributed at https://github.com/stanfordnlp/treelstm.
ReVal is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
ReVal is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
**************************************************************************/
--]]
--[[
ReVal:MT Evalauation script
--]]
require('.')
local utf8=require 'lua-utf8'
-- Pearson correlation
function pearson(x, y)
x = x - x:mean()
y = y - y:mean()
return x:dot(y) / (x:norm() * y:norm())
end
-- read command line arguments
local args = lapp [[
-m,--model (default dependency) Model architecture: [dependency, lstm, bilstm]
-l,--layers (default 1) Number of layers (ignored for Tree-LSTM)
-d,--dim (default 150) LSTM memory dimension
]]
local model_class = treelstm.TreeLSTMSim
-- directory containing dataset files
local data_dir = 'tmp/'
-- load embeddings
print('loading word embeddings')
local emb_dir = 'glove/'
local emb_prefix = emb_dir .. 'glove.840B'
--]]
--[[local train_start = sys.clock()
--]]
local emb_vocab, emb_vecs = treelstm.read_embedding(emb_prefix .. '.vocab', emb_prefix .. '.300d.th')
local testvocab = treelstm.Vocab(data_dir .. 'testvocab-cased.txt')
-- use only vectors in vocabulary (not necessary, but gives faster training)
local emb_dim = emb_vecs:size(2) --300
local num_unk = 0
local vecs = torch.Tensor(testvocab.size, emb_dim)
for i = 1, testvocab.size do
local w = testvocab:token(i)
if emb_vocab:contains(w) then
vecs[i] = emb_vecs[emb_vocab:index(w)]
else
mseed=''
for pos, code in utf8.next, w do
mseed=mseed .. code
end
torch.manualSeed(mseed)
num_unk = num_unk + 1
vecs[i]:uniform(-0.05, 0.05)
end
end
print('unk count test= ' .. num_unk)
---initialze again emb_vecs with test embeddings
local emb_vecs = vecs
local test_dir = data_dir .. 'test/'
local test_dataset = treelstm.read_test_dataset(test_dir, testvocab)
-- loading a pre-trained model
local model_save_path = 'trained_model/rel-dependency.1l.150d.th'
local best_dev_model = model_class.load(model_save_path)
print 'Loaded model for testing have the following configuration'
best_dev_model:print_config()
-- evaluate
header('Evaluating on test set')
local test_predictions = best_dev_model:predict_dataset(test_dataset, emb_vecs)
-- write segment level scores
if lfs.attributes(treelstm.predictions_dir) == nil then
lfs.mkdir(treelstm.predictions_dir)
end
local predictions_save_path = string.format(
treelstm.predictions_dir .. '/rel-%s.%dl.%dd.pred', args.model, args.layers, args.dim)
local predictions_file = torch.DiskFile(predictions_save_path, 'w')
print('writing segment level scores to ' .. predictions_save_path)
local sysscore = 0
for i = 1, test_predictions:size(1) do
local segscore = (test_predictions[i]-1)/4
predictions_file:writeFloat((test_predictions[i]-1)/4)
sysscore=sysscore + segscore
end
sysscore=sysscore/test_predictions:size(1)
print('ReVal Score:'.. sysscore)
predictions_file:close()