-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathconfig.py
132 lines (94 loc) · 4.31 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
"""
Configuration for Simultaneous Neural Machine Translation
"""
from collections import OrderedDict
# data_home = '/home/thoma/scratch/un16/'
# model_home = '/home/thoma/scratch/simul/'
# data_home = '/mnt/scratch/un16/'
# model_home = '/mnt/scratch/simul/'
data_home = '/misc/kcgscratch1/ChoGroup/thoma_data/simul_trans/un16/'
model_home = '/misc/kcgscratch1/ChoGroup/thoma_data/simul_trans/'
def pretrain_config():
"""Configuration for pretraining underlining NMT model."""
config = dict()
# training set (source, target)
config['datasets'] = [data_home + 'train.un16.en-zh.zh.c0.tok.clean.bpe20k.np',
data_home + 'train.un16.en-zh.en.c0.tok.clean.bpe20k.np']
# validation set (source, target)
config['valid_datasets'] = [data_home + 'devset.un16.en-zh.zh.c0.tok.bpe20k.np',
data_home + 'devset.un16.en-zh.en.c0.tok.bpe20k.np']
# vocabulary (source, target)
config['dictionaries'] = [data_home + 'train.un16.en-zh.zh.c0.tok.clean.bpe20k.vocab.pkl',
data_home + 'train.un16.en-zh.en.c0.tok.clean.bpe20k.vocab.pkl']
# save the model to
config['saveto'] = data_home + 'pretraining/model_un16_bpe2k_uni_zh-en.npz'
config['reload_'] = True
# model details
config['dim_word'] = 512
config['dim'] = 1028
config['n_words'] = 20000
config['n_words_src'] = 20000
# learning details
config['decay_c'] = 0
config['clip_c'] = 1.
config['use_dropout'] = False
config['lrate'] = 0.0001
config['optimizer'] = 'adadelta'
config['patience'] = 1000
config['maxlen'] = 50
config['batch_size'] = 32
config['valid_batch_size'] = 64
config['validFreq'] = 1000
config['dispFreq'] = 50
config['saveFreq'] = 1000
config['sampleFreq'] = 99
return config
def rl_config():
"""Configuration for training the agent using REINFORCE algorithm."""
config = OrderedDict() # general configuration
# work-space
config['workspace'] = model_home
# training set (source, target); or leave it None, agent will use the same corpus saved in the model
config['datasets'] = [data_home + 'train.un16.en-zh.en.c0.tok.clean.bpe20k.np',
data_home + 'train.un16.en-zh.zh.c0.tok.clean.bpe20k.np']
# validation set (source, target); or leave it None, agent will use the same corpus saved in the model
config['valid_datasets'] = [data_home + 'devset.un16.en-zh.en.c0.tok.bpe20k.np',
data_home + 'devset.un16.en-zh.zh.c0.tok.bpe20k.np']
# vocabulary (source, target); or leave it None, agent will use the same dictionary saved in the model
config['dictionaries'] = [data_home + 'train.un16.en-zh.en.c0.tok.clean.bpe20k.vocab.pkl',
data_home + 'train.un16.en-zh.zh.c0.tok.clean.bpe20k.vocab.pkl']
# pretrained model
config['model'] = model_home + '.pretrained/model_un16_bpe2k_uni_en-zh.npz'
config['option'] = model_home + '.pretrained/model_un16_bpe2k_uni_en-zh.npz.pkl'
# critical training parameters.
config['sample'] = 10
config['batchsize'] = 10
config['rl_maxlen'] = 100
config['target_ap'] = 0.8 # 0.75 # target delay if using AP as reward.
config['target_cw'] = 8 # if cw > 0 use cw mode
# under-construction
config['forget'] = False
# learning rate
config['lr_policy'] = 0.0002
config['lr_model'] = 0.00002
# policy parameters
config['prop'] = 0.5 # leave it default
config['recurrent'] = True # use a recurrent agent
config['layernorm'] = False # layer normalalization for the GRU agent.
config['updater'] = 'REINFORCE' # 'TRPO' not work well.
config['act_mask'] = True # leave it default
# old model parameters (maybe useless, leave them default)
config['step'] = 1
config['peek'] = 1
config['s0'] = 1
config['gamma'] = 1
config['Rtype'] = 10
config['maxsrc'] = 10
config['pre'] = False
config['coverage'] = False
config['upper'] = False
config['finetune'] = True
config['train_gt'] = False # when training with GT, fix the random agent??
config['full_att'] = True
config['predict'] = True
return config