-
Notifications
You must be signed in to change notification settings - Fork 2
/
test.py
80 lines (73 loc) · 2.48 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from RNN_models import *
import sys
import numpy as np
import time
from utils import load_pMNIST_data, load_synthetic_data, load_PTB_data
# Global variables
args = sys.argv
data_path = './data/'
seed = int(args[1])
rng = np.random.RandomState(seed)
task = args[2]
sequence_length = int(args[3])
model_type = args[4]
modules = int(args[5])
hidden_units_per_module = int(args[6])
n_r= int(args[7])
learning_rate = np.float32(args[8])
batch_size=int(args[9])
activation=args[10]
n_epoch = 100000
test_frequency = 200
f_out = 'linear'
if task in ['MNIST', 'pMNIST']:
n_in = 1
n_out = 10
n_epoch = 100
if batch_size == 1:
test_frequency = 5000
else:
test_frequency = 20
elif task in ['copying', 'copyingVariable']:
n_in = n_out = 10
f_out = 'softmax'
elif task in ['PTB', 'PTB_5']:
n_in = n_out = 49
if batch_size == 1:
test_frequency = 4000
else:
test_frequency = 400
n_epoch = 20
else:
n_in = 2
n_out = 1
if task != 'MNIST':
string = 'seed-%i__task-%s__T-%i__m-%i__nh-%i__nr-%i__model-%s__lr-%.6f__bs-%i__f-%s' % \
(seed, task, sequence_length, modules, hidden_units_per_module, n_r, model_type, learning_rate, batch_size, activation)
else:
string = 'seed-%i__task-%s__T-%i__m-%i__nh-%i__nr-%i__model-%s__lr-%.6f__bs-%i' % \
(seed, task, sequence_length, modules, hidden_units_per_module, n_r, model_type, learning_rate, batch_size)
print('This is for testing')
load_progress = data_path + 'saved_models/' + string
save_progress = data_path + 'saved_models/' + string
print(string)
sys.stdout.flush()
# Loading the data
if task == 'MNIST':
train_set, valid_set, test_data = load_pMNIST_data(rng, data_path, perm=False)
elif task == 'pMNIST':
train_set, valid_set, test_data = load_pMNIST_data(rng, data_path, perm=True)
elif task == 'PTB':
train_set, valid_set, test_data = load_PTB_data(data_path)
elif task == 'PTB_5':
train_set, valid_set, test_data = load_PTB_data(data_path, 5)
else:
train_set, valid_set = load_synthetic_data(data_path, task, sequence_length)
# Defining the model
start = time.time()
rnn = eval(model_type)(rng, n_in=n_in, n_out=n_out, m=modules, n_r=n_r, hupm=hidden_units_per_module, task_type=task, f_out=eval(f_out), f_act=eval(activation))
end = time.time() - start
print('It took %.6f seconds to built the model' % end)
# Training the model
rnn.optimiser.test(task, sequence_length, valid_set, load_progress)
# test_cost, test_accuracy = np.mean(results, axis=0)