Skip to content

Commit

Permalink
Changed image size from 100x100 to 64x64.
Browse files Browse the repository at this point in the history
Implemented save model graph method (needs graphviz packaged!).
  • Loading branch information
SvenMuc committed May 23, 2017
1 parent d026f08 commit a5262dd
Showing 1 changed file with 65 additions and 7 deletions.
72 changes: 65 additions & 7 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
VALIDATION_SET_SIZE = 0.2 # proportion of full dataset used for the test set
BATCH_SIZE = 256 # default training batch size (number of images per batch)
ROI = [20, 60, 320 - 20, 160 - 22] # Region if interest
IMAGE_WIDTH = 100 # Model input image width
IMAGE_HEIGHT = 100 # Model input image height
IMAGE_WIDTH = 64 # Model input image width
IMAGE_HEIGHT = 64 # Model input image height

# augmentation
STEERING_ANGLE_CORRECTION = 6.25 # Steering angle correction for left and right images in degree
Expand Down Expand Up @@ -195,6 +195,62 @@ def train_model(model, dataset_csv_filename, trained_model=None, nb_epochs=7, nb
network.safe_model(model + '_model.h5')


def save_model_graph(model, filename):
""" Saves the model graph to a PNG file.
:param model: Supported models are
- LeNet5
- NvidiaFull
- NvidiaLight (lightweight NVIDIA architecture)
- VGG16
:param filename: Filename of PNG file.
"""

supported_models = ['LeNet5', 'NvidiaFull', 'NvidiaLight', 'VGG16']

# check for valid model
matching = [s for s in supported_models if model in s]

if len(matching) == 0:
print('Not support model. Check help for supported models.')
exit(-1)

if model == 'LeNet5':
# setup LeNet-5 model architecture
network = LeNet(input_depth=3, input_height=IMAGE_HEIGHT, input_width=IMAGE_WIDTH,
regression=True, nb_classes=1,
roi=ROI,
steering_angle_correction=STEERING_ANGLE_CORRECTION,
weights_path=None)
elif model == 'NvidiaFull':
# setup Full NVIDIA CNN model architecture
network = NvidiaFull(input_depth=3, input_height=IMAGE_HEIGHT, input_width=IMAGE_WIDTH,
regression=True, nb_classes=1,
roi=ROI,
steering_angle_correction=STEERING_ANGLE_CORRECTION,
weights_path=None)
elif model == 'NvidiaLight':
# setup lightweight NVIDIA CNN model architecture
network = NvidiaLight(input_depth=3, input_height=IMAGE_HEIGHT, input_width=IMAGE_WIDTH,
regression=True, nb_classes=1,
roi=ROI,
steering_angle_correction=STEERING_ANGLE_CORRECTION,
weights_path=None)
elif model == 'VGG16':
# setup VGG-16 model architecture
network = VGG(input_depth=3, input_height=IMAGE_HEIGHT, input_width=IMAGE_WIDTH,
regression=True, nb_classes=1,
roi=ROI,
steering_angle_correction=STEERING_ANGLE_CORRECTION,
weights_path=None)
else:
print('Not support model. Check help for supported models.')
exit(-1)

print('Safe model graph...', end='', flush=True)
network.save_model_graph(filename)
print('done')

if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Model Training and Evaluation')

Expand Down Expand Up @@ -244,14 +300,15 @@ def train_model(model, dataset_csv_filename, trained_model=None, nb_epochs=7, nb
'-th', '--show_training_history',
help='Plots the training history (train loss and validation loss).',
dest='history_filename',
metavar="FILE"
metavar='FILE'
)

parser.add_argument(
'-smg', '--safe-model-graph',
help='Saves the model graph to a PNG file.',
help='Saves the model graph to a PNG file (1=model name, 2=filename).',
dest='safe_model_graph',
metavar="FILE"
nargs=2,
metavar='ARG'
)

args = parser.parse_args()
Expand Down Expand Up @@ -306,5 +363,6 @@ def train_model(model, dataset_csv_filename, trained_model=None, nb_epochs=7, nb
print('History object file \"{:s}\" not found.'.format(args.history_filename))

elif args.safe_model_graph:
# TODO: safes the model graph to a PNG file
print('Safe model graph')
# safes the model graph to a PNG file
save_model_graph(args.safe_model_graph[0], args.safe_model_graph[1])

0 comments on commit a5262dd

Please sign in to comment.