Skip to content

Commit

Permalink
Added multi GPU support (basic version)
Browse files Browse the repository at this point in the history
  • Loading branch information
TatjanaUtz committed Apr 9, 2019
1 parent e008cfd commit 9ddedb6
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 10 deletions.
4 changes: 2 additions & 2 deletions database_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def _check_if_database_exists(self):
This method checks if the database defined by self.db_dir and self.db_name exists. If not,
an exception is raise.
"""
db_path = self.db_dir + "\\" + self.db_name # create full path to database
db_path = os.path.join(self.db_dir, self.db_name) # create full path to database

# check if database exists
if not os.path.exists(db_path): # database doesn't exists: raise exception
Expand Down Expand Up @@ -222,7 +222,7 @@ def _open_db(self):
This methods opens the database defined by self.db_dir and self.db_name by creating a
database connection and a cursor.
"""
db_path = self.db_dir + "\\" + self.db_name # create full path to the database
db_path = os.path.join(self.db_dir, self.db_name) # create full path to the database

# create database connection and a cursor
self.db_connection = sqlite3.connect(db_path)
Expand Down
10 changes: 7 additions & 3 deletions logging_config.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,25 @@
"""Configurations for logging."""

import logging
import os



def init_logging():
def init_logging(db_dir, db_name):
"""Initializes logging.
Configures logging. Error messages are logged to the 'error.log' file. Info messages are logged
to the console. The results are save in a 'result_' log file.
Args:
db_dir -- directory of the database, used to create file for results
db_name -- name of the database, used to create file name for results
"""
# create logger for traditional-SA project
logger = logging.getLogger('RNN-SA')
logger.setLevel(logging.INFO)

# create file handler which logs error messages
log_file_handler = logging.FileHandler('error.log', mode='w+')
log_file_handler = logging.FileHandler(os.path.join(db_dir, 'error.log'), mode='w+')
log_file_handler.setLevel(logging.ERROR)

# create console handler with a lower log level (e.g debug or info)
Expand Down
36 changes: 31 additions & 5 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from random import shuffle # for shuffle of the task-sets

import keras
import tensorflow as tf
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
import numpy as np
import sklearn
Expand All @@ -18,6 +21,7 @@
import logging_config
import ml_models
import params
import os

# default indices of all task attributes (column indices of 'Task')
DEFAULT_FEATURES = ['Task_ID', 'Priority', 'Deadline', 'Quota', 'CAPS', 'PKG', 'Arg', 'CORES',
Expand All @@ -39,10 +43,10 @@
def main():
"""Main function of project 'RNN-SA'."""
# determine database directory and name
db_dir, db_name = "..\\Datenbanken\\", "panda_v2.db"
db_dir, db_name = os.getcwd(), "panda_v3.db"

# create and initialize logger
logger = logging_config.init_logging()
logger = logging_config.init_logging(db_dir, db_name)

# load the data
logger.info("Loading and pre-processing data from the database...")
Expand All @@ -58,16 +62,38 @@ def main():
logger.info("Doing hyperparameter exploration...")
start_time = time.time()
#h = hyperparameter_exploration(data=data, name='LSTM', num='0')
out, model = ml_models.LSTM_model(data['train_X'], data['train_y'], data['val_X'],
data['val_y'], params.hparams)
# out, model = ml_models.LSTM_model(data['train_X'], data['train_y'], data['val_X'],
# data['val_y'], params.hparams)
end_time = time.time()
logger.info("Finished hyperparameter exploration!")
logger.info("Best result: ")
logger.info("Time elapsed: %f s \n", end_time - start_time)

# multi GPU support
model = ml_models._build_LSTM_model(params.hparams, params.config)
try:
parallel_model = keras.utils.multi_gpu_model(model, cpu_relocation=True)
print("Training using multiple GPUs...")
except ValueError:
parallel_model = model
print("Training using single GPU or CPU...")
parallel_model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
out = parallel_model.fit(
x=data['train_X'],
y=data['train_y'],
batch_size=params['batch_size'],
epochs=params['num_epochs'],
verbose=params.config['verbose_training'],
callbacks=ml_models._init_callbacks(params.hparams, params.config),
validation_data=[data['val_X'], data['val_y']],
shuffle=True,
)



# evaluate
loss, acc = model.evaluate(data['test_X'], data['test_y'], batch_size=params.hparams['batch_size'])
print("Loss: %f --- Accuracy: %f", loss, acc)
logger.info("Loss: %f --- Accuracy: %f", loss, acc)


def hyperparameter_exploration(data, name, num):
Expand Down

0 comments on commit 9ddedb6

Please sign in to comment.