@@ -95,12 +95,15 @@ def show_iris(i=0, j=1, k=2):
95
95
# MNIST
96
96
97
97
98
- def load_MNIST (path = "aima-data/MNIST" ):
98
+ def load_MNIST (path = "aima-data/MNIST/Digits" , fashion = False ):
99
99
import os , struct
100
100
import array
101
101
import numpy as np
102
102
from collections import Counter
103
103
104
+ if fashion :
105
+ path = "aima-data/MNIST/Fashion"
106
+
104
107
plt .rcParams .update (plt .rcParamsDefault )
105
108
plt .rcParams ['figure.figsize' ] = (10.0 , 8.0 )
106
109
plt .rcParams ['image.interpolation' ] = 'nearest'
@@ -143,8 +146,17 @@ def load_MNIST(path="aima-data/MNIST"):
143
146
return (train_img , train_lbl , test_img , test_lbl )
144
147
145
148
146
- def show_MNIST (labels , images , samples = 8 ):
147
- classes = ["0" , "1" , "2" , "3" , "4" , "5" , "6" , "7" , "8" , "9" ]
149
+ digit_classes = [str (i ) for i in range (10 )]
150
+ fashion_classes = ["T-shirt/top" , "Trouser" , "Pullover" , "Dress" , "Coat" ,
151
+ "Sandal" , "Shirt" , "Sneaker" , "Bag" , "Ankle boot" ]
152
+
153
+
154
+ def show_MNIST (labels , images , samples = 8 , fashion = False ):
155
+ if not fashion :
156
+ classes = digit_classes
157
+ else :
158
+ classes = fashion_classes
159
+
148
160
num_classes = len (classes )
149
161
150
162
for y , cls in enumerate (classes ):
@@ -161,13 +173,19 @@ def show_MNIST(labels, images, samples=8):
161
173
plt .show ()
162
174
163
175
164
- def show_ave_MNIST (labels , images ):
165
- classes = ["0" , "1" , "2" , "3" , "4" , "5" , "6" , "7" , "8" , "9" ]
176
+ def show_ave_MNIST (labels , images , fashion = False ):
177
+ if not fashion :
178
+ item_type = "Digit"
179
+ classes = digit_classes
180
+ else :
181
+ item_type = "Apparel"
182
+ classes = fashion_classes
183
+
166
184
num_classes = len (classes )
167
185
168
186
for y , cls in enumerate (classes ):
169
187
idxs = np .nonzero ([i == y for i in labels ])
170
- print ("Digit" , y , ":" , len (idxs [0 ]), "images." )
188
+ print (item_type , y , ":" , len (idxs [0 ]), "images." )
171
189
172
190
ave_img = np .mean (np .vstack ([images [i ] for i in idxs [0 ]]), axis = 0 )
173
191
#print(ave_img.shape)
0 commit comments