Skip to content

Commit

Permalink
use squeezenet
Browse files Browse the repository at this point in the history
  • Loading branch information
TropComplique committed Aug 4, 2017
1 parent 743be98 commit e2304be
Show file tree
Hide file tree
Showing 11 changed files with 1,004 additions and 397 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
.ipynb_checkpoints
__pycache__
*.hdf5
xception_weights.hdf5
21 changes: 12 additions & 9 deletions get_logits_from_xception.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,8 @@
"from image_preprocessing_ver1 import ImageDataGenerator\n",
"# it outputs not only x_batch and y_batch but also image names\n",
"\n",
"from keras.applications.xception import preprocess_input\n",
"from keras.models import Model\n",
"\n",
"from xception import get_xception"
"from xception import Xception, preprocess_input"
]
},
{
Expand Down Expand Up @@ -100,7 +98,7 @@
},
"outputs": [],
"source": [
"model = get_xception()\n",
"model = Xception()\n",
"model.load_weights('xception_weights.hdf5')\n",
"# remove softmax\n",
"model.layers.pop()\n",
Expand All @@ -124,7 +122,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"399it [03:52, 1.71it/s]"
"399it [03:53, 1.70it/s]"
]
}
],
Expand Down Expand Up @@ -158,10 +156,15 @@
"\n",
"0it [00:00, ?it/s]\u001b[A\n",
"1it [00:00, 1.71it/s]\u001b[A\n",
"2it [00:01, 1.73it/s]\u001b[A\n",
"3it [00:01, 1.74it/s]\u001b[A\n",
"4it [00:02, 1.76it/s]\u001b[A\n",
"79it [00:44, 1.78it/s]"
"2it [00:01, 1.71it/s]\u001b[A\n",
"3it [00:01, 1.71it/s]\u001b[A\n",
"4it [00:02, 1.71it/s]\u001b[A\n",
"5it [00:02, 1.71it/s]\u001b[A\n",
"6it [00:03, 1.71it/s]\u001b[A\n",
"7it [00:04, 1.71it/s]\u001b[A\n",
"8it [00:04, 1.71it/s]\u001b[A\n",
"9it [00:05, 1.71it/s]\u001b[A\n",
"79it [00:46, 1.70it/s]"
]
}
],
Expand Down
233 changes: 170 additions & 63 deletions knowledge_distillation.ipynb

Large diffs are not rendered by default.

17 changes: 0 additions & 17 deletions mobilenet.py

This file was deleted.

96 changes: 96 additions & 0 deletions squeezenet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import numpy as np
import keras
from keras.models import Model
from keras.layers import Input, Convolution2D, MaxPooling2D
from keras.layers import Activation, Dropout, GlobalAveragePooling2D, concatenate


mean = np.array([0.485, 0.456, 0.406], dtype='float32')
std = np.array([0.229, 0.224, 0.225], dtype='float32')


def preprocess_input(x):
x /= 255.0
x -= mean
x /= std
return x


# a building block of the SqueezeNet architecture
def fire_module(number, x, squeeze, expand, weight_decay=None, trainable=False):

module_name = 'fire' + number

if trainable and weight_decay is not None:
kernel_regularizer = keras.regularizers.l2(weight_decay)
else:
kernel_regularizer = None

x = Convolution2D(
squeeze, (1, 1),
name=module_name + '/' + 'squeeze',
trainable=trainable,
kernel_regularizer=kernel_regularizer
)(x)
x = Activation('relu')(x)

a = Convolution2D(
expand, (1, 1),
name=module_name + '/' + 'expand1x1',
trainable=trainable,
kernel_regularizer=kernel_regularizer
)(x)
a = Activation('relu')(a)

b = Convolution2D(
expand, (3, 3), padding='same',
name=module_name + '/' + 'expand3x3',
trainable=trainable,
kernel_regularizer=kernel_regularizer
)(x)
b = Activation('relu')(b)

return concatenate([a, b])


def SqueezeNet(weight_decay, image_size=224):

image = Input(shape=(image_size, image_size, 3))

x = Convolution2D(
64, (3, 3), strides=(2, 2), name='conv1',
trainable=False
)(image) # 111, 111, 64

x = Activation('relu')(x)
x = MaxPooling2D(pool_size=(3, 3), strides=(2, 2))(x) # 55, 55, 64

x = fire_module('2', x, squeeze=16, expand=64) # 55, 55, 128
x = fire_module('3', x, squeeze=16, expand=64) # 55, 55, 128
x = MaxPooling2D(pool_size=(3, 3), strides=(2, 2))(x) # 27, 27, 128

x = fire_module('4', x, squeeze=32, expand=128) # 27, 27, 256
x = fire_module('5', x, squeeze=32, expand=128) # 27, 27, 256
x = MaxPooling2D(pool_size=(3, 3), strides=(2, 2))(x) # 13, 13, 256

x = fire_module('6', x, squeeze=48, expand=192) # 13, 13, 384
x = fire_module('7', x, squeeze=48, expand=192) # 13, 13, 384
x = fire_module('8', x, squeeze=64, expand=256) # 13, 13, 512
x = fire_module('9', x, 64, 256, weight_decay, trainable=False) # 13, 13, 512

x = Dropout(0.5)(x)
x = Convolution2D(
256, (1, 1), name='conv10',
kernel_initializer=keras.initializers.RandomNormal(stddev=0.01),
kernel_regularizer=keras.regularizers.l2(weight_decay)
)(x) # 13, 13, 256

x = Activation('relu')(x)
logits = GlobalAveragePooling2D()(x) # 256
probabilities = Activation('softmax')(logits)

model = Model(image, probabilities)
model.load_weights('squeezenet_weights.hdf5', by_name=True)

return model

Binary file added squeezenet_weights.hdf5
Binary file not shown.
127 changes: 70 additions & 57 deletions train_xception.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@
"\n",
"import keras\n",
"from keras import optimizers\n",
"from keras.callbacks import ReduceLROnPlateau, EarlyStopping\n",
"from keras.preprocessing.image import ImageDataGenerator\n",
"from keras.applications.xception import preprocess_input\n",
"\n",
"from xception import get_xception"
"from xception import Xception, preprocess_input"
]
},
{
Expand All @@ -56,25 +56,38 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Found 25600 images belonging to 256 classes.\n",
"Found 16980 images belonging to 256 classes.\n",
"Found 5120 images belonging to 256 classes.\n"
]
}
],
"source": [
"data_generator = ImageDataGenerator(\n",
" rotation_range=30, \n",
" zoom_range=0.3,\n",
" horizontal_flip=True, \n",
" width_shift_range=0.2,\n",
" height_shift_range=0.2,\n",
" shear_range=0.001,\n",
" channel_shift_range=0.1,\n",
" fill_mode='reflect',\n",
" data_format='channels_last',\n",
" preprocessing_function=preprocess_input\n",
")\n",
"\n",
"data_generator_val = ImageDataGenerator(\n",
" data_format='channels_last',\n",
" preprocessing_function=preprocess_input\n",
")\n",
"\n",
"train_generator = data_generator.flow_from_directory(\n",
" data_dir + 'train', \n",
" data_dir + 'train_no_resizing', \n",
" target_size=(299, 299),\n",
" batch_size=64\n",
")\n",
"\n",
"val_generator = data_generator.flow_from_directory(\n",
" data_dir + 'val', \n",
"val_generator = data_generator_val.flow_from_directory(\n",
" data_dir + 'val', shuffle=False,\n",
" target_size=(299, 299),\n",
" batch_size=64\n",
")"
Expand All @@ -95,10 +108,7 @@
},
"outputs": [],
"source": [
"model = get_xception()\n",
"\n",
"for layer in model.layers[:-7]:\n",
" layer.trainable = False"
"model = Xception(weight_decay=1e-5)"
]
},
{
Expand All @@ -117,7 +127,7 @@
"outputs": [],
"source": [
"model.compile(\n",
" optimizer=optimizers.Adam(lr=1e-4), \n",
" optimizer=optimizers.SGD(lr=1e-2, momentum=0.9, nesterov=True), \n",
" loss='categorical_crossentropy', metrics=['accuracy', 'top_k_categorical_accuracy']\n",
")"
]
Expand All @@ -133,32 +143,52 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"400/400 [==============================] - 205s - loss: 3.7803 - acc: 0.4745 - top_k_categorical_accuracy: 0.6484 - val_loss: 1.9980 - val_acc: 0.7826 - val_top_k_categorical_accuracy: 0.9297\n",
"Epoch 2/10\n",
"400/400 [==============================] - 204s - loss: 1.5028 - acc: 0.7947 - top_k_categorical_accuracy: 0.9386 - val_loss: 1.0635 - val_acc: 0.8206 - val_top_k_categorical_accuracy: 0.9473\n",
"Epoch 3/10\n",
"400/400 [==============================] - 204s - loss: 0.8996 - acc: 0.8484 - top_k_categorical_accuracy: 0.9615 - val_loss: 0.8329 - val_acc: 0.8320 - val_top_k_categorical_accuracy: 0.9535\n",
"Epoch 4/10\n",
"400/400 [==============================] - 204s - loss: 0.6544 - acc: 0.8775 - top_k_categorical_accuracy: 0.9739 - val_loss: 0.7266 - val_acc: 0.8402 - val_top_k_categorical_accuracy: 0.9528\n",
"Epoch 5/10\n",
"400/400 [==============================] - 204s - loss: 0.5098 - acc: 0.9024 - top_k_categorical_accuracy: 0.9819 - val_loss: 0.6648 - val_acc: 0.8389 - val_top_k_categorical_accuracy: 0.9567\n",
"Epoch 6/10\n",
"400/400 [==============================] - 204s - loss: 0.4047 - acc: 0.9221 - top_k_categorical_accuracy: 0.9880 - val_loss: 0.6352 - val_acc: 0.8483 - val_top_k_categorical_accuracy: 0.9574\n",
"Epoch 7/10\n",
"400/400 [==============================] - 204s - loss: 0.3253 - acc: 0.9391 - top_k_categorical_accuracy: 0.9919 - val_loss: 0.6081 - val_acc: 0.8503 - val_top_k_categorical_accuracy: 0.9580\n",
"Epoch 8/10\n",
"400/400 [==============================] - 204s - loss: 0.2639 - acc: 0.9524 - top_k_categorical_accuracy: 0.9954 - val_loss: 0.5949 - val_acc: 0.8519 - val_top_k_categorical_accuracy: 0.9590\n",
"Epoch 9/10\n",
"400/400 [==============================] - 204s - loss: 0.2161 - acc: 0.9639 - top_k_categorical_accuracy: 0.9965 - val_loss: 0.5808 - val_acc: 0.8555 - val_top_k_categorical_accuracy: 0.9567\n",
"Epoch 10/10\n",
"400/400 [==============================] - 204s - loss: 0.1775 - acc: 0.9703 - top_k_categorical_accuracy: 0.9979 - val_loss: 0.5569 - val_acc: 0.8610 - val_top_k_categorical_accuracy: 0.9609\n"
"Epoch 1/30\n",
"266/266 [==============================] - 163s - loss: 3.9762 - acc: 0.3249 - top_k_categorical_accuracy: 0.4612 - val_loss: 3.2910 - val_acc: 0.4666 - val_top_k_categorical_accuracy: 0.7197\n",
"Epoch 2/30\n",
"266/266 [==============================] - 161s - loss: 2.2335 - acc: 0.6227 - top_k_categorical_accuracy: 0.8145 - val_loss: 1.8696 - val_acc: 0.6570 - val_top_k_categorical_accuracy: 0.8789\n",
"Epoch 3/30\n",
"266/266 [==============================] - 161s - loss: 1.5226 - acc: 0.7112 - top_k_categorical_accuracy: 0.8911 - val_loss: 1.3368 - val_acc: 0.7346 - val_top_k_categorical_accuracy: 0.9189\n",
"Epoch 4/30\n",
"266/266 [==============================] - 161s - loss: 1.2231 - acc: 0.7527 - top_k_categorical_accuracy: 0.9124 - val_loss: 1.0939 - val_acc: 0.7707 - val_top_k_categorical_accuracy: 0.9289\n",
"Epoch 5/30\n",
"266/266 [==============================] - 161s - loss: 1.0554 - acc: 0.7792 - top_k_categorical_accuracy: 0.9242 - val_loss: 0.9827 - val_acc: 0.7822 - val_top_k_categorical_accuracy: 0.9361\n",
"Epoch 6/30\n",
"266/266 [==============================] - 161s - loss: 0.9525 - acc: 0.7920 - top_k_categorical_accuracy: 0.9331 - val_loss: 0.9031 - val_acc: 0.7932 - val_top_k_categorical_accuracy: 0.9414\n",
"Epoch 7/30\n",
"266/266 [==============================] - 161s - loss: 0.8914 - acc: 0.7995 - top_k_categorical_accuracy: 0.9387 - val_loss: 0.8727 - val_acc: 0.8002 - val_top_k_categorical_accuracy: 0.9402\n",
"Epoch 8/30\n",
"266/266 [==============================] - 161s - loss: 0.8319 - acc: 0.8118 - top_k_categorical_accuracy: 0.9432 - val_loss: 0.8268 - val_acc: 0.8053 - val_top_k_categorical_accuracy: 0.9447\n",
"Epoch 9/30\n",
"266/266 [==============================] - 161s - loss: 0.7794 - acc: 0.8232 - top_k_categorical_accuracy: 0.9491 - val_loss: 0.8082 - val_acc: 0.8102 - val_top_k_categorical_accuracy: 0.9441\n",
"Epoch 10/30\n",
"266/266 [==============================] - 161s - loss: 0.7413 - acc: 0.8303 - top_k_categorical_accuracy: 0.9529 - val_loss: 0.7862 - val_acc: 0.8133 - val_top_k_categorical_accuracy: 0.9477\n",
"Epoch 11/30\n",
"266/266 [==============================] - 161s - loss: 0.7166 - acc: 0.8355 - top_k_categorical_accuracy: 0.9569 - val_loss: 0.7667 - val_acc: 0.8170 - val_top_k_categorical_accuracy: 0.9488\n",
"Epoch 12/30\n",
"266/266 [==============================] - 161s - loss: 0.6792 - acc: 0.8450 - top_k_categorical_accuracy: 0.9599 - val_loss: 0.7592 - val_acc: 0.8191 - val_top_k_categorical_accuracy: 0.9492\n",
"Epoch 13/30\n",
"266/266 [==============================] - 161s - loss: 0.6620 - acc: 0.8479 - top_k_categorical_accuracy: 0.9629 - val_loss: 0.7417 - val_acc: 0.8229 - val_top_k_categorical_accuracy: 0.9500\n",
"Epoch 14/30\n",
"266/266 [==============================] - 161s - loss: 0.6207 - acc: 0.8553 - top_k_categorical_accuracy: 0.9692 - val_loss: 0.7457 - val_acc: 0.8223 - val_top_k_categorical_accuracy: 0.9492\n",
"Epoch 15/30\n",
"266/266 [==============================] - 161s - loss: 0.6173 - acc: 0.8518 - top_k_categorical_accuracy: 0.9659 - val_loss: 0.7210 - val_acc: 0.8271 - val_top_k_categorical_accuracy: 0.9520\n",
"Epoch 16/30\n",
"266/266 [==============================] - 161s - loss: 0.5821 - acc: 0.8644 - top_k_categorical_accuracy: 0.9711 - val_loss: 0.7388 - val_acc: 0.8229 - val_top_k_categorical_accuracy: 0.9492\n",
"Epoch 17/30\n",
"266/266 [==============================] - 161s - loss: 0.5675 - acc: 0.8666 - top_k_categorical_accuracy: 0.9726 - val_loss: 0.7175 - val_acc: 0.8303 - val_top_k_categorical_accuracy: 0.9510\n",
"Epoch 18/30\n",
"266/266 [==============================] - 161s - loss: 0.5485 - acc: 0.8706 - top_k_categorical_accuracy: 0.9749 - val_loss: 0.7128 - val_acc: 0.8307 - val_top_k_categorical_accuracy: 0.9525\n",
"Epoch 19/30\n",
"266/266 [==============================] - 161s - loss: 0.5297 - acc: 0.8741 - top_k_categorical_accuracy: 0.9764 - val_loss: 0.7083 - val_acc: 0.8299 - val_top_k_categorical_accuracy: 0.9527\n",
"Epoch 20/30\n",
"266/266 [==============================] - 161s - loss: 0.5163 - acc: 0.8800 - top_k_categorical_accuracy: 0.9772 - val_loss: 0.7003 - val_acc: 0.8328 - val_top_k_categorical_accuracy: 0.9535\n"
]
},
{
"data": {
"text/plain": [
"<keras.callbacks.History at 0x7f5ce44c3cc0>"
"<keras.callbacks.History at 0x7fcc2655a160>"
]
},
"execution_count": 7,
Expand All @@ -169,8 +199,12 @@
"source": [
"model.fit_generator(\n",
" train_generator, \n",
" steps_per_epoch=400, epochs=10, verbose=1,\n",
" validation_data=val_generator, validation_steps=48\n",
" steps_per_epoch=266, epochs=30, verbose=1,\n",
" callbacks=[\n",
" ReduceLROnPlateau(monitor='val_acc', factor=0.1, patience=2, epsilon=0.007),\n",
" EarlyStopping(monitor='val_acc', patience=4, min_delta=0.01)\n",
" ],\n",
" validation_data=val_generator, validation_steps=80, workers=4\n",
")"
]
},
Expand All @@ -181,27 +215,6 @@
"# Results"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Found 5120 images belonging to 256 classes.\n"
]
}
],
"source": [
"val_generator_no_shuffle = data_generator.flow_from_directory(\n",
" data_dir + 'val', \n",
" target_size=(299, 299),\n",
" batch_size=64, shuffle=False\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 9,
Expand All @@ -210,7 +223,7 @@
{
"data": {
"text/plain": [
"[0.57773708868771789, 0.85429687499999996, 0.95703125]"
"[0.70807909807190295, 0.83027343750000004, 0.953125]"
]
},
"execution_count": 9,
Expand All @@ -219,7 +232,7 @@
}
],
"source": [
"model.evaluate_generator(val_generator_no_shuffle, 80)"
"model.evaluate_generator(val_generator, 80)"
]
},
{
Expand Down
Loading

0 comments on commit e2304be

Please sign in to comment.