diff --git a/talos/commands/predict.py b/talos/commands/predict.py index f8291c0..2b89d2f 100644 --- a/talos/commands/predict.py +++ b/talos/commands/predict.py @@ -44,6 +44,8 @@ def predict_classes(self, x, metric, asc, model_id=None): ''' + import numpy as np + if model_id is None: from ..utils.best_model import best_model model_id = best_model(self.scan_object, metric, asc) @@ -51,4 +53,8 @@ def predict_classes(self, x, metric, asc, model_id=None): from ..utils.best_model import activate_model model = activate_model(self.scan_object, model_id) - return model.predict_classes(x) + # make (class) predictiosn with the model + preds = model.predict(x) + preds_classes = np.argmax(preds, axis=1) + + return preds_classes diff --git a/tests/commands/test_latest.py b/tests/commands/test_latest.py index 792aae8..9e358f3 100644 --- a/tests/commands/test_latest.py +++ b/tests/commands/test_latest.py @@ -1,5 +1,9 @@ def test_latest(): + import warnings + + warnings.simplefilter('ignore') + print('\n >>> start Latest Features... \n') import talos