forked from dmlc/dgl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathopts.py
55 lines (47 loc) · 2.93 KB
/
opts.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import torch
import argparse
def fill_config(args):
# dirty work
args.device = torch.device(args.gpu)
args.dec_ninp = args.nhid * 3 if args.title else args.nhid * 2
args.fnames = [args.train_file, args.valid_file, args.test_file]
return args
def vocab_config(args, ent_vocab, rel_vocab, text_vocab, ent_text_vocab, title_vocab):
# dirty work
args.ent_vocab = ent_vocab
args.rel_vocab = rel_vocab
args.text_vocab = text_vocab
args.ent_text_vocab = ent_text_vocab
args.title_vocab = title_vocab
return args
def get_args():
args = argparse.ArgumentParser(description='Graph Writer in DGL')
args.add_argument('--nhid', default=500, type=int, help='hidden size')
args.add_argument('--nhead', default=4, type=int, help='number of heads')
args.add_argument('--head_dim', default=125, type=int, help='head dim')
args.add_argument('--weight_decay', default=0.0, type=float, help='weight decay')
args.add_argument('--prop', default=6, type=int, help='number of layers of gnn')
args.add_argument('--title', action='store_true', help='use title input')
args.add_argument('--test', action='store_true', help='inference mode')
args.add_argument('--batch_size', default=32, type=int, help='batch_size')
args.add_argument('--beam_size', default=4, type=int, help='beam size, 1 for greedy')
args.add_argument('--epoch', default=20, type=int, help='training epoch')
args.add_argument('--beam_max_len', default=200, type=int, help='max length of the generated text')
args.add_argument('--enc_lstm_layers', default=2, type=int, help='number of layers of lstm')
args.add_argument('--lr', default=1e-1, type=float, help='learning rate')
#args.add_argument('--lr_decay', default=1e-8, type=float, help='')
args.add_argument('--clip', default=1, type=float, help='gradient clip')
args.add_argument('--emb_drop', default=0.0, type=float, help='embedding dropout')
args.add_argument('--attn_drop', default=0.1, type=float, help='attention dropout')
args.add_argument('--drop', default=0.1, type=float, help='dropout')
args.add_argument('--lp', default=1.0, type=float, help='length penalty')
args.add_argument('--graph_enc', default='gtrans', type=str, help='gnn mode, we only support the graph transformer now')
args.add_argument('--train_file', default='data/unprocessed.train.json', type=str, help='training file')
args.add_argument('--valid_file', default='data/unprocessed.val.json', type=str, help='validation file')
args.add_argument('--test_file', default='data/unprocessed.test.json', type=str, help='test file')
args.add_argument('--save_dataset', default='data.pickle', type=str, help='save path of dataset')
args.add_argument('--save_model', default='saved_model.pt', type=str, help='save path of model')
args.add_argument('--gpu', default=0, type=int, help='gpu mode')
args = args.parse_args()
args = fill_config(args)
return args