Skip to content

Commit 671a651

Browse files
antmarakisnorvig
authored andcommitted
Update MNIST Functions for Fashion (#646)
* Update notebook.py * Update notebook.py
1 parent aa1a31f commit 671a651

File tree

1 file changed

+24
-6
lines changed

1 file changed

+24
-6
lines changed

notebook.py

+24-6
Original file line numberDiff line numberDiff line change
@@ -95,12 +95,15 @@ def show_iris(i=0, j=1, k=2):
9595
# MNIST
9696

9797

98-
def load_MNIST(path="aima-data/MNIST"):
98+
def load_MNIST(path="aima-data/MNIST/Digits", fashion=False):
9999
import os, struct
100100
import array
101101
import numpy as np
102102
from collections import Counter
103103

104+
if fashion:
105+
path = "aima-data/MNIST/Fashion"
106+
104107
plt.rcParams.update(plt.rcParamsDefault)
105108
plt.rcParams['figure.figsize'] = (10.0, 8.0)
106109
plt.rcParams['image.interpolation'] = 'nearest'
@@ -143,8 +146,17 @@ def load_MNIST(path="aima-data/MNIST"):
143146
return(train_img, train_lbl, test_img, test_lbl)
144147

145148

146-
def show_MNIST(labels, images, samples=8):
147-
classes = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]
149+
digit_classes = [str(i) for i in range(10)]
150+
fashion_classes = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat",
151+
"Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"]
152+
153+
154+
def show_MNIST(labels, images, samples=8, fashion=False):
155+
if not fashion:
156+
classes = digit_classes
157+
else:
158+
classes = fashion_classes
159+
148160
num_classes = len(classes)
149161

150162
for y, cls in enumerate(classes):
@@ -161,13 +173,19 @@ def show_MNIST(labels, images, samples=8):
161173
plt.show()
162174

163175

164-
def show_ave_MNIST(labels, images):
165-
classes = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]
176+
def show_ave_MNIST(labels, images, fashion=False):
177+
if not fashion:
178+
item_type = "Digit"
179+
classes = digit_classes
180+
else:
181+
item_type = "Apparel"
182+
classes = fashion_classes
183+
166184
num_classes = len(classes)
167185

168186
for y, cls in enumerate(classes):
169187
idxs = np.nonzero([i == y for i in labels])
170-
print("Digit", y, ":", len(idxs[0]), "images.")
188+
print(item_type, y, ":", len(idxs[0]), "images.")
171189

172190
ave_img = np.mean(np.vstack([images[i] for i in idxs[0]]), axis = 0)
173191
#print(ave_img.shape)

0 commit comments

Comments
 (0)