From c237987bf32e042d4b03ea4334433191c66c78c2 Mon Sep 17 00:00:00 2001 From: Ari Crellin-Quick Date: Tue, 7 Mar 2017 11:39:08 -0800 Subject: [PATCH 1/3] Allow for prediction of subset of a dataset by specifying TS names --- cesium_app/handlers/prediction.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/cesium_app/handlers/prediction.py b/cesium_app/handlers/prediction.py index 3220ae5..70f5f11 100644 --- a/cesium_app/handlers/prediction.py +++ b/cesium_app/handlers/prediction.py @@ -67,6 +67,7 @@ def post(self): dataset_id = data['datasetID'] model_id = data['modelID'] + ts_names = data.get('ts_names') dataset = Dataset.get(Dataset.id == data["datasetID"]) model = Model.get(Model.id == data["modelID"]) @@ -88,7 +89,12 @@ def post(self): executor = yield self._get_executor() - all_time_series = executor.map(time_series.load, dataset.uris) + if ts_names: + ts_uris = [f.uri for f in dataset.files if f.name in ts_names] + else: + ts_uris = dataset.uris + + all_time_series = executor.map(time_series.load, ts_uris) all_labels = executor.map(lambda ts: ts.label, all_time_series) all_features = executor.map(featurize.featurize_single_ts, all_time_series, From 3639646887d3b86c19c4de0dea777976506389f3 Mon Sep 17 00:00:00 2001 From: Ari Crellin-Quick Date: Wed, 29 Mar 2017 11:58:16 -0700 Subject: [PATCH 2/3] Add test for specifying TS name in predict call --- cesium_app/handlers/prediction.py | 8 ++++-- cesium_app/tests/frontend/test_predict.py | 34 +++++++++++++++++++++++ 2 files changed, 39 insertions(+), 3 deletions(-) diff --git a/cesium_app/handlers/prediction.py b/cesium_app/handlers/prediction.py index 70f5f11..8ef098f 100644 --- a/cesium_app/handlers/prediction.py +++ b/cesium_app/handlers/prediction.py @@ -25,10 +25,10 @@ def _get_prediction(self, prediction_id): try: d = Prediction.get(Prediction.id == prediction_id) except Prediction.DoesNotExist: - raise AccessError('No such dataset') + raise AccessError('No such prediction') if not d.is_owned_by(self.get_username()): - raise AccessError('No such dataset') + raise AccessError('No such prediction') return d @@ -90,7 +90,9 @@ def post(self): executor = yield self._get_executor() if ts_names: - ts_uris = [f.uri for f in dataset.files if f.name in ts_names] + ts_uris = [f.uri for f in dataset.files if os.path.basename(f.name) + in ts_names or os.path.basename(f.name).split('.npz')[0] + in ts_names] else: ts_uris = dataset.uris diff --git a/cesium_app/tests/frontend/test_predict.py b/cesium_app/tests/frontend/test_predict.py index cddfc0a..5f40c30 100644 --- a/cesium_app/tests/frontend/test_predict.py +++ b/cesium_app/tests/frontend/test_predict.py @@ -7,6 +7,9 @@ from os.path import join as pjoin import numpy as np import numpy.testing as npt +from cesium_app.config import cfg +import json +import requests from cesium_app.tests.fixtures import (create_test_project, create_test_dataset, create_test_featureset, create_test_model, create_test_prediction) @@ -204,3 +207,34 @@ def test_download_prediction_csv_regr(driver): [4, 3.1, 3.1]]) finally: os.remove('/tmp/cesium_prediction_results.csv') + + +def test_predict_specific_ts_name(): + with create_test_project() as p, create_test_dataset(p) as ds,\ + create_test_featureset(p) as fs, create_test_model(fs) as m: + ts_data = [[1, 2, 3, 4], [32.2, 53.3, 32.3, 32.52], [0.2, 0.3, 0.6, 0.3]] + impute_kwargs = {'strategy': 'constant', 'value': None} + data = {'datasetID': ds.id, + 'ts_names': ['217801'], + 'modelID': m.id} + print('data:', data) + response = requests.post('{}/predictions'.format(cfg['server']['url']), + data=json.dumps(data)).json() + print('response dict:', response) + assert response['status'] == 'success' + + n_secs = 0 + while n_secs < 5: + pred_info = requests.get('{}/predictions/{}'.format( + cfg['server']['url'], response['data']['id'])).json() + print(pred_info) + if pred_info['status'] == 'success' and pred_info['data']['finished']: + assert isinstance(pred_info['data']['results']['217801'] + ['features']['total_time'], + float) + assert 'Mira' in pred_info['data']['results']['217801']['prediction'] + break + n_secs += 1 + time.sleep(1) + else: + raise Exception('test_predict_specific_ts_name timed out') From c0ea7e40b3b9225a59c4e147fe6c7a006baf83d8 Mon Sep 17 00:00:00 2001 From: Ari Crellin-Quick Date: Mon, 3 Apr 2017 12:43:26 -0700 Subject: [PATCH 3/3] Remove outdated line from tools/watch_logs --- cesium_app/handlers/prediction.py | 3 +++ cesium_app/tests/frontend/test_predict.py | 3 --- tools/watch_logs.py | 1 - 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/cesium_app/handlers/prediction.py b/cesium_app/handlers/prediction.py index 8ef098f..a36b3af 100644 --- a/cesium_app/handlers/prediction.py +++ b/cesium_app/handlers/prediction.py @@ -67,6 +67,8 @@ def post(self): dataset_id = data['datasetID'] model_id = data['modelID'] + # If only a subset of specified dataset is to be used, a list of the + # corresponding time series file names can be provided ts_names = data.get('ts_names') dataset = Dataset.get(Dataset.id == data["datasetID"]) @@ -89,6 +91,7 @@ def post(self): executor = yield self._get_executor() + # If only a subset of the dataset is to be used, get specified files if ts_names: ts_uris = [f.uri for f in dataset.files if os.path.basename(f.name) in ts_names or os.path.basename(f.name).split('.npz')[0] diff --git a/cesium_app/tests/frontend/test_predict.py b/cesium_app/tests/frontend/test_predict.py index 5f40c30..05f3f3e 100644 --- a/cesium_app/tests/frontend/test_predict.py +++ b/cesium_app/tests/frontend/test_predict.py @@ -217,17 +217,14 @@ def test_predict_specific_ts_name(): data = {'datasetID': ds.id, 'ts_names': ['217801'], 'modelID': m.id} - print('data:', data) response = requests.post('{}/predictions'.format(cfg['server']['url']), data=json.dumps(data)).json() - print('response dict:', response) assert response['status'] == 'success' n_secs = 0 while n_secs < 5: pred_info = requests.get('{}/predictions/{}'.format( cfg['server']['url'], response['data']['id'])).json() - print(pred_info) if pred_info['status'] == 'success' and pred_info['data']['finished']: assert isinstance(pred_info['data']['results']['217801'] ['features']['total_time'], diff --git a/tools/watch_logs.py b/tools/watch_logs.py index 8689a4a..8c38d82 100755 --- a/tools/watch_logs.py +++ b/tools/watch_logs.py @@ -96,7 +96,6 @@ def logs_from_config(supervisor_conf): with nostdout(): from cesium_app.config import cfg -watched.append(cfg['paths']['err_log_path']) watched.append('log/error.log') watched.append('log/nginx-error.log')