-
Notifications
You must be signed in to change notification settings - Fork 5
/
train.py
141 lines (119 loc) · 4.85 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import json
from tqdm import tqdm
import numpy as np
import os
import torch
from torch import nn, optim
from tensorboard_logger import configure, log_value
from ntm import NTM
from ntm.datasets import CopyDataset, RepeatCopyDataset, AssociativeDataset, NGram, PrioritySort
from ntm.args import get_parser
args = get_parser().parse_args()
configure("runs/")
# ----------------------------------------------------------------------------
# -- initialize datasets, model, criterion and optimizer
# ----------------------------------------------------------------------------
args.task_json = 'ntm/tasks/copy.json'
'''
args.task_json = 'ntm/tasks/repeatcopy.json'
args.task_json = 'ntm/tasks/associative.json'
args.task_json = 'ntm/tasks/ngram.json'
args.task_json = 'ntm/tasks/prioritysort.json'
'''
task_params = json.load(open(args.task_json))
dataset = CopyDataset(task_params)
'''
dataset = RepeatCopyDataset(task_params)
dataset = AssociativeDataset(task_params)
dataset = NGram(task_params)
dataset = PrioritySort(task_params)
'''
"""
For the Copy task, input_size: seq_width + 2, output_size: seq_width
For the RepeatCopy task, input_size: seq_width + 2, output_size: seq_width + 1
For the Associative task, input_size: seq_width + 2, output_size: seq_width
For the NGram task, input_size: 1, output_size: 1
For the Priority Sort task, input_size: seq_width + 1, output_size: seq_width
"""
ntm = NTM(input_size=task_params['seq_width'] + 2,
output_size=task_params['seq_width'],
controller_size=task_params['controller_size'],
memory_units=task_params['memory_units'],
memory_unit_size=task_params['memory_unit_size'],
num_heads=task_params['num_heads'])
criterion = nn.BCELoss()
# As the learning rate is task specific, the argument can be moved to json file
optimizer = optim.RMSprop(ntm.parameters(),
lr=args.lr,
alpha=args.alpha,
momentum=args.momentum)
'''
optimizer = optim.Adam(ntm.parameters(), lr=args.lr,
betas=(args.beta1, args.beta2))
'''
args.saved_model = 'saved_model_copy.pt'
'''
args.saved_model = 'saved_model_repeatcopy.pt'
args.saved_model = 'saved_model_associative.pt'
args.saved_model = 'saved_model_ngram.pt'
args.saved_model = 'saved_model_prioritysort.pt'
'''
cur_dir = os.getcwd()
PATH = os.path.join(cur_dir, args.saved_model)
# ----------------------------------------------------------------------------
# -- basic training loop
# ----------------------------------------------------------------------------
losses = []
errors = []
for iter in tqdm(range(args.num_iters)):
optimizer.zero_grad()
ntm.reset()
data = dataset[iter]
input, target = data['input'], data['target']
out = torch.zeros(target.size())
# -------------------------------------------------------------------------
# loop for other tasks
# -------------------------------------------------------------------------
for i in range(input.size()[0]):
# to maintain consistency in dimensions as torch.cat was throwing error
in_data = torch.unsqueeze(input[i], 0)
ntm(in_data)
# passing zero vector as input while generating target sequence
in_data = torch.unsqueeze(torch.zeros(input.size()[1]), 0)
for i in range(target.size()[0]):
out[i] = ntm(in_data)
# -------------------------------------------------------------------------
# -------------------------------------------------------------------------
# loop for NGram task
# -------------------------------------------------------------------------
'''
for i in range(task_params['seq_len'] - 1):
in_data = input[i].view(1, -1)
ntm(in_data)
target_data = torch.zeros([1]).view(1, -1)
out[i] = ntm(target_data)
'''
# -------------------------------------------------------------------------
loss = criterion(out, target)
losses.append(loss.item())
loss.backward()
# clips gradient in the range [-10,10]. Again there is a slight but
# insignificant deviation from the paper where they are clipped to (-10,10)
nn.utils.clip_grad_value_(ntm.parameters(), 10)
optimizer.step()
binary_output = out.clone()
binary_output = binary_output.detach().apply_(lambda x: 0 if x < 0.5 else 1)
# sequence prediction error is calculted in bits per sequence
error = torch.sum(torch.abs(binary_output - target))
errors.append(error.item())
# ---logging---
if iter % 200 == 0:
print('Iteration: %d\tLoss: %.2f\tError in bits per sequence: %.2f' %
(iter, np.mean(losses), np.mean(errors)))
log_value('train_loss', np.mean(losses), iter)
log_value('bit_error_per_sequence', np.mean(errors), iter)
losses = []
errors = []
# ---saving the model---
torch.save(ntm.state_dict(), PATH)
# torch.save(ntm, PATH)