-
Notifications
You must be signed in to change notification settings - Fork 3
/
eval_rocket.py
113 lines (95 loc) · 3.64 KB
/
eval_rocket.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
########################################################################
#
# @author : Emmanouil Sylligardos
# @when : Winter Semester 2022/2023
# @where : LIPADE internship Paris
# @title : MSAD (Model Selection Anomaly Detection)
# @component: root
# @file : eval_rocket
#
########################################################################
import argparse
import re
import os
from datetime import datetime
from utils.timeseries_dataset import read_files, create_splits
from utils.evaluator import Evaluator, load_classifier
def eval_rocket(data_path, model_path, path_save=None, fnames=None, read_from_file=None):
"""Predict time series with the given ROCKET model.
:param data_path: Path to the data to predict.
:param model_path: Path to the model to load and use for predictions.
:param path_save: Path to save the evaluation results.
:param fnames: List of file names (time series) to predict.
:param read_from_file: File to read which time series to predict from a given path.
Returns:
DataFrame: A DataFrame containing the predicted time series.
"""
window_size = int(re.search(r'\d+', str(data_path)).group())
classifier_name = f"rocket_{window_size}"
if read_from_file is not None and "unsupervised" in read_from_file:
classifier_name += f"_{read_from_file.split('/')[-1].replace('unsupervised_', '')[:-len('.csv')]}"
elif "testsize_" in model_path:
extra = model_path.split('/')[-2].replace(classifier_name, "")
classifier_name += extra
assert(
not (fnames is not None and read_from_file is not None)
), "You should provide either the fnames or the path to the specific splits, not both"
# Load model
model = load_classifier(model_path)
# Load the splits
if read_from_file is not None:
_, val_set, test_set = create_splits(
data_path,
read_from_file=read_from_file,
)
fnames = test_set if len(test_set) > 0 else val_set
# fnames = fnames[:100]
else:
# Read data (single csv file or directory with csvs)
if '.csv' == data_path[-len('.csv'):]:
tmp_fnames = [data_path.split('/')[-1]]
data_path = data_path.split('/')[:-1]
data_path = '/'.join(data_path)
else:
tmp_fnames = read_files(data_path)
# Keep specific time series if fnames is given
if fnames is not None:
fnames_len = len(fnames)
fnames = [x for x in tmp_fnames if x in fnames]
if len(fnames) != fnames_len:
raise ValueError("The data path does not include the time series in fnames")
else:
fnames = tmp_fnames
# Compute predictions and inference time
evaluator = Evaluator()
results = evaluator.predict(
model=model,
fnames=fnames,
data_path=data_path,
deep_model=False,
)
results.columns = [f"{classifier_name}_{x}" for x in results.columns.values]
# Print results
print(results)
# Save the results
if path_save is not None:
file_name = os.path.join(path_save, f"{classifier_name}_preds.csv")
results.to_csv(file_name)
return results
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog='Evaluate rocket models',
description='Evaluate rocekt models \
on a single or multiple time series and save the results'
)
parser.add_argument('-d', '--data', type=str, help='path to the time series to predict', required=True)
parser.add_argument('-mp', '--model_path', type=str, help='path to the trained model', required=True)
parser.add_argument('-ps', '--path_save', type=str, help='path to save the results', default="results/raw_predictions")
parser.add_argument('-f', '--file', type=str, help='path to file that contains a specific split', default=None)
args = parser.parse_args()
eval_rocket(
data_path=args.data,
model_path=args.model_path,
path_save=args.path_save,
read_from_file=args.file
)