1
+ import glob
1
2
import os
3
+ import re
2
4
3
5
import keras
6
+ import numpy as np
4
7
5
8
import reacdiff .data .data as datamod
6
9
import reacdiff .utils as utils
@@ -11,14 +14,40 @@ def predict(args):
11
14
print ('Loading data' )
12
15
data = datamod .Dataset (
13
16
datamod .load_data (args .data_path ),
17
+ targets = None if args .targets_path is None else datamod .load_csv (args .targets_path ),
14
18
data2 = None if args .data_path2 is None else datamod .load_data (args .data_path2 )
15
19
)
16
20
17
- os .makedirs (os .path .dirname (args .save_path ), exist_ok = True )
21
+ os .makedirs (os .path .dirname (os . path . abspath ( args .save_path ) ), exist_ok = True )
18
22
19
- # Load model
20
- model = keras .models .load_model (args .model , custom_objects = {'rmse' : utils .rmse , 'mae' : utils .mae })
23
+ # Walk directory for ensemble of models
24
+ if os .path .isdir (args .model ):
25
+ model_dirs = glob .iglob (os .path .join (args .model , 'model*' ))
26
+ model_nums = [re .search ('\d+' , os .path .basename (d ))[0 ] for d in model_dirs ]
27
+ model_nums .sort ()
21
28
22
- # Predict
23
- preds = model .predict (data .get_data (), batch_size = args .batch_size , verbose = 1 )
29
+ targets_size = args .targets_size if args .targets_path is None else data .targets .shape [1 ]
30
+ all_preds = np .zeros ((len (model_nums ), len (data ), targets_size ))
31
+
32
+ for model_idx in model_nums :
33
+ print (f'Evaluating model { model_idx } ' )
34
+
35
+ model_path = os .path .join (args .model , f'model{ model_idx } ' , 'model.h5' )
36
+ model = keras .models .load_model (model_path , custom_objects = {'rmse' : utils .rmse , 'mae' : utils .mae })
37
+
38
+ preds = model .predict (data .get_data (), batch_size = args .batch_size , verbose = 1 )
39
+ all_preds [model_idx ] = preds
40
+ preds = np .mean (all_preds , axis = 0 )
41
+ else :
42
+ # Load model
43
+ model = keras .models .load_model (args .model , custom_objects = {'rmse' : utils .rmse , 'mae' : utils .mae })
44
+
45
+ # Predict
46
+ preds = model .predict (data .get_data (), batch_size = args .batch_size , verbose = 1 )
47
+
48
+ if args .targets_path is not None :
49
+ print ('Evaluating ensemble' )
50
+ rmse = utils .rmse_np (data .targets , preds )
51
+ mae = utils .mae_np (data .targets , preds )
52
+ print (f'rmse: { rmse :.4f} ; mae: { mae :.4f} ' )
24
53
datamod .save_csv (preds , args .save_path )
0 commit comments