diff --git a/.travis.yml b/.travis.yml index 3e1c858..70c9304 100644 --- a/.travis.yml +++ b/.travis.yml @@ -21,6 +21,8 @@ install: - source activate test-environment # Attempt to install torchvision; on failure, revert back to pre-conda environment. - conda install -q -y torchvision -c soumith || export PATH="$OPATH" + # Install pandas + - conda install -q -y pandas - pip install -r requirements.txt # command to run tests script: py.test -v diff --git a/Augmentor/ImageUtilities.py b/Augmentor/ImageUtilities.py index 45d6656..53719be 100644 --- a/Augmentor/ImageUtilities.py +++ b/Augmentor/ImageUtilities.py @@ -256,6 +256,34 @@ def scan(source_directory, output_directory): return augmentor_images, class_labels +def scan_dataframe(source_dataframe, image_col, category_col, output_directory): + try: + import pandas as pd + except ImportError: + raise ImportError('Pandas is required to use the scan_dataframe function!\nrun pip install pandas and try again') + + # ensure column is categorical + cat_col_series = pd.Categorical(source_dataframe[category_col]) + abs_output_directory = os.path.abspath(output_directory) + class_labels = list(enumerate(cat_col_series.categories)) + + augmentor_images = [] + + for image_path, cat_name, cat_id in zip(source_dataframe[image_col].values, + cat_col_series.get_values(), + cat_col_series.codes): + + a = AugmentorImage(image_path=image_path, output_directory=abs_output_directory) + a.class_label = cat_name + a.class_label_int = cat_id + categorical_label = np.zeros(len(class_labels), dtype=np.uint32) + categorical_label[cat_id] = 1 + a.categorical_label = categorical_label + a.file_format = os.path.splitext(image_path)[1].split(".")[1] + augmentor_images.append(a) + + return augmentor_images, class_labels + def scan_directory(source_directory): """ diff --git a/Augmentor/Pipeline.py b/Augmentor/Pipeline.py index 004ec8d..f591f02 100644 --- a/Augmentor/Pipeline.py +++ b/Augmentor/Pipeline.py @@ -16,7 +16,7 @@ from builtins import * from .Operations import * -from .ImageUtilities import scan_directory, scan, AugmentorImage +from .ImageUtilities import scan_directory, scan, scan_dataframe, AugmentorImage import os import sys @@ -128,6 +128,14 @@ def _populate(self, source_directory, output_directory, ground_truth_directory, # Scan the directory that user supplied. self.augmentor_images, self.class_labels = scan(source_directory, abs_output_directory) + self._check_images(abs_output_directory) + + def _check_images(self, abs_output_directory): + """ + Private method. Used to check and get the dimensions of all of the images + :param abs_output_directory: the absolute path of the output directory + :return: + """ # Make output directory/directories if len(set(self.class_labels)) <= 1: # Fixed bad bug by adding set() function here. if not os.path.exists(abs_output_directory): @@ -142,7 +150,6 @@ def _populate(self, source_directory, output_directory, ground_truth_directory, os.makedirs(os.path.join(abs_output_directory, str(class_label[0]))) except IOError: print("Insufficient rights to read or write output directory (%s)" % abs_output_directory) - # Check the images, read their dimensions, and remove them if they cannot be read # TODO: Do not throw an error here, just remove the image and continue. for augmentor_image in self.augmentor_images: @@ -1526,3 +1533,48 @@ def get_ground_truth_paths(self): paths.append((augmentor_image.image_path, augmentor_image.ground_truth)) return paths + +class DataFramePipeline(Pipeline): + def __init__(self, source_dataframe, image_col, category_col, output_directory="output", save_format=None): + """ + Create a new Pipeline object pointing to dataframe containing the paths + to your original image dataset. + + Create a new Pipeline object, using the :attr:`source_dataframe` + and the columns :attr:`image_col` for the path of the image and + :attr:`category_col` for the name of the cateogry + + :param source_dataframe: A Pandas DataFrame where the images are located + :param output_directory: Specifies where augmented images should be + saved to the disk. Default is the absolute path + :param save_format: The file format to use when saving newly created, + augmented images. Default is JPEG. Legal options are BMP, PNG, and + GIF. + :return: A :class:`Pipeline` object. + """ + super(DataFramePipeline, self).__init__(source_directory = None, + output_directory=output_directory, + save_format=save_format) + self._populate(source_dataframe, + image_col, + category_col, + output_directory, + save_format) + + def _populate(self, + source_dataframe, + image_col, + category_col, + output_directory, + save_format): + # Assume we have an absolute path for the output + # Scan the directory that user supplied. + self.augmentor_images, self.class_labels = scan_dataframe(source_dataframe, + image_col, + category_col, + output_directory) + + self._check_images(output_directory) + + + diff --git a/Augmentor/__init__.py b/Augmentor/__init__.py index 9228d08..d6c2b9a 100644 --- a/Augmentor/__init__.py +++ b/Augmentor/__init__.py @@ -11,10 +11,10 @@ """ -from .Pipeline import Pipeline +from .Pipeline import Pipeline, DataFramePipeline __author__ = """Marcus D. Bloice""" __email__ = 'marcus.bloice@medunigraz.at' __version__ = '0.2.0' -__all__ = ['Pipeline'] +__all__ = ['Pipeline', 'DataFramePipeline'] diff --git a/notebooks/Augmentor_Keras_DataFrame.ipynb b/notebooks/Augmentor_Keras_DataFrame.ipynb new file mode 100644 index 0000000..6e27ecc --- /dev/null +++ b/notebooks/Augmentor_Keras_DataFrame.ipynb @@ -0,0 +1,505 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Training a Neural Network using Augmentor and Keras\n", + "\n", + "In this notebook, we will train a simple convolutional neural network on the MNIST dataset using Augmentor to augment images on the fly using a generator.\n", + "\n", + "## Import Required Libraries\n", + "\n", + "We start by making a number of imports:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import os, sys\n", + "sys.path.append('..')" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/anaconda3/envs/pyqae_base/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n", + " from ._conv import register_converters as _register_converters\n", + "Using TensorFlow backend.\n" + ] + } + ], + "source": [ + "import Augmentor\n", + "\n", + "import keras\n", + "from keras.models import Sequential\n", + "from keras.layers import Dense, Dropout, Flatten\n", + "from keras.layers import Conv2D, MaxPooling2D\n", + "\n", + "import numpy as np\n", + "\n", + "import matplotlib.pyplot as plt\n", + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define a Convolutional Neural Network\n", + "\n", + "Once the libraries have been imported, we define a small convolutional neural network. See the Keras documentation for details of this network: \n", + "\n", + "It is a three layer deep neural network, consisting of 2 convolutional layers and a fully connected layer:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "num_classes = 10\n", + "input_shape = (28, 28, 1)\n", + "\n", + "model = Sequential()\n", + "model.add(Conv2D(32, kernel_size=(3, 3),\n", + " activation='relu',\n", + " input_shape=input_shape))\n", + "model.add(Conv2D(64, (3, 3), activation='relu'))\n", + "model.add(MaxPooling2D(pool_size=(2, 2)))\n", + "model.add(Dropout(0.25))\n", + "model.add(Flatten())\n", + "model.add(Dense(128, activation='relu'))\n", + "model.add(Dropout(0.5))\n", + "model.add(Dense(num_classes, activation='softmax'))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Once a network has been defined, you can compile it so that the model is ready to be trained with data:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "model.compile(loss=keras.losses.categorical_crossentropy,\n", + " optimizer=keras.optimizers.Adadelta(),\n", + " metrics=['accuracy'])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can view a summary of the network using the `summary()` function:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "_________________________________________________________________\n", + "Layer (type) Output Shape Param # \n", + "=================================================================\n", + "conv2d_1 (Conv2D) (None, 26, 26, 32) 320 \n", + "_________________________________________________________________\n", + "conv2d_2 (Conv2D) (None, 24, 24, 64) 18496 \n", + "_________________________________________________________________\n", + "max_pooling2d_1 (MaxPooling2 (None, 12, 12, 64) 0 \n", + "_________________________________________________________________\n", + "dropout_1 (Dropout) (None, 12, 12, 64) 0 \n", + "_________________________________________________________________\n", + "flatten_1 (Flatten) (None, 9216) 0 \n", + "_________________________________________________________________\n", + "dense_1 (Dense) (None, 128) 1179776 \n", + "_________________________________________________________________\n", + "dropout_2 (Dropout) (None, 128) 0 \n", + "_________________________________________________________________\n", + "dense_2 (Dense) (None, 10) 1290 \n", + "=================================================================\n", + "Total params: 1,199,882\n", + "Trainable params: 1,199,882\n", + "Non-trainable params: 0\n", + "_________________________________________________________________\n" + ] + } + ], + "source": [ + "model.summary()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Use Augmentor with a DataFrame\n", + "\n", + "Now we will use Augmentor from a DataFrame to train the model. We can create a DataFrame from the download data using the glob command and then parsing pieces of the path\n", + "\n", + "\n", + "To get the data, we can use `wget` (this may not work under Windows):" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "if not os.path.exists('mnist_png'):\n", + " !wget https://rawgit.com/myleott/mnist_png/master/mnist_png.tar.gz\n", + " !tar -xf mnist_png.tar.gz" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "After the MNIST data has downloaded, we can instantiate a `Pipeline` object in the `training` directory to add the images to the current pipeline:" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
pathdata_splitmnist_cat
0mnist_png/training/9/36655.pngtraining9
1mnist_png/training/9/32433.pngtraining9
2mnist_png/training/9/28319.pngtraining9
3mnist_png/training/9/4968.pngtraining9
4mnist_png/training/9/23502.pngtraining9
\n", + "
" + ], + "text/plain": [ + " path data_split mnist_cat\n", + "0 mnist_png/training/9/36655.png training 9\n", + "1 mnist_png/training/9/32433.png training 9\n", + "2 mnist_png/training/9/28319.png training 9\n", + "3 mnist_png/training/9/4968.png training 9\n", + "4 mnist_png/training/9/23502.png training 9" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from glob import glob\n", + "import pandas as pd\n", + "import os\n", + "image_df = pd.DataFrame(dict(path = glob('mnist_png/*/*/*.png')))\n", + "image_df['data_split'] = image_df['path'].map(lambda x: x.split('/')[-3])\n", + "image_df['mnist_cat'] = image_df['path'].map(lambda x: x.split('/')[-2])\n", + "image_df.head(5)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Initialised with 60000 image(s) found.\n", + "Output directory set to output." + ] + } + ], + "source": [ + "p = Augmentor.DataFramePipeline(image_df.query('data_split==\"training\"'), \n", + " image_col = 'path', \n", + " category_col = 'mnist_cat')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Add Operations to the Pipeline\n", + "\n", + "Now that a pipeline object `p` has been created, we can add operations to the pipeline. Below we add several simple operations:" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "p.flip_top_bottom(probability=0.1)\n", + "p.rotate(probability=0.3, max_left_rotation=5, max_right_rotation=5)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can view the status of pipeline using the `status()` function, which shows information regarding the number of classes in the pipeline, the number of images, and what operations have been added to the pipeline:" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Operations: 2\n", + "\t0: Flip (probability=0.1 top_bottom_left_right=TOP_BOTTOM )\n", + "\t1: RotateRange (probability=0.3 max_left_rotation=-5 max_right_rotation=5 )\n", + "Images: 60000\n", + "Classes: 10\n", + "\tClass index: 0 Class label: 0 \n", + "\tClass index: 1 Class label: 1 \n", + "\tClass index: 2 Class label: 2 \n", + "\tClass index: 3 Class label: 3 \n", + "\tClass index: 4 Class label: 4 \n", + "\tClass index: 5 Class label: 5 \n", + "\tClass index: 6 Class label: 6 \n", + "\tClass index: 7 Class label: 7 \n", + "\tClass index: 8 Class label: 8 \n", + "\tClass index: 9 Class label: 9 \n", + "Dimensions: 1\n", + "\tWidth: 28 Height: 28\n", + "Formats: 1\n", + "\t PNG\n", + "\n", + "You can remove operations using the appropriate index and the remove_operation(index) function.\n" + ] + } + ], + "source": [ + "p.status()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Creating a Generator\n", + "\n", + "A generator will create images indefinitely, and we can use this generator as input into the model created above. The generator is created with a user-defined batch size, which we define here in a variable named `batch_size`. This is used later to define number of steps per epoch, so it is best to keep it stored as a variable." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "batch_size = 128\n", + "g = p.keras_generator(batch_size=batch_size)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The generator can now be used to created augmented data. In Python, generators are invoked using the `next()` function - the Augmentor generators will return images indefinitely, and so `next()` can be called as often as required. \n", + "\n", + "You can view the output of generator manually:" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "images, labels = next(g)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Images, and their labels, are returned in batches of the size defined above by `batch_size`. The `image_batch` variable is a tuple, containing the augmentented images and their corresponding labels.\n", + "\n", + "To see the label of the first image returned by the generator you can use the array's index:" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[0 0 0 0 0 0 0 1 0 0]\n" + ] + } + ], + "source": [ + "print(labels[0])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Or preview the images using Matplotlib (the image should be a 5, according to the label information above):" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADm5JREFUeJzt3X+sVPWZx/HPA4IoENQwWCK4lxKz0RCXrhMwcbO6VBtqmiAxNcWkuWsIaKxmmzS6SEwgmkUltqx/rDWXhRRiy49IUTS4W39s1CabxsGYKsvuVuVui1wul2DsbVR+PvvHPZhbvPOdYebMnLk871di7sx5zvfOw8TPPTPzPWe+5u4CEM+YohsAUAzCDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gqAva+WBTp071rq6udj4kEEpvb6+OHDli9ezbVPjNbKGkpySNlfSv7v54av+uri5VKpVmHhJAQrlcrnvfhl/2m9lYSf8i6duSrpG0xMyuafT3AWivZt7zz5P0gbt/5O7HJW2VtCiftgC0WjPhv0LSH4bdP5Bt+zNmttzMKmZWGRgYaOLhAOSpmfCP9KHCV64Pdvcedy+7e7lUKjXxcADy1Ez4D0iaOez+DEkHm2sHQLs0E/63JV1lZrPMbLyk70nalU9bAFqt4ak+dz9pZvdJ+ncNTfVtdPe9uXUGoKWamud3992SdufUC4A24vReICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QFOEHgmpqlV4z65U0KOmUpJPuXs6jKQCt11T4M3/n7kdy+D0A2oiX/UBQzYbfJf3KzPaY2fI8GgLQHs2+7L/B3Q+a2TRJr5jZf7v7m8N3yP4oLJekK6+8ssmHA5CXpo787n4w+3lY0k5J80bYp8fdy+5eLpVKzTwcgBw1HH4zm2hmk8/clvQtSe/n1RiA1mrmZf/lknaa2Znf8wt3/7dcugLQcg2H390/kvRXOfYCoI2Y6gOCIvxAUIQfCIrwA0ERfiAowg8ElcdVfUBVH3/8cdXawMBAcuzcuXPzbudLa9euTda3bt2arL/++uvJ+iWXXHLOPbUbR34gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCIp5fiQdO3YsWV+2bFmy/txzz1Wtff7558mxDz74YLK+aNGiZP2ZZ56pWtuxY0dy7IQJE5L1MWNG/3Fz9P8LADSE8ANBEX4gKMIPBEX4gaAIPxAU4QeCYp7/PODuVWvZugpVDQ4OJut33XVXsv78888n62vWrKlaO3ToUHLspk2bkvVa1+Q3Y/Lkycn6BReM/uhw5AeCIvxAUIQfCIrwA0ERfiAowg8ERfiBoGpOVprZRknfkXTY3edk2y6TtE1Sl6ReSXe4+yeta/P8durUqWT96NGjDY8fO3ZscuyCBQuS9SlTpiTrH374YbI+Y8aMqrXU+QmS9Oijjybr+/fvT9a3bdtWtfbkk08mx65evTpZHz9+fLI+GtRz5P+ZpIVnbVsh6TV3v0rSa9l9AKNIzfC7+5uSzj70LJJ05vSrTZJuy7kvAC3W6Hv+y929T5Kyn9PyawlAO7T8Az8zW25mFTOr1FqbDUD7NBr+fjObLknZz8PVdnT3Hncvu3u5VCo1+HAA8tZo+HdJ6s5ud0t6IZ92ALRLzfCb2RZJ/ynpL83sgJktlfS4pFvM7HeSbsnuAxhFas7zu/uSKqVv5txLWM8++2yyvnTp0mQ9Nc9/0UUXJcc+8MADyfrDDz+crI8bNy5Zb0ata+a7urqS9c2bN1etzZkzJzm2u7s7Wed6fgCjFuEHgiL8QFCEHwiK8ANBEX4gqNE/XzEK1DqtecWK9EWRp0+fTtZTy0XPnz8/Ofb+++9P1ls5lVdLrX/3Qw89lKynLoXevn17cmytKdLzAUd+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiKef42ePrpp5P1/v7+pn7/rbfeWrW2fv365NipU6c29djNqPXV3Vu2bEnWN2zYkKw/8cQTVWvlcjk5NgKO/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QFPP8Ofj000+T9XXr1jX1+2fNmpWsp5aivvjii5t67Fb65JP0qu7Lli1L1u+5556G67WWLo+AIz8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBFVznt/MNkr6jqTD7j4n27Za0jJJZ76QfqW7725Vk50g9R3ya9euTY6tdR7A1Vdfnazv2bMnWe/k75g/efJk1dqNN96YHDt+/Phk/ZFHHknWi1xzYDSo58j/M0kLR9i+zt3nZv+d18EHzkc1w+/ub0qqvvQJgFGpmff895nZb81so5ldmltHANqi0fD/VNJsSXMl9Un6cbUdzWy5mVXMrFJrzToA7dNQ+N29391PuftpSeslzUvs2+PuZXcvl0qlRvsEkLOGwm9m04fdXSzp/XzaAdAu9Uz1bZF0k6SpZnZA0ipJN5nZXEkuqVfS3S3sEUAL1Ay/uy8ZYXP6C9PPQydOnKhae/nll5Njr7/++mT91VdfTdY7eR7/+PHjyfrtt99etXbo0KHk2FrnN0yaNClZRxpn+AFBEX4gKMIPBEX4gaAIPxAU4QeC4qu765S6PLTWMtgTJ05sql6kWsto7927N1l/6aWXqtbWrFmTHDt79uxkHc3hyA8ERfiBoAg/EBThB4Ii/EBQhB8IivADQTHPX6cxY6r/nbzuuuva2Em+as3j79+/P1lfsGBBsj5z5syqtXvvvTc5Fq3FkR8IivADQRF+ICjCDwRF+IGgCD8QFOEHgmKe/zxXax6/r68vWZ8/f36yfuGFFybrb7zxRtXalClTkmPRWhz5gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiComvP8ZjZT0mZJX5N0WlKPuz9lZpdJ2iapS1KvpDvc/ZPWtYpGHD16NFm/9tprk/Vjx44l6zt37kzWZ82alayjOPUc+U9K+pG7Xy3pekk/MLNrJK2Q9Jq7XyXptew+gFGiZvjdvc/d38luD0raJ+kKSYskbcp22yTptlY1CSB/5/Se38y6JH1D0m8kXe7ufdLQHwhJ0/JuDkDr1B1+M5skaYekH7r7H89h3HIzq5hZZWBgoJEeAbRAXeE3s3EaCv7P3f2X2eZ+M5ue1adLOjzSWHfvcfeyu5dLpVIePQPIQc3wm5lJ2iBpn7v/ZFhpl6Tu7Ha3pBfybw9Aq9RzSe8Nkr4v6T0zezfbtlLS45K2m9lSSb+X9N3WtIhaBgcHq9YWL17c8FhJ6unpSdZvvvnmZB2dq2b43f3XkqxK+Zv5tgOgXTjDDwiK8ANBEX4gKMIPBEX4gaAIPxAUX909Cnz22WfJ+qpVq6rW3nrrreTYxx57LFnv7u5O1jF6ceQHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaCY5+8AJ06cSNZ3796drK9bt65qbeHChcmxK1bwpctRceQHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaCY528Dd0/W9+7dm6zffffdyfqdd95ZtbZhw4bkWMTFkR8IivADQRF+ICjCDwRF+IGgCD8QFOEHgqo5z29mMyVtlvQ1Sacl9bj7U2a2WtIySQPZrivdPX3heVD9/f3J+uLFi5P1adOmJes9PT1VaxMmTEiORVz1nORzUtKP3P0dM5ssaY+ZvZLV1rn7k61rD0Cr1Ay/u/dJ6stuD5rZPklXtLoxAK11Tu/5zaxL0jck/SbbdJ+Z/dbMNprZpVXGLDeziplVBgYGRtoFQAHqDr+ZTZK0Q9IP3f2Pkn4qabakuRp6ZfDjkca5e4+7l929XCqVcmgZQB7qCr+ZjdNQ8H/u7r+UJHfvd/dT7n5a0npJ81rXJoC81Qy/mZmkDZL2uftPhm2fPmy3xZLez789AK1Sz6f9N0j6vqT3zOzdbNtKSUvMbK4kl9QrKX3d6Xns+PHjyfrWrVuT9S+++CJZf/HFF5P1iRMnJuvASOr5tP/XkmyEEnP6wCjGGX5AUIQfCIrwA0ERfiAowg8ERfiBoKzW10rnqVwue6VSadvjAdGUy2VVKpWRpua/giM/EBThB4Ii/EBQhB8IivADQRF+ICjCDwTV1nl+MxuQ9H/DNk2VdKRtDZybTu2tU/uS6K1Refb2F+5e1/fltTX8X3lws4q7lwtrIKFTe+vUviR6a1RRvfGyHwiK8ANBFR3+6utMFa9Te+vUviR6a1QhvRX6nh9AcYo+8gMoSCHhN7OFZvY/ZvaBma0ooodqzKzXzN4zs3fNrNDrj7Nl0A6b2fvDtl1mZq+Y2e+ynyMuk1ZQb6vN7OPsuXvXzG4tqLeZZvYfZrbPzPaa2T9k2wt97hJ9FfK8tf1lv5mNlfS/km6RdEDS25KWuPt/tbWRKsysV1LZ3QufEzazv5X0J0mb3X1Otm2tpKPu/nj2h/NSd//HDulttaQ/Fb1yc7agzPThK0tLuk3S36vA5y7R1x0q4Hkr4sg/T9IH7v6Rux+XtFXSogL66Hju/qako2dtXiRpU3Z7k4b+52m7Kr11BHfvc/d3stuDks6sLF3oc5foqxBFhP8KSX8Ydv+AOmvJb5f0KzPbY2bLi25mBJdny6afWT59WsH9nK3mys3tdNbK0h3z3DWy4nXeigj/SF8x1ElTDje4+19L+rakH2Qvb1GfulZubpcRVpbuCI2ueJ23IsJ/QNLMYfdnSDpYQB8jcveD2c/Dknaq81Yf7j+zSGr283DB/Xypk1ZuHmllaXXAc9dJK14XEf63JV1lZrPMbLyk70naVUAfX2FmE7MPYmRmEyV9S523+vAuSd3Z7W5JLxTYy5/plJWbq60srYKfu05b8bqQk3yyqYx/ljRW0kZ3/6e2NzECM/u6ho720tAipr8osjcz2yLpJg1d9dUvaZWk5yVtl3SlpN9L+q67t/2Dtyq93aShl65frtx85j12m3v7G0lvSXpP0uls80oNvb8u7LlL9LVEBTxvnOEHBMUZfkBQhB8IivADQRF+ICjCDwRF+IGgCD8QFOEHgvp/c4EhwmoWEhkAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plt.imshow(images[0].reshape(28, 28), cmap=\"Greys\");" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Train the Network\n", + "\n", + "We train the network by passing the generator, `g`, to the model's fit function. In Keras, if a generator is used we used the `fit_generator()` function as opposed to the standard `fit()` function. Also, the steps per epoch should roughly equal the total number of images in your dataset divided by the `batch_size`.\n", + "\n", + "Training the network over 5 epochs, we get the following output:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1/5\n", + "156/468 [========>.....................] - ETA: 3:55 - loss: 0.7015 - acc: 0.7725" + ] + } + ], + "source": [ + "h = model.fit_generator(g, steps_per_epoch=len(p.augmentor_images)/batch_size, epochs=5, verbose=1)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Summary\n", + "\n", + "Using Augmentor with Keras means only that you need to create a generator when you are finished creating your pipeline. This has the advantage that no images need to be saved to disk and are augmented on the fly." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python [conda env:pyqae_base]", + "language": "python", + "name": "conda-env-pyqae_base-py" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.4" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tests/test_load.py b/tests/test_load.py index 599829e..f8e5936 100644 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -148,6 +148,53 @@ def test_initialise_with_ten_images(): # not delete itself after closing automatically shutil.rmtree(tmpdir) +def test_dataframe_initialise_with_ten_images(): + import pandas as pd + tmpdir = tempfile.mkdtemp() + tmps = [] + + for i in range(10): + tmps.append(tempfile.NamedTemporaryFile(dir=tmpdir, suffix='.JPEG')) + + bytestream = io.BytesIO() + + im = Image.new('RGB', (800, 800)) + im.save(bytestream, 'JPEG') + + tmps[i].file.write(bytestream.getvalue()) + tmps[i].flush() + temp_df = pd.DataFrame(path = tmps, cat_id = [len(i) for i in tmps]) + p = Augmentor.DataFramePipeline(temp_df, + image_col = 'path', + category_col='cat_id') + assert len(p.augmentor_images) == len(tmps) + + # Check if we can re-read all these images using PIL. + # This will fail for Windows, as you cannot open a file that is already open. + if os.name != "nt": + for i in range(len(tmps)): + im = Image.open(p.augmentor_images[i].image_path) + assert im is not None + + # Check if the paths found during the scan are exactly the paths + # stored by Augmentor after initialisation + for i in range(len(tmps)): + p_paths = [x.image_path for x in p.augmentor_images] + assert tmps[i].name in p_paths + + # Check if all the paths stored by the Pipeline object + # actually exist and are valid paths + for i in range(len(tmps)): + assert os.path.exists(p.augmentor_images[i].image_path) + + # Close all temporary files which will also delete them automatically + for i in range(len(tmps)): + tmps[i].close() + + # Finally remove the directory (and everything in it) as mkdtemp does + # not delete itself after closing automatically + shutil.rmtree(tmpdir) + def test_class_image_scan(): # Some constants