@@ -141,6 +141,23 @@ def train(self):
141
141
else :
142
142
print ("Initializing model from scratch." )
143
143
144
+ if self .config .LOAD_PATH and not self .config .TRAIN_PATH :
145
+ model_dirname = self .config .LOAD_PATH
146
+ elif self .config .MODEL_PATH :
147
+ model_dirname = self .config .MODEL_PATH
148
+ else :
149
+ model_dirname = None
150
+ print ('Model directory is missing' )
151
+ exit (- 1 )
152
+
153
+ stats_file_name = os .path .join (model_dirname , "stats.txt" )
154
+ loss_file_name = os .path .join (model_dirname , "avg_loss.txt" )
155
+ try :
156
+ os .remove (stats_file_name )
157
+ os .remove (loss_file_name )
158
+ except OSError :
159
+ pass
160
+
144
161
sum_loss = 0
145
162
batch_num = 0
146
163
epochs_trained = 0
@@ -194,7 +211,7 @@ def train(self):
194
211
batch_num += 1
195
212
196
213
if batch_num % self .num_batches_to_log == 0 :
197
- self .trace (pbar , sum_loss , batch_num , multi_batch_start_time )
214
+ self .trace (pbar , sum_loss , batch_num , multi_batch_start_time , loss_file_name )
198
215
sum_loss = 0
199
216
multi_batch_start_time = time .time ()
200
217
@@ -211,7 +228,12 @@ def train(self):
211
228
checkpoint_manager .save ()
212
229
213
230
# validate model to calculate metrics or stop training
214
- results , precision , recall , f1 , rouge = self .evaluate ()
231
+ results , precision , recall , f1 , rouge = self .evaluate (model_dirname )
232
+
233
+ # Add results to a stats file for later processing of graphs
234
+ with open (stats_file_name , "a+" ) as stats_file :
235
+ stats_file .write ("{0}, {1}, {2}, {3}, {4}\n " .format (epochs_trained , results , precision , recall , f1 ))
236
+
215
237
if self .config .BEAM_WIDTH == 0 :
216
238
print ("Accuracy after %d epochs: %.5f" % (epochs_trained , results ))
217
239
else :
@@ -252,22 +274,15 @@ def train(self):
252
274
% ((elapsed // 60 // 60 ), (elapsed // 60 ) % 60 , elapsed % 60 )
253
275
)
254
276
255
- def evaluate (self ):
277
+ def evaluate (self , model_dirname ):
256
278
if not self .model :
257
279
print ("Model is not initialized" )
258
280
exit (- 1 )
259
281
260
282
print ("Testing..." )
261
283
eval_start_time = time .time ()
262
284
263
- if self .config .LOAD_PATH and not self .config .TRAIN_PATH :
264
- model_dirname = self .config .LOAD_PATH
265
- elif self .config .MODEL_PATH :
266
- model_dirname = self .config .MODEL_PATH
267
- else :
268
- model_dirname = None
269
- print ('Model directory is missing' )
270
- exit (- 1 )
285
+
271
286
272
287
ref_file_name = os .path .join (model_dirname , "ref.txt" )
273
288
predicted_file_name = os .path .join (model_dirname , "pred.txt" )
@@ -388,15 +403,16 @@ def evaluate(self):
388
403
389
404
elapsed = int (time .time () - eval_start_time )
390
405
precision , recall , f1 = calculate_results (true_positive , false_positive , false_negative )
391
-
406
+ accuracy = num_correct_predictions / total_predictions
407
+
392
408
try :
393
409
files_rouge = FilesRouge (predicted_file_name , ref_file_name )
394
410
rouge = files_rouge .get_scores (avg = True , ignore_empty = True )
395
411
except ValueError :
396
412
rouge = 0
397
413
398
414
print ("Evaluation time: %sh%sm%ss" % ((elapsed // 60 // 60 ), (elapsed // 60 ) % 60 , elapsed % 60 ))
399
- return num_correct_predictions / total_predictions , precision , recall , f1 , rouge
415
+ return accuracy , precision , recall , f1 , rouge
400
416
401
417
def print_hyperparams (self ):
402
418
print ("Training batch size:\t \t \t " , self .config .BATCH_SIZE )
@@ -422,17 +438,18 @@ def print_hyperparams(self):
422
438
print ("LSTM dropout keep_prob:\t \t \t " , self .config .RNN_DROPOUT_KEEP_PROB )
423
439
print ("============================================" )
424
440
425
- def trace (self , pbar , sum_loss , batch_num , multi_batch_start_time ):
441
+ def trace (self , pbar , sum_loss , batch_num , multi_batch_start_time , loss_file_name ):
426
442
multi_batch_elapsed = time .time () - multi_batch_start_time
427
443
avg_loss = sum_loss / self .num_batches_to_log
444
+ throughput = self .config .BATCH_SIZE * self .num_batches_to_log / (multi_batch_elapsed if multi_batch_elapsed > 0 else 1 )
428
445
msg = "Average loss at batch {0}: {1}, \t throughput: {2} samples/sec" .format (
429
446
batch_num ,
430
447
avg_loss ,
431
- self .config .BATCH_SIZE
432
- * self .num_batches_to_log
433
- / (multi_batch_elapsed if multi_batch_elapsed > 0 else 1 ),
448
+ throughput
434
449
)
435
450
pbar .set_description (msg )
451
+ with open (loss_file_name , "a+" ) as loss_file :
452
+ loss_file .write ("{0}, {1}, {2}\n " .format (batch_num , avg_loss , throughput ))
436
453
437
454
def encode (self , predict_data_lines ):
438
455
if not self .model :
0 commit comments