Skip to content

Commit

Permalink
fintune inception
Browse files Browse the repository at this point in the history
  • Loading branch information
Sai Kiran Burle committed Dec 10, 2017
1 parent 627b5b8 commit 4c9ccef
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 21 deletions.
2 changes: 1 addition & 1 deletion binary_image_classifier/simple_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from keras.models import Sequential
from keras.preprocessing.image import ImageDataGenerator, load_img, img_to_array

TRAIN_PATH = '/Users/sai/dev/catsdogs/data2/train/'
TRAIN_PATH = '/Users/sai/dev/datasets/catsdogs-kaggle/data2/train/'

# Constants
NUM_CHANNELS = 3
Expand Down
86 changes: 66 additions & 20 deletions finetune/inception.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,79 @@
from keras.applications.inception_v3 import InceptionV3, preprocess_input, decode_predictions
from keras.layers import Dense, GlobalAveragePooling2D
from keras.models import Model
from keras.preprocessing.image import ImageDataGenerator

NUM_CLASSES = 2
TRAIN_PATH = '/Users/sai/dev/datasets/catsdogs-kaggle/data2/train/'

# Constants
NUM_CHANNELS = 3
IMG_X = 150
IMG_Y = 150

def batch_generator():
pass
BATCH_SIZE = 16
TOTAL_NUM_IMAGES = 25000

# base pre-trained model
base_model = InceptionV3(include_top=False, weights='imagenet')

# Global
x = base_model.output
x = GlobalAveragePooling2D()(x)
def get_train_data_augmenter():
# real time image augmentation
augmenter = ImageDataGenerator(
# rotation_range=40,
# width_shift_range=0.2,
# height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
fill_mode='nearest',
# rescale=1./255,
preprocessing_function=preprocess_input
)
return augmenter

# Fully connected layer
x = Dense(units=1024, activation='relu')(x)
# Logistic softmax layer
predictions = Dense(NUM_CLASSES, activation='softmax')(x)

model = Model(inputs=base_model.input, outputs=predictions)
def get_train_data_generator(augmenter):
train_generator = augmenter.flow_from_directory(
TRAIN_PATH,
target_size=(IMG_X, IMG_Y),
batch_size=BATCH_SIZE,
class_mode='binary'
)
return train_generator

# Train only the top layers
for layer in base_model.layers:
layer.trainable = False

# Compile the model
model.compile(optimizer='rmsprop', loss='categorical_crossentropy')
def get_model():
# base pre-trained model
base_model = InceptionV3(include_top=False, weights='imagenet')

# Train the model
model.fit_generator()
# Global
x = base_model.output
x = GlobalAveragePooling2D()(x)

# Fully connected layer
x = Dense(units=1024, activation='relu')(x)
# Logistic softmax layer
predictions = Dense(1, activation='sigmoid')(x)

model = Model(inputs=base_model.input, outputs=predictions)

# Train only the top layers
for layer in base_model.layers:
layer.trainable = False

# Compile the model
model.compile(optimizer='rmsprop', loss='binary_crossentropy', metrics=['accuracy'])

return model


def train_model(model, data_gen):
model.fit_generator(
data_gen,
steps_per_epoch=TOTAL_NUM_IMAGES // BATCH_SIZE,
epochs=50
)

if __name__ == "__main__":
augmenter = get_train_data_augmenter()
model = get_model()
train_data_gen = get_train_data_generator(augmenter)
train_model(model, train_data_gen)

0 comments on commit 4c9ccef

Please sign in to comment.