-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathopt.py
176 lines (167 loc) · 5.69 KB
/
opt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import argparse
def parse_pretrain_opt():
parser = argparse.ArgumentParser()
# general settings
parser.add_argument(
'--data_path',
choices=['./dataset/dailydialog++'])
parser.add_argument(
'--split2four',
help='split the dialog whose turn length longer than 4',
choices=['True', 'False'])
parser.add_argument(
'--train_mode',
choices=['easy', 'medium', 'hard'],
default='medium',
help='curriculum learning mode when training')
parser.add_argument(
'--dataset',
choices=['dailydialog_plusplus_mlr'],
help='Dataset name (currently supports `dailydialog_plusplus_mlr`).')
parser.add_argument(
'--model',
choices=['bert_metric', 'roberta_metric'],
help='Model name ')
parser.add_argument(
'--trainer',
choices=['mlr_pretrain'],
help='Trainer name (currently supports `mlr_pretrain`).')
parser.add_argument(
'--mode',
choices=['train', 'test'],
help='Runing mode (currently supports `train` / `test`).')
parser.add_argument(
'--gpu',
help='the id of GPU for loading model and data.')
parser.add_argument(
'--seed',
type=int,
help='The seed for reproducibility (optional).')
parser.add_argument(
'--max_seq_length',
type=int,
default=512,
help='The max sequence length of the context-response pair.')
# data settings
parser.add_argument(
'--train_batch_size',
type=int,
help='The batch size for training.')
parser.add_argument(
'--eval_batch_size',
type=int,
help='The batch size for validation.')
parser.add_argument(
'--test_batch_size',
type=int,
help='The batch size for testing.')
# trainer settings
parser.add_argument(
'--num_epochs',
type=int,
help='The number of epochs for training.')
parser.add_argument(
'--display_steps',
type=int,
help='Print training loss every display_steps.')
parser.add_argument(
'--update_steps',
type=int,
help='Update centroids every update_steps.')
parser.add_argument(
'--learning_rate',
type=float,
help='The initial learning rate for training.')
parser.add_argument(
'--warmup_proportion',
type=float,
help='The warmup proportion of the warmup strategy for LR scheduling.')
parser.add_argument(
'--inter_distance_lower_bound',
type=float,
help='The lower bound of the inter-cluster distance.')
parser.add_argument(
'--intra_distance_upper_bound',
type=float,
help='The upper bound of the intra-cluster distance.')
parser.add_argument(
'--feature_distance_lower_bound',
type=float,
help='The lower bound of the inter-cluster feature distance.'
'(for dual_mlr_loss computing)')
parser.add_argument(
'--feature_distance_upper_bound',
type=float,
help='The upper bound of the intra-cluster feature distance.'
'(for dual_mlr_loss computing)')
parser.add_argument(
'--score_distance_lower_bound',
#type=float,
help='The lower bound of the inter-cluster score distance.'
'(for dual_mlr_loss computing)')
parser.add_argument(
'--score_distance_upper_bound',
type=float,
help='The upper bound of the intra-cluster score distance.'
'(for dual_mlr_loss computing)')
parser.add_argument(
'--feature_loss_weight',
type=float,
default=1.,
help='The loss weight for feature_mlr_loss.')
parser.add_argument(
'--score_loss_weight',
type=float,
default=1.,
help='The loss weight for score_mlr_loss.')
parser.add_argument(
'--bce_loss_weight',
type=float,
default=1.,
help='The loss weight for bce_loss.')
parser.add_argument(
'--monitor_metric_name',
choices=['acc', 'mlr_loss', 'dual_mlr_loss', 'dual_fat_loss',
'margin_ranking_loss', 'fat_loss', 'vanilla_mlr_loss'],
nargs='+',
help='The metric used as the monitor for saving checkpoints.')
parser.add_argument(
'--monitor_metric_type',
choices=['min', 'max'],
nargs='+',
help='The quantified type of the monitor metric.'
'"min" means lower is better, while "max" means higher is better.')
parser.add_argument(
'--checkpoint_dir_path',
help='The path of the output checkpoint directory.')
parser.add_argument(
'--logging_level',
choices=['INFO', 'DEBUG'],
help='The output logging level.')
parser.add_argument(
'--centroid_mode',
choices=['mean'],
help='The mode to compute the centroid feature.')
parser.add_argument(
'--distance_mode',
choices=['cosine'],
help='The mode to compute the distance between two features.')
parser.add_argument(
'--pretrained_model_name',
#choices=['bert-base-uncased','roberta-base'],
help='The name of the pretrained model.')
parser.add_argument(
'--weighted_s_loss',
action='store_true',
help='Computing the inter-cluster spearation loss with weights.')
parser.add_argument(
'--use_projection_head',
action='store_true',
help='Computing the inter-cluster spearation loss with weights.')
'''
parser.add_argument(
'--checkpoint_file_name',
help='The file name of the checkpoint to be loaded.')
'''
args = parser.parse_args()
return args