-
Notifications
You must be signed in to change notification settings - Fork 1
/
train_generator.py
executable file
·64 lines (60 loc) · 3.2 KB
/
train_generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import argparse
from generator.train import main
parser = argparse.ArgumentParser()
# Path arguments
parser.add_argument('--train', metavar='FILE', required=True,
help='path to training file')
parser.add_argument('--valid', metavar='FILE', required=True,
help='path to validation file')
parser.add_argument('--save-dir', default='checkpoints', metavar='DIR',
help='directory to save checkpoints and outputs')
parser.add_argument('--load-model', default='', metavar='FILE',
help='path to load checkpoint if specified')
# Architecture arguments
parser.add_argument('--vocab-size', type=int, default=10000, metavar='N',
help='keep N most frequent words in vocabulary')
parser.add_argument('--dim_z', type=int, default=128, metavar='D',
help='dimension of latent variable z')
parser.add_argument('--dim_emb', type=int, default=512, metavar='D',
help='dimension of word embedding')
parser.add_argument('--dim_h', type=int, default=1024, metavar='D',
help='dimension of hidden state per layer')
parser.add_argument('--nlayers', type=int, default=1, metavar='N',
help='number of layers')
parser.add_argument('--dim_d', type=int, default=512, metavar='D',
help='dimension of hidden state in AAE discriminator')
# Model arguments
parser.add_argument('--model_type', default='dae', metavar='M',
choices=['dae', 'vae', 'aae'],
help='which model to learn')
parser.add_argument('--lambda_kl', type=float, default=0, metavar='R',
help='weight for kl term in VAE')
parser.add_argument('--lambda_adv', type=float, default=0, metavar='R',
help='weight for adversarial loss in AAE')
parser.add_argument('--lambda_p', type=float, default=0, metavar='R',
help='weight for L1 penalty on posterior log-variance')
parser.add_argument('--noise', default='0,0,0,0', metavar='P,P,P,K',
help='word drop prob, blank prob, substitute prob'
'max word shuffle distance')
# Training arguments
parser.add_argument('--dropout', type=float, default=0.5, metavar='DROP',
help='dropout probability (0 = no dropout)')
parser.add_argument('--lr', type=float, default=0.0005, metavar='LR',
help='learning rate')
# parser.add_argument('--clip', type=float, default=0.25, metavar='NORM',
# help='gradient clipping')
parser.add_argument('--epochs', type=int, default=50, metavar='N',
help='number of training epochs')
parser.add_argument('--batch-size', type=int, default=256, metavar='N',
help='batch size')
# Others
parser.add_argument('--seed', type=int, default=1111, metavar='N',
help='random seed')
parser.add_argument('--no-cuda', action='store_true',
help='disable CUDA')
parser.add_argument('--log-interval', type=int, default=100, metavar='N',
help='report interval')
if __name__ == '__main__':
args = parser.parse_args()
args.noise = [float(x) for x in args.noise.split(',')]
main(args)