diff --git a/talos/commands/predict.py b/talos/commands/predict.py index f8291c0..2b89d2f 100644 --- a/talos/commands/predict.py +++ b/talos/commands/predict.py @@ -44,6 +44,8 @@ def predict_classes(self, x, metric, asc, model_id=None): ''' + import numpy as np + if model_id is None: from ..utils.best_model import best_model model_id = best_model(self.scan_object, metric, asc) @@ -51,4 +53,8 @@ def predict_classes(self, x, metric, asc, model_id=None): from ..utils.best_model import activate_model model = activate_model(self.scan_object, model_id) - return model.predict_classes(x) + # make (class) predictiosn with the model + preds = model.predict(x) + preds_classes = np.argmax(preds, axis=1) + + return preds_classes