-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
129 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
from PIL import Image, ImageFilter | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
|
||
def imageprepare(img): | ||
""" | ||
This function returns the pixel values. | ||
The imput is an image file location. | ||
""" | ||
im = Image.open(img).convert('L') | ||
width = float(im.size[0]) | ||
height = float(im.size[1]) | ||
newImage = Image.new('L', (28, 28), (255)) # creates white canvas of 28x28 pixels | ||
|
||
if width > height: # check which dimension is bigger | ||
# Width is bigger. Width becomes 20 pixels. | ||
nheight = int(round((20.0 / width * height), 0)) # resize height according to ratio width | ||
if (nheight == 0): # rare case but minimum is 1 pixel | ||
nheight = 1 | ||
# resize and sharpen | ||
img = im.resize((20, nheight), Image.ANTIALIAS).filter(ImageFilter.SHARPEN) | ||
wtop = int(round(((28 - nheight) / 2), 0)) # calculate horizontal position | ||
newImage.paste(img, (4, wtop)) # paste resized image on white canvas | ||
else: | ||
# Height is bigger. Heigth becomes 20 pixels. | ||
nwidth = int(round((20.0 / height * width), 0)) # resize width according to ratio height | ||
if (nwidth == 0): # rare case but minimum is 1 pixel | ||
nwidth = 1 | ||
# resize and sharpen | ||
img = im.resize((nwidth, 20), Image.ANTIALIAS).filter(ImageFilter.SHARPEN) | ||
wleft = int(round(((28 - nwidth) / 2), 0)) # caculate vertical pozition | ||
newImage.paste(img, (wleft, 4)) # paste resized image on white canvas | ||
|
||
# newImage.save("sample.png | ||
|
||
tv = list(newImage.getdata()) # get pixel values | ||
|
||
# normalize pixels to 0 and 1. 0 is pure white, 1 is pure black. | ||
tva = [(255 - x) * 1.0 / 255.0 for x in tv] | ||
#print(tva) | ||
return tva | ||
|
||
img_array = imageprepare('./test3.jpeg')#file path here | ||
# mnist IMAGES are 28x28=784 pixels | ||
print(len(x)) | ||
img_array_reshaped = np.array(img_array).reshape(28,28) | ||
|
||
plt.imshow(img_array_reshaped, cmap='gist_gray') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
# https://github.com/keras-team/keras/blob/master/examples/mnist_mlp.py | ||
# https://nextjournal.com/schmudde/ml4a-mnist | ||
# https://github.com/keras-team/keras/blob/master/examples/mnist_cnn.py (CONVOLUTIONAL) | ||
# https://github.com/Hvass-Labs/TensorFlow-Tutorials/blob/master/01_Simple_Linear_Model.ipynb (TF) | ||
|
||
import keras | ||
from keras.datasets import mnist | ||
from keras.models import Sequential | ||
from keras.layers import Dense, Dropout | ||
from keras.optimizers import RMSprop | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
import random | ||
|
||
''' | ||
TRY: | ||
-> Loss functions: | ||
Log loss | ||
KL Divergence | ||
Mean Squared Error | ||
Categorical_crossentropy (atual) | ||
-> Optimizers: | ||
RMSprop (atual) | ||
''' | ||
|
||
|
||
batch_size = 128 | ||
num_classes = 10 # {0, 1, 2, 3, 4, 5, 6, 7, 8, 9} = 10 | ||
epochs = 20 | ||
|
||
(x_train, y_train), (x_test, y_test) = mnist.load_data() | ||
|
||
plt.figure() | ||
numbers = np.concatenate([np.concatenate([x_train[i] for i in [int(random.random() * len(x_train)) for i in range(9)]], axis=1) for i in range(9)], axis=0) | ||
plt.imshow(numbers, cmap='gist_gray', interpolation='none') | ||
plt.xticks([]) | ||
plt.yticks([]) | ||
plt.xlabel('Alguns números do MNIST') | ||
plt.show() | ||
|
||
''' | ||
r = int(random.random() * len(x_train)) | ||
numbers2 = np.array([x_train[x] for x in range(r-5, r)]) | ||
plt.xlabel('Alguns números do MNIST, começando de ' + str(r-5)) | ||
plt.imshow(numbers2.reshape(int(numbers2.size/28), 28), cmap='gist_gray') | ||
''' | ||
|
||
x_train = x_train.reshape(60000, 784) | ||
x_test = x_test.reshape(10000, 784) | ||
x_train = x_train.astype('float32') | ||
x_test = x_test.astype('float32') | ||
|
||
x_train /= 255 # Cores [0, 255] -> [0, 1] | ||
x_test /= 255 | ||
|
||
print(x_train.shape[0], 'números/arquivos de treino') | ||
print(x_test.shape[0], 'números/arquivos de teste') | ||
|
||
y_train = keras.utils.to_categorical(y_train, num_classes) | ||
y_test = keras.utils.to_categorical(y_test, num_classes) | ||
|
||
classifier = Sequential() | ||
classifier.add(Dense(512, activation='relu', input_shape=(784,))) | ||
classifier.add(Dropout(0.2)) | ||
classifier.add(Dense(512, activation='relu')) | ||
classifier.add(Dropout(0.2)) | ||
classifier.add(Dense(num_classes, activation='softmax')) | ||
|
||
#classifier.summary() | ||
|
||
classifier.compile(loss='categorical_crossentropy', optimizer=RMSprop(), metrics=['accuracy']) | ||
|
||
classifier.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, verbose=1, validation_data=(x_test, y_test)) | ||
|
||
score = classifier.evaluate(x_test, y_test, verbose=1) | ||
|
||
print('Test loss:', score[0]) | ||
print('Test accuracy:', score[1]) |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.