From c120cd52a05b29fdf564bca1d8c5d678c388165e Mon Sep 17 00:00:00 2001 From: nic Date: Sat, 8 Dec 2018 12:38:57 +0100 Subject: [PATCH] Snake eyes dataset, and a minimal example --- examples/snakeeyes_cnn.py | 70 ++++++++++++++++++++++++++++ keras_contrib/datasets/snake_eyes.py | 60 ++++++++++++++++++++++++ 2 files changed, 130 insertions(+) create mode 100644 examples/snakeeyes_cnn.py create mode 100644 keras_contrib/datasets/snake_eyes.py diff --git a/examples/snakeeyes_cnn.py b/examples/snakeeyes_cnn.py new file mode 100644 index 000000000..33b6d85a1 --- /dev/null +++ b/examples/snakeeyes_cnn.py @@ -0,0 +1,70 @@ +'''Trains a simple convnet on the Snake Eyes dataset. + +This is largely based on the Keras MNIST example. Training just 1 epoch can +achieve 77% accuracy, and 95% with 2 epochs. Faster training or simpler models +might be possible. An accuracy of 100% can also be attained with adjustments. +''' + +from __future__ import print_function +import keras +from keras_contrib.datasets import snake_eyes +from keras.models import Sequential +from keras.layers import Dense, Dropout, Flatten +from keras.layers import Conv2D, MaxPooling2D +from keras import backend as K + +batch_size = 128 +num_classes = 12 +epochs = 12 + +# input image dimensions +img_rows, img_cols = 20, 20 + +# the data, split between train and test sets +(x_train, y_train), (x_test, y_test) = snake_eyes.load_data() + +if K.image_data_format() == 'channels_first': + x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols) + x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols) + input_shape = (1, img_rows, img_cols) +else: + x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1) + x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1) + input_shape = (img_rows, img_cols, 1) + +x_train = x_train.astype('float32') +x_test = x_test.astype('float32') +x_train /= 255 +x_test /= 255 +print('x_train shape:', x_train.shape) +print(x_train.shape[0], 'train samples') +print(x_test.shape[0], 'test samples') + +# convert class vectors to binary class matrices +y_train = keras.utils.to_categorical(y_train - 1, num_classes) +y_test = keras.utils.to_categorical(y_test - 1, num_classes) + +model = Sequential() +model.add(Conv2D(32, kernel_size=(3, 3), + activation='relu', + input_shape=input_shape)) +model.add(Conv2D(64, (3, 3), activation='relu')) +model.add(MaxPooling2D(pool_size=(2, 2))) +model.add(Dropout(0.25)) +model.add(Flatten()) +model.add(Dense(128, activation='relu')) +model.add(Dropout(0.5)) +model.add(Dense(num_classes, activation='softmax')) + +model.compile(loss=keras.losses.categorical_crossentropy, + optimizer=keras.optimizers.Adadelta(), + metrics=['accuracy']) + +model.fit(x_train, y_train, + batch_size=batch_size, + epochs=epochs, + verbose=1, + validation_data=(x_test, y_test)) +score = model.evaluate(x_test, y_test, verbose=0) +print('Test loss:', score[0]) +print('Test accuracy:', score[1]) diff --git a/keras_contrib/datasets/snake_eyes.py b/keras_contrib/datasets/snake_eyes.py new file mode 100644 index 000000000..44f3ff3b3 --- /dev/null +++ b/keras_contrib/datasets/snake_eyes.py @@ -0,0 +1,60 @@ +"""Snake Eyes dataset. + +More information available at http://github.com/nlw0/snake-eyes +or http://www.kaggle.com/nicw102168/snake-eyes/home +""" +import gzip +import os + +from keras.utils.data_utils import get_file +import numpy as np + + +def load_data(): + """Loads the Snake Eyes dataset. + # Returns + Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`. + """ + + train_batches, test_set = load_batches() + x_train = np.concatenate([batch[0] for batch in train_batches]) + y_train = np.concatenate([batch[1] for batch in train_batches]) + return (x_train, y_train), test_set + + +def load_batches(): + """Loads the Snake Eyes dataset. + # Returns + Tuples of Numpy arrays: `[(x_train, y_train)...], (x_test, y_test)`. + """ + dirname = os.path.join('datasets', 'snake-eyes') + base = 'https://raw.githubusercontent.com/nlw0/snake-eyes/master/' + files = ['snakeeyes_00.dat.gz', 'snakeeyes_01.dat.gz', + 'snakeeyes_02.dat.gz', 'snakeeyes_03.dat.gz', + 'snakeeyes_04.dat.gz', 'snakeeyes_05.dat.gz', + 'snakeeyes_06.dat.gz', 'snakeeyes_07.dat.gz', + 'snakeeyes_08.dat.gz', 'snakeeyes_09.dat.gz', + 'snakeeyes_test.dat.gz'] + + paths = [] + for fname in files: + paths.append(get_file(fname, + origin=base + fname, + cache_subdir=dirname)) + + train_batches = [] + for data_path in paths[0:-1]: + with gzip.open(data_path, 'rb') as datafile: + data_train = np.frombuffer(datafile.read(), + np.uint8).reshape(100000, 401) + y_train = data_train[:, 0] + x_train = data_train[:, 1:].reshape(len(y_train), 20, 20) + train_batches.append((x_train, y_train)) + + with gzip.open(paths[-1], 'rb') as datafile: + data_test = np.frombuffer(datafile.read(), + np.uint8).reshape(10000, 401) + y_test = data_test[:, 0] + x_test = data_test[:, 1:].reshape(len(y_test), 20, 20) + + return train_batches, (x_test, y_test)