-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtesting.py
92 lines (75 loc) · 3.58 KB
/
testing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from src.training_data import Data
from src.cnn import CNN
from src.model import Model
## TODO: Add unit testing using pytest
def loadData_testing(testingData):
print('TESTING DATA LOADING:')
for i, dataset in enumerate(testingData['dataset']):
data = Data(dataset)
data.load()
checkScale(data, [0, 255])
if not testingData['colour'][i]:
data.rgb2greyScale()
data.blackAndWhite()
checkScale(data, [0,1])
checkDim(data, testingData['resolution'][i])
checkChannels(data, testingData['resolution'][i])
print('Data loaded correctly for "{}" dataset'.format(dataset))
print('\n')
def train_testing(testingData):
print('TESTING CNN TRAINING:')
for i, dataset in enumerate(testingData['dataset']):
data = Data(dataset)
data.load()
data = getOneDataPoint(data)
if not testingData['colour'][i]:
data.rgb2greyScale()
try:
cnn = CNN(data)
except Exception as e:
raise ValueError('Failed when creating model. {}'.format(e))
try:
cnn.build(nConvBlocks=1,nFilters=1, kernelSize=1, stride=2)
except Exception as e:
raise ValueError('Failed when building model. {}'.format(e))
try:
model = Model(cnn)
model.compile(optimizer='adam',loss='mean_squared_error')
except Exception as e:
raise ValueError('Failed when compiling model. {}'.format(e))
try:
model.train(epochs=1, nBatch=1, earlyStopPatience=1)
except Exception as e:
raise ValueError('Failed when training model. {}'.format(e))
try:
model.predict()
except Exception as e:
raise ValueError('Failed when predicting. {}'.format(e))
print('Training successfully tested for "{}" dataset'.format(dataset))
del model; del data
def checkDim(data, refResolution):
assert data.x_train.shape[1:3] == refResolution[0:2], 'Training dataset dimension is incorrect'
assert data.x_val.shape[1:3] == refResolution[0:2], 'Validation dataset dimension is incorrect'
assert data.x_test.shape[1:3] == refResolution[0:2], 'Testing dataset dimension is incorrect'
def checkChannels(data, refResolution):
assert data.resolution[2] == refResolution[2], 'Data has an incorrect number of channels'
def checkScale(data, lim):
assert data.x_train.min() >= lim[0] and data.x_train.max() <= lim[1],\
'Training dataset not scaled between {} and {}'.format(lim[0],lim[1])
assert data.x_val.min() >= lim[0] and data.x_val.max() <= lim[1], \
'Validation dataset not scaled between {} and {}'.format(lim[0], lim[1])
assert data.x_test.min() >= lim[0] and data.x_test.max() <= lim[1],\
'Testing dataset not scaled between {} and {}'.format(lim[0],lim[1])
def getOneDataPoint(data):
data.x_train = data.x_train[0].reshape(1, data.resolution[0], data.resolution[1], data.resolution[2])
data.x_val = data.x_val[0].reshape(1, data.resolution[0], data.resolution[1], data.resolution[2])
data.x_test = data.x_test[0].reshape(1, data.resolution[0], data.resolution[1], data.resolution[2])
return data
if __name__ == '__main__':
testingData = {
'dataset' : ['mnist', 'afreightdata_test', 'afreightdata_test', 'beam_test', 'beam_homog_test'],
'colour' : [False, False, True, False, False, False],
'resolution' : [(28, 28, 1), (120, 160, 1), (120, 160, 3), (100,300,1), (80,160,1), (80,160,1)]
}
loadData_testing(testingData)
train_testing(testingData)