-
Notifications
You must be signed in to change notification settings - Fork 0
/
data_prep.lua
63 lines (49 loc) · 2.3 KB
/
data_prep.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
require 'paths'
local maker = require 'data_maker'
cmd = torch.CmdLine()
cmd:text()
cmd:text('preprocess the corpus')
cmd:text()
cmd:text('Options')
cmd:option('-src_path', 'data/nmt/prep', 'path to pre-processed data')
cmd:option('-dst_path', 'data/nmt/data', 'path to where dictionaries and datasets should be written')
cmd:option('-src_train', 'train.de-en.de', 'the name of source training data')
cmd:option('-src_valid', 'valid.de-en.de', 'the name of source valid data')
cmd:option('-src_test', 'test.de-en.de', 'the name of source testing data')
cmd:option('-tgt_train', 'train.de-en.en', 'the name of target training data')
cmd:option('-tgt_valid', 'valid.de-en.en', 'the name of target valid data')
cmd:option('-tgt_test', 'test.de-en.en', 'the name of target test data')
cmd:option('-min_freq', 3, 'remove words appearing less than min_freq')
cmd:option('-seed', 123, 'torch manual random number generator seed')
cmd:text()
local opt = cmd:parse(arg)
torch.manualSeed(opt.seed)
if not paths.dirp(opt.dst_path) then
os.execute('mkdir -p ' .. opt.dst_path)
end
local src_train = paths.concat(opt.src_path, opt.src_train)
local src_valid = paths.concat(opt.src_path, opt.src_valid)
local src_test = paths.concat(opt.src_path, opt.src_test)
local tgt_train = paths.concat(opt.src_path, opt.tgt_train)
local tgt_valid = paths.concat(opt.src_path, opt.tgt_valid)
local tgt_test = paths.concat(opt.src_path, opt.tgt_test)
local sdict_path = paths.concat(opt.dst_path, 'src.dict.t7')
local tdict_path = paths.concat(opt.dst_path, 'tgt.dict.t7')
print('building source dictionary ...')
local sdict = maker.dict(src_train, opt.min_freq)
torch.save(sdict_path, sdict)
print('building target dictionary ...')
local tdict = maker.dict(tgt_train, opt.min_freq)
torch.save(tdict_path, tdict)
local train_path = paths.concat(opt.dst_path, 'train.t7')
local valid_path = paths.concat(opt.dst_path, 'valid.t7')
local test_path = paths.concat(opt.dst_path, 'test.t7')
print('coverting training data ...')
local train = maker.convert(src_train, tgt_train, sdict, tdict)
torch.save(train_path, train)
print('converting valid data ...')
local valid = maker.convert(src_valid, tgt_valid, sdict, tdict)
torch.save(valid_path, valid)
print('converting testing data ..')
local test = maker.convert(src_test, tgt_test, sdict, tdict)
torch.save(test_path, test)