From 24601006c2aaac3de2e2df57a65460c208779bdf Mon Sep 17 00:00:00 2001 From: Mikko Kotila Date: Fri, 28 Jan 2022 21:59:38 +0200 Subject: [PATCH 1/2] fixed predict_classes in predict --- talos/commands/predict.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/talos/commands/predict.py b/talos/commands/predict.py index f8291c0d..2b89d2fc 100644 --- a/talos/commands/predict.py +++ b/talos/commands/predict.py @@ -44,6 +44,8 @@ def predict_classes(self, x, metric, asc, model_id=None): ''' + import numpy as np + if model_id is None: from ..utils.best_model import best_model model_id = best_model(self.scan_object, metric, asc) @@ -51,4 +53,8 @@ def predict_classes(self, x, metric, asc, model_id=None): from ..utils.best_model import activate_model model = activate_model(self.scan_object, model_id) - return model.predict_classes(x) + # make (class) predictiosn with the model + preds = model.predict(x) + preds_classes = np.argmax(preds, axis=1) + + return preds_classes From dace33af88ff2bed993ca3a16e9ccb54fd0a2f0f Mon Sep 17 00:00:00 2001 From: Mikko Kotila Date: Fri, 28 Jan 2022 22:06:04 +0200 Subject: [PATCH 2/2] cleanup --- tests/commands/test_latest.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/commands/test_latest.py b/tests/commands/test_latest.py index 792aae85..9e358f30 100644 --- a/tests/commands/test_latest.py +++ b/tests/commands/test_latest.py @@ -1,5 +1,9 @@ def test_latest(): + import warnings + + warnings.simplefilter('ignore') + print('\n >>> start Latest Features... \n') import talos