diff --git a/src_py/apiServer/stats.py b/src_py/apiServer/stats.py index 7c5252ce..6a81b267 100644 --- a/src_py/apiServer/stats.py +++ b/src_py/apiServer/stats.py @@ -373,7 +373,10 @@ def get_model_performence_stats(self , confusion_matrix_worker_dict , show : boo workers_performence = OrderedDict() for (worker_name, class_name) in confusion_matrix_worker_dict.keys(): workers_performence[(worker_name, class_name)] = OrderedDict() - tn, fp, fn, tp = confusion_matrix_worker_dict[(worker_name, class_name)].ravel() + matrix = confusion_matrix_worker_dict[(worker_name, class_name)] + if len(matrix.shape) == 1 or matrix.shape == (1, 1): # if the matrix is 1D or 1x1 skip to prevent errors in the calculations + continue + tn, fp, fn, tp = matrix.ravel() if printStats: LOG_INFO(f"worker {worker_name} class: {class_name} tn: {tn}, fp: {fp}, fn: {fn}, tp: {tp}") tn = int(tn)