Skip to content

Commit 9e6cb61

Browse files
author
Colin Grambow
committed
Modify predictions to work with ensembles
1 parent 23d06ff commit 9e6cb61

File tree

2 files changed

+38
-5
lines changed

2 files changed

+38
-5
lines changed

reacdiff/parsing.py

+4
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ def parse_predict_args():
1919
help='Path to data containing states for prediction task')
2020
parser.add_argument('--model', type=str, required=True,
2121
help='Path to trained model')
22+
parser.add_argument('--targets_path', type=str,
23+
help='Path to targets')
24+
parser.add_argument('--targets_size', type=int, default=4,
25+
help='Size of target vector. Only required if targets_path is not specified.')
2226
parser.add_argument('--data_path2', type=str,
2327
help='Path to additional observable states for prediction')
2428
parser.add_argument('--save_path', type=str, default=os.path.join(os.getcwd(), 'preds.csv'),

reacdiff/train/predict.py

+34-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
import glob
12
import os
3+
import re
24

35
import keras
6+
import numpy as np
47

58
import reacdiff.data.data as datamod
69
import reacdiff.utils as utils
@@ -11,14 +14,40 @@ def predict(args):
1114
print('Loading data')
1215
data = datamod.Dataset(
1316
datamod.load_data(args.data_path),
17+
targets=None if args.targets_path is None else datamod.load_csv(args.targets_path),
1418
data2=None if args.data_path2 is None else datamod.load_data(args.data_path2)
1519
)
1620

17-
os.makedirs(os.path.dirname(args.save_path), exist_ok=True)
21+
os.makedirs(os.path.dirname(os.path.abspath(args.save_path)), exist_ok=True)
1822

19-
# Load model
20-
model = keras.models.load_model(args.model, custom_objects={'rmse': utils.rmse, 'mae': utils.mae})
23+
# Walk directory for ensemble of models
24+
if os.path.isdir(args.model):
25+
model_dirs = glob.iglob(os.path.join(args.model, 'model*'))
26+
model_nums = [re.search('\d+', os.path.basename(d))[0] for d in model_dirs]
27+
model_nums.sort()
2128

22-
# Predict
23-
preds = model.predict(data.get_data(), batch_size=args.batch_size, verbose=1)
29+
targets_size = args.targets_size if args.targets_path is None else data.targets.shape[1]
30+
all_preds = np.zeros((len(model_nums), len(data), targets_size))
31+
32+
for model_idx in model_nums:
33+
print(f'Evaluating model {model_idx}')
34+
35+
model_path = os.path.join(args.model, f'model{model_idx}', 'model.h5')
36+
model = keras.models.load_model(model_path, custom_objects={'rmse': utils.rmse, 'mae': utils.mae})
37+
38+
preds = model.predict(data.get_data(), batch_size=args.batch_size, verbose=1)
39+
all_preds[model_idx] = preds
40+
preds = np.mean(all_preds, axis=0)
41+
else:
42+
# Load model
43+
model = keras.models.load_model(args.model, custom_objects={'rmse': utils.rmse, 'mae': utils.mae})
44+
45+
# Predict
46+
preds = model.predict(data.get_data(), batch_size=args.batch_size, verbose=1)
47+
48+
if args.targets_path is not None:
49+
print('Evaluating ensemble')
50+
rmse = utils.rmse_np(data.targets, preds)
51+
mae = utils.mae_np(data.targets, preds)
52+
print(f'rmse: {rmse:.4f}; mae: {mae:.4f}')
2453
datamod.save_csv(preds, args.save_path)

0 commit comments

Comments
 (0)