-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmain.py
executable file
·653 lines (460 loc) · 18.6 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
#!/usr/bin/env python3
# http://flask.pocoo.org/
from flask import Flask
# http://flask.pocoo.org/docs/1.0/api/#flask.request
from flask import request
# http://flask.pocoo.org/docs/1.0/tutorial/templates/
from flask import render_template
from flask import redirect, Response
from werkzeug.utils import secure_filename
from keras.callbacks import LambdaCallback
# to save files
import os
import logging
import time
from datetime import datetime
# for json responses
import json
# for multithreading
from threading import Thread
# import network setup from parent directory
# so that we can start training a network
import sys
parentPath = os.path.abspath("..")
# add parent path to sys.path for import
if parentPath not in sys.path:
sys.path.insert(0, parentPath)
# import this module from the parent directory
from network_setup import externalSetup
# import midi parser
from midi_parser.parse_midi import MIDI_Converter
# import server settings (global variables)
from settings import *
# for error handling
# see: https://docs.python.org/dev/library/traceback.html
# we can also use use: https://docs.python.org/2/library/repr.html
import traceback
# for listing the midi files in the result folder
import glob
# use flask-socketio for socket connection
from flask_socketio import SocketIO, send, emit
app = Flask(__name__)
# settings the key allows us to use sessions in a flask application
app.config['SECRET_KEY'] = 'secret!'
# start new socketio instance for the application
socketio = SocketIO(app)
# TRAINING_STATUS contains:
# - status (converting/training/failure)
# - finished
# - error (if errors occured)
# - epoch
# - epochs (total amount)
TRAINING_STATUS = {}
TRAINING_THREAD = None
# holds start time in string format
TIMESTAMP_SERVER_START = None
# server logger
SVR_LOGGER = None
EPOCH_START = None
EPOCH_DURATIONS = []
LOSS_ALL = []
def main():
global SVR_LOGGER
# get current time for log-files and more in string format
TIMESTAMP_SERVER_START = getTimestampNow()
###### logging configuration ######
# create formatter
logFormat = '%(asctime)s - [%(levelname)s]: %(message)s'
logDateFormat = '%m/%d/%Y %I:%M:%S %p'
formatter = logging.Formatter(fmt=logFormat, datefmt=logDateFormat)
# create console handler (to log to console as well)
ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG if DEBUG else logging.INFO)
ch.setFormatter(formatter)
# validate log path and filename
logPath = LOG_FOLDER
if logPath is None or logPath == "":
logPath = "./"
if not logPath.endswith("/"):
logPath += "/"
if not os.path.exists(logPath):
print('[SERVER] Missing path "{}" - creating it.'.format(logPath))
os.makedirs(logPath)
logFileName = LOG_FILENAME
if logFileName is None or logFileName == "":
logFileName = "log_{}.log"
# place timestamp
logFileName = logFileName.format(TIMESTAMP_SERVER_START)
# configure logging, level=DEBUG => log everything
logging.basicConfig(filename=logPath+logFileName, level=logging.DEBUG, format=logFormat, datefmt=logDateFormat)
# get the logger
SVR_LOGGER = logging.getLogger('musicnet-webservicelogger')
SVR_LOGGER.addHandler(ch)
SVR_LOGGER.debug('Logger started.')
###### logging configuration ######
# configure application
app.debug = DEBUG
app.jinja_env.trim_blocks = True # disable jinja2 empty lines
app.jinja_env.lstrip_blocks = True
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
#app.run(threaded=True, host='0.0.0.0', port=PORT)
socketio.run(app, host='0.0.0.0', port=PORT)
# injects all the setting variables
# see http://flask.pocoo.org/docs/1.0/templating/#context-processors
@app.context_processor
def inject_settings():
''' Will inject the SETTINGS dict for the template engine. '''
return dict(SETTINGS)
@app.route("/", methods=["GET", "POST"])
def submit():
''' Serves for the main interface and submit requests. '''
if request.method == "GET":
return render_template('index.html',
title=TITLE,
accept=ACCEPTED_FILE_EXTENSIONS
)
if request.method == "POST":
SVR_LOGGER.debug("Got files: {}".format(request.files))
SVR_LOGGER.debug("Got settings: {}".format(request.form))
if len(request.files) <= 0:
# TODO: redirect to main page and insert error
return "No files!"
#filename = secure_filename(file.filename)
#file.save(os.path.join(app.config['UPLOAD_FOLDER'], filename)))
filePaths = validateFiles(request.files.getlist("file"))
SVR_LOGGER.info("Validated files: {}".format(filePaths))
uploaded = len(filePaths)
SVR_LOGGER.info("Files uploaded: {}\n{}".format(uploaded, filePaths))
if uploaded <= 0:
# TODO: redirect to main page and insert error
return "No files!"
# validate settings
settings = validateSettings(settings_in=request.form)
if not isinstance(settings, dict):
# an error occured
return "Settings are invalid! ({})".format(settings)
# get only the paths (not the name which is currently the key)
filePathsList = []
for key in filePaths.keys():
filePathsList.append(filePaths[key])
# add filepaths to settings
settings['filepaths'] = filePathsList
SVR_LOGGER.info("Using settings: {}".format(settings))
# start new thread to train the network
global TRAINING_THREAD
if not TRAINING_THREAD is None:
SVR_LOGGER.debug("User tried to start a new training process but am still training...")
# TODO: nice error website or navigate to training
return "There is still a training running..\nPlease wait before starting a new one..."
TRAINING_THREAD = Thread(target=train_network, kwargs=dict(settings=settings))
TRAINING_THREAD.daemon = True
TRAINING_THREAD.start()
# redirect to training page
return redirect("./training", code=303)
@app.route("/training", methods=["GET"])
def training():
''' Serves the training interface. '''
return render_template('training.html',
title=TITLE + " - Training"
)
@app.route("/training/status", methods=["GET"])
def training_state():
''' Returns the current training status in JSON format. '''
return jsonResponse(TRAINING_STATUS)
@app.route("/results")
def results():
''' Returns the file names of all midi results as a JSON list. '''
return jsonResponse({'results': getResultFilepaths()})
@socketio.on("connect")
def client_connected():
# see http://flask.pocoo.org/docs/0.12/api/#flask.Request
# see also https://tedboy.github.io/flask/generated/generated/flask.Request.html#attributes
SVR_LOGGER.info("Client connected! ({})".format(request.remote_addr))
# broadcast the status to all clients
socketio.emit("status", TRAINING_STATUS, broadcast=True, include_self=True)
# send current result list
broadcastResultFiles()
@socketio.on("disconnect")
def client_connected():
# see http://flask.pocoo.org/docs/0.12/api/#flask.Request
# see also https://tedboy.github.io/flask/generated/generated/flask.Request.html#attributes
SVR_LOGGER.info("Client disconnected! ({})".format(request.remote_addr))
@socketio.on("status")
def socket_getTrainingStatus():
''' Sends status to clients. '''
#SVR_LOGGER.info("Got a socket message: {}".format(message))
emit("status", TRAINING_STATUS, json=True)
def broadcastResultFiles():
''' Sends the list of files to all clients. '''
socketio.emit("results", {'results': getResultFilepaths()}, json=True, broadcast=True)
def getTimestampNow():
''' Returns a formatted timestamp. '''
return datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
def jsonResponse(data):
''' Returns a JSON response for the given dictionary data. '''
return Response(
response=json.dumps(data),
status=200,
mimetype='application/json'
)
def train_network(settings):
''' Function that will run in a separate thread to train the network. '''
global TRAINING_THREAD
global TRAINING_STATUS
# set result path to be None
resultMidiPath = None
# try to convert midi files to json
try:
TRAINING_STATUS = {} # clear everything
filePaths_midi = settings['filepaths']
TRAINING_STATUS['status'] = "converting"
filePath_json = convertMidiFiles(filePaths_midi)
except Exception as e:
errmsg = "Failed to convert MIDI to JSON! ({})".format(traceback.format_exc())
train_network_error(errmsg, SVR_LOGGER)
return
try:
SVR_LOGGER.info("Training network...")
SVR_LOGGER.info("- Settings: {}".format(settings))
SVR_LOGGER.info("- JSON-Path: {}".format(filePath_json))
# set initial status
trainingInit(settings)
# add additional callbacks for the status updates
callbacks = []
# update callback for epochs, +1 because epochs start at 0
epoch_update_callback = LambdaCallback(
on_epoch_begin=lambda epoch, logs: updateEpochBegin(epoch, logs),
on_epoch_end=lambda epoch, logs: updateEpochEnd(epoch, logs))
callbacks.append(epoch_update_callback)
# check that folders exist is done in the setup
# this will start training the network
resultMidiPath = externalSetup(
logger = SVR_LOGGER,
jsonFilesPath = filePath_json,
weightsOutPath = WEIGHT_FOLDER.format(getTimestampNow()),
midiOutPath = RESULT_FOLDER.format(getTimestampNow()),
settings = settings,
callbacks = callbacks
)
# check if we got results
if resultMidiPath is None or len(resultMidiPath) == 0:
# TODO: handle the error by showing error page
errmsg = "Network delivered no result!"
train_network_error(errmsg, SVR_LOGGER)
except Exception as e:
errmsg = "An unexpected error occured! ({})".format(traceback.format_exc())
train_network_error(errmsg, SVR_LOGGER)
return
# update status
# for pop() see https://docs.python.org/3/library/stdtypes.html#dict.pop
SVR_LOGGER.info("Training finished!")
TRAINING_STATUS['finished'] = True
TRAINING_STATUS['epoch'] = settings['epochs']
TRAINING_STATUS['end'] = getTimestampNow()
TRAINING_STATUS.pop('error', None) # None to prevent KeyError if key not given
# add path to result
if not resultMidiPath is None and len(resultMidiPath) > 0:
TRAINING_STATUS['result'] = resultMidiPath[1:] # remove "."
else:
train_network_error("No result.", SVR_LOGGER)
return
# tell that the thread is done
TRAINING_THREAD = None
# send changed training status to clients
trainingStatusChanged()
# broadcast list of results to the clients
broadcastResultFiles()
def trainingInit(settings):
''' Will clean up and initialize all used variables. '''
global TRAINING_STATUS
global EPOCH_DURATIONS
global LOSS_ALL
TRAINING_STATUS['status'] = "training"
TRAINING_STATUS['finished'] = False
TRAINING_STATUS['epoch'] = 1
TRAINING_STATUS['epochs'] = settings['epochs']
TRAINING_STATUS['songs'] = len(settings['filepaths'])
TRAINING_STATUS['start'] = getTimestampNow()
EPOCH_START = None
EPOCH_DURATIONS = []
LOSS_ALL = []
trainingStatusChanged()
def train_network_error(errmsg, logger=None):
''' Adds an error message to the status and sets the thread to be None. '''
global TRAINING_STATUS
global TRAINING_THREAD
if not logger is None:
logger.error(errmsg)
else:
print(errmsg)
TRAINING_STATUS = {}
TRAINING_STATUS['status'] = "failure"
TRAINING_STATUS['error'] = errmsg
TRAINING_THREAD = None
trainingStatusChanged()
def updateEpochBegin(epoch, logs):
''' Called when epoch starts. '''
global TRAINING_STATUS
global EPOCH_START
SVR_LOGGER.info("[Epoch-Begin]: {}".format(epoch))
TRAINING_STATUS['epoch'] = int(epoch) + 1
EPOCH_START = time.time()
trainingStatusChanged()
def updateEpochEnd(epoch, logs):
''' Called when an epoch ends. '''
global TRAINING_STATUS
global EPOCH_DURATIONS
global LOSS_ALL
# cal
epoch_duration = time.time() - EPOCH_START
EPOCH_DURATIONS.append(epoch_duration)
epoch_avg = sum(EPOCH_DURATIONS) / len(EPOCH_DURATIONS)
epochs_left = TRAINING_STATUS['epochs'] - (epoch + 1)
time_left = round(epoch_avg * epochs_left)
SVR_LOGGER.info("[Epoch-End]: {} , duration: {} , time remaining: {}".format(epoch, epoch_duration, time_left))
TRAINING_STATUS['remaining'] = time_left
# add or remove loss
if 'loss' in logs:
curLoss = round(logs['loss'], 5)
LOSS_ALL.append({'loss': curLoss, 'epoch': epoch})
TRAINING_STATUS['loss'] = {
'current': curLoss,
'all': LOSS_ALL
}
else:
TRAINING_STATUS.pop('loss', None) # None to prevent KeyError if key not given
def trainingStatusChanged():
''' Called when the training status changed. '''
# the following doesnt work for us
#with app.test_request_context('/'):
# socketio.emit("statusUpdate", TRAINING_STATUS, json=True, broadcast=True, include_self=False)
# the following works but is not needed now
# broadcast the status change to all connected clients
socketio.emit("status", TRAINING_STATUS, broadcast=True, include_self=True)
def allowed_file(filename):
''' To validate uploaded files. '''
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
def validateFiles(files):
'''
Validates the uploaded files and returns their paths as a set.
(Style: filename = path)
See http://flask.pocoo.org/docs/1.0/patterns/fileuploads/
'''
SVR_LOGGER.info("Files: {}".format(files))
emptyName = 0
filesOut = {}
# http://werkzeug.pocoo.org/docs/0.14/datastructures/#werkzeug.datastructures.FileStorage
for file in files:
if file.filename == '':
emptyName += 1
continue
#return redirect(request.url)
#return "No file selected!"
if not file:
SVR_LOGGER.warning("File invalid: {}".format(file))
continue
if allowed_file(file.filename):
# get a secure filename
# see here: http://werkzeug.pocoo.org/docs/0.14/utils/#werkzeug.utils.secure_filename
filename = secure_filename(file.filename)
# save file to upload folder
path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
file.save(path)
SVR_LOGGER.info("File saved: {}".format(path))
# save path
filesOut[filename] = path
else:
SVR_LOGGER.warning("Filename not allowed: {}".format(file))
if emptyName > 0:
SVR_LOGGER.warning("Files with empty name: {}".format(emptyName))
return filesOut
def validateSettings(settings_in):
'''
Validates and returns the settings as an array of JSON objects.
(key=value pairs)
'''
settings = {}
for key in SETTINGS['keys']:
setting = settings_in.getlist(key)
# let checkboxes pass because they are not given if unchecked!
if len(setting) <= 0 and not key in SETTINGS['checkboxes']:
return "Missing key {}!".format(key)
# handle and validate key types
if key in SETTINGS['radio']:
value = str(setting[0])
if not value in SETTINGS[key + "_options"]:
SVR_LOGGER.error("Invalid option for key {}".format(str(e)))
return "Invalid option for key {}!".format(key)
elif key in SETTINGS['checkboxes']:
value = False
try:
if len(setting) > 0:
value = bool(setting[0])
except Exception as e:
SVR_LOGGER.error("Exception converting value! {}".format(str(e)))
return "Wrong format for key {}! (not bool)".format(key)
else:
value = 0
try:
valid = False
try:
value = int(setting[0])
valid = True
except Exception as e:
#SVR_LOGGER.warning("Value doesnt match int! ({})".format(str(e)))
valid = False
try:
value = float(setting[0])
valid = True
except Exception as e:
#SVR_LOGGER.warning("Value doesnt match float! ({})".format(str(e)))
valid = False
if valid == False:
raise Exception("Value doesn't match type Integer or Float!")
except Exception as e:
SVR_LOGGER.error("Exception converting value! {}".format(str(e)))
return "Wrong setting format for key {}!".format(key)
if (value < SETTINGS[key + "_min"] or
value > SETTINGS[key + "_max"]):
return "Value for key {} out of bounds!".format(key)
settings[key] = value
SVR_LOGGER.info("Validating key {}={} was successful.".format(key, value))
return settings
def convertMidiFiles(filePaths_midi):
'''
Converts the MIDI files to JSON.
Returns the a path to the folder that contains the converted JSON files.
'''
# path to save these json files to
outPath = JSON_FOLDER.format(getTimestampNow())
converter = MIDI_Converter()
SVR_LOGGER.info("Converting files...")
conResult = converter.convertFiles(
paths=filePaths_midi,
outputPath=outPath,
logger=SVR_LOGGER
)
if conResult['success'] == False:
raise Exception("Failed to convert midi files!")
#return conResult['data'] # list of paths
return outPath
def getResultFilepaths():
''' Returns the file names of all midi results as a JSON list. '''
filePaths = []
mainKey = "results"
if not "RESULT_FOLDER" in globals():
SVR_LOGGER.warning("RESULT_FOLDER variable is not defined!")
return jsonResponse({mainKey: filePaths})
# check result folder path
resultPath = RESULT_FOLDER
if not resultPath.endswith("/"):
resultPath += "/"
if not os.path.exists(resultPath):
SVR_LOGGER.info("Result path does not exist yet but was requested.")
return jsonResponse({mainKey: filePaths})
# get all midi files from this folder and add the paths to the list
for filepath in glob.glob(resultPath + "*.mid"):
filePaths.append(filepath)
return filePaths
if __name__ == '__main__':
main()