-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
54 lines (43 loc) · 2.62 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch
from trainer import *
from dataloader import *
from argparse import ArgumentParser
data_dir = './processed_data'
train_data = torch.load(os.path.join(data_dir,'train_data.pt'), map_location=device )
val_data = torch.load(os.path.join(data_dir,'val_data.pt'), map_location=device )
test_data = torch.load(os.path.join(data_dir,'test_data.pt'), map_location=device )
vocab = torch.load(os.path.join(data_dir,'vocab.pt'), map_location=device )
# only use for debugging
# train_data = torch.randint(0, 100, (150,50))
# val_data = torch.randint(0, 100, (100,50))
# test_data = torch.randint(0, 100, (100,50))
if __name__ == "__main__":
parser = ArgumentParser(conflict_handler="resolve")
# parser.add_argument("-emsize", type=int, default=50, help="multilabel classification")
parser.add_argument("-d_model", type=int, default=50, help="embedding dimension")
parser.add_argument("-d_v", type=int, default=40, help="values dimension ")
parser.add_argument("-d_k", type=int, default=40, help="keys dimension")
parser.add_argument("-dim_feedforward", type=int, default=2048, help="feed forward layer dimension")
parser.add_argument("-nlayers", type=int, default=2, help="number of encoder layers")
parser.add_argument("-nhead", type=int, default=2, help="number of heads")
parser.add_argument("-dropout", type=float, default=0.2, help="dropout rate")
parser.add_argument("-lr", type=float, default=5.0, help="learning rate")
parser.add_argument("-model_save_dir", type=str, default='./models', help="multilabel classification")
parser.add_argument("-logs_dir", type=str, default='./training_logs', help="multilabel classification")
parser.add_argument("-epochs", type=int, default=10, help="multilabel classification")
parser.add_argument("-bptt", type=int, default=35, help="multilabel classification")
parser.add_argument("-atten_type", type=str, default='multilinear', help="multilabel classification")
parser.add_argument("-log_interval", type=int, default=2000, help="multilabel classification")
parser.add_argument("-train_from", type=str, default=None, help="multilabel classification")
args = parser.parse_args()
args.ntokens = len(vocab) # size of vocabulary
# only for debugging
# args.ntokens = 100
# warmup_steps = 4000
# label_smoothing
#build the model
model = TransformerModel(args).to(device)
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('The model has {} parameters.'.format(pytorch_total_params))
trainer = Trainer(model, train_data, val_data, test_data, args)
trainer.train(args.epochs)