diff --git a/digits/model/images/classification/views.py b/digits/model/images/classification/views.py index 9d5d12012..3a8f0a4d1 100644 --- a/digits/model/images/classification/views.py +++ b/digits/model/images/classification/views.py @@ -529,6 +529,11 @@ def classify_many(): 'Unable to classify any image from the file') scores = last_output_data + # force correct 2D shape squeezing scores + for i in reversed(range(2, len(scores.shape))): + if scores.shape[i] == 1: + scores = np.squeeze(scores, axis=(i,)) + # take top 5 indices = (-scores).argsort()[:, :5]