-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtf_config.py
96 lines (74 loc) · 3.36 KB
/
tf_config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import matplotlib.pyplot as plt
# 30/3/23 DH: Refactor of model creation + DB model access
from tf_model import *
#from gspread_errors import *
class TFConfig(object):
def __init__(self, integer=False) -> None:
# 8/5/23 DH:
"""
Load and prepare the [MNIST dataset](http://yann.lecun.com/exdb/mnist/).
Convert the sample data from integers to floating-point numbers (WHY...???)
("The training set contains 60000 examples, and the test set 10000 examples")
"""
mnist = tf.keras.datasets.mnist
# https://www.tensorflow.org/api_docs/python/tf/keras/datasets/mnist/load_data()
(self.x_train, self.y_train), (self.x_test, self.y_test) = mnist.load_data()
# 6/5/23 DH: WHY...????
# '$ tf-test train' appears to be more accurate for floats vs integers
# (5800 vs 7600 for 20Dense-700Train)
if integer is False:
self.x_train = self.x_train / 255.0
self.x_test = self.x_test / 255.0
self.tfModel = TFModel()
# 6/5/23 DH: 'self.digitDict' used to be single image per digit key (rather than a list)
def displayDictImg(self,imgDict,elem):
self.displayImg(imgDict[elem])
# 6/5/23 DH: If a key/mouse is pressed on an image then it goes into prune mode
def displayImgList(self,imgList):
newImgList = []
for index, img in enumerate(imgList):
notification = self.displayImg(img)
if notification is not None:
print(index)
newImgList.append(img)
if len(newImgList) > 0:
return True, newImgList
return False, imgList
def displayImg(self,img, timeout=1):
# https://matplotlib.org/3.5.3/api/_as_gen/matplotlib.pyplot.html
plt.imshow(img, cmap='gray_r')
# 22/1/23 DH: Calling function without '()' does not return an error but does NOT execute it (like function pointer)
plt.draw()
retButtonPress = plt.waitforbuttonpress(timeout=timeout)
# 'True' if key press, 'False' if mouse press, 'None' if timeout
return retButtonPress
def modelEval(self,start=False):
print("--- model.evaluate() ---")
print("Using x_test + y_test (%i): "%(self.x_test.shape[0]))
evalRes = self.tfModel.model.evaluate(self.x_test, self.y_test, verbose=2)
accuracyPercent = "{:.2f}".format(evalRes[1])
print("evaluate() accuracy:",accuracyPercent)
print("------------------------\n")
if start == True:
self.startPercent = accuracyPercent
# 28/4/23 DH: Lowest percent gets lowered as appropriate during the retraining
self.lowestPercent = accuracyPercent
self.accuracies = []
if hasattr(self, 'accuracies'):
self.accuracies.append(accuracyPercent)
return accuracyPercent
def build(self, paramDict):
self.dense1 = paramDict['dense1']
self.dropout1 = paramDict['dropout1']
self.trainingNum = paramDict['trainingNum']
# 24/4/23 DH:
self.x_trainSet = paramDict['x_trainSet']
self.y_trainSet = paramDict['y_trainSet']
# 23/4/23 DH:
self.epochs = paramDict['epochs']
#print("x_train:",type(self.x_trainSet),self.x_trainSet.shape )
#print("y_train:",type(self.y_trainSet),self.y_trainSet.shape )
self.model = self.tfModel.createTrainedModel(dense1=self.dense1, dropout1=self.dropout1,
x_trainSet=self.x_trainSet, y_trainSet=self.y_trainSet,
epochs=self.epochs)
self.modelEval(start=True)