Skip to content

Commit

Permalink
change for PR
Browse files Browse the repository at this point in the history
  • Loading branch information
Orisadek committed Aug 3, 2024
1 parent 9e83dd6 commit f9569da
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 72 deletions.
72 changes: 0 additions & 72 deletions src_py/apiServer/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,78 +146,6 @@ def expend_labels_df(self, df):
return df


def survey(results, category_names):
"""
Parameters
----------
results : dict
A mapping from question labels to a list of answers per category.
It is assumed all lists contain the same number of entries and that
it matches the length of *category_names*.
category_names : list of str
The category labels.
"""
labels = list(results.keys())
data = np.array(list(results.values()))
data_cum = data.cumsum(axis=1)
category_colors = plt.colormaps['RdYlGn'](np.linspace(0.15, 0.85, data.shape[1]))
fig, ax = plt.subplots(figsize=(9.2, 5))
ax.invert_yaxis()
ax.xaxis.set_visible(False)
ax.set_xlim(0, np.sum(data, axis=1).max())

for i, (colname, color) in enumerate(zip(category_names, category_colors)):
widths = data[:, i]
starts = data_cum[:, i] - widths
rects = ax.barh(labels, widths, left=starts, height=0.5,label=colname, color=color)

r, g, b, _ = color
text_color = 'white' if r * g * b < 0.5 else 'darkgrey'
ax.bar_label(rects, label_type='center', color=text_color)
ax.legend(ncols=len(category_names), bbox_to_anchor=(0, 1),loc='lower left', fontsize='small')

return fig, ax

def get_distributed_train_labels(self , normalize : bool = False ,plot : bool = False , saveToFile : bool = False):
assert self.phase == PHASE_TRAINING_STR, "This function is only available for train phase"
sources_pieces_list = self.experiment_phase.get_sources_pieces()
workers_model_db_list = self.nerl_model_db.get_workers_model_db_list()
dict_worker = {}
for worker in workers_model_db_list:
dict_worker[worker.get_worker_name()] = []
labels_d = []
for source_piece_inst in sources_pieces_list:
sourcePiece_csv_labels_path = source_piece_inst.get_pointer_to_sourcePiece_CsvDataSet()
df_actual_labels = pd.read_csv(sourcePiece_csv_labels_path)
print(df_actual_labels,"df_actual_labels")
num_of_labels = df_actual_labels.shape[1]
header_list = range(num_of_labels)
df_actual_labels.columns = header_list
labels_d = header_list
source_name = source_piece_inst.get_source_name()
# build confusion matrix for each worker
target_workers = source_piece_inst.get_target_workers()
batch_size = source_piece_inst.get_batch_size()
for worker_db in workers_model_db_list:
worker_name = worker_db.get_worker_name()
if worker_name not in target_workers:
continue
if(dict_worker[worker_name] == []):
dict_worker[worker_name] = [0] * num_of_labels
df_worker_labels = df_actual_labels.copy()
total_batches_per_source = worker_db.get_total_batches_per_source(source_name)
for batch_id in range(total_batches_per_source):
batch_db = worker_db.get_batch(source_name, str(batch_id))
if not batch_db: # if batch is missing
if not self.missed_batches_warning_msg:
LOG_WARNING(f"missed batches")
self.missed_batches_warning_msg = True
dict_worker[worker_name][sum(df_worker_labels.iloc[batch_id * batch_size: (batch_id + 1) * batch_size])] +=1 # set the actual label to None for the predict labels in the df

self.survey(dict_worker, labels_d)
plt.show()


def get_confusion_matrices(self , normalize : bool = False ,plot : bool = False , saveToFile : bool = False):
assert self.experiment_flow_type == "classification", "This function is only available for classification experiments"
assert self.phase == PHASE_PREDICTION_STR, "This function is only available for predict phase"
Expand Down
11 changes: 11 additions & 0 deletions src_py/apiServer/statsTiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ def get_confusion_matrices(self , normalize : bool = False ,plot : bool = False
confusion_matrix_source_dict, confusion_matrix_distributed_dict = self.get_confusion_matrices_tiles(normalize ,plot ,saveToFile )
return confusion_matrix_source_dict, confusion_matrix_distributed_dict


"""
This function is used to calculate the confusion matrix for each cluster and each class in the distributed tokens
attention! get_confusion_matrices_tiles take the sum of the labels and from that do majority vote and
then decide 0 or 1 for each class
"""
def get_confusion_matrices_tiles(self , normalize : bool = False ,plot : bool = False , saveToFile : bool = False):
assert self.experiment_flow_type == "classification", "This function is only available for classification experiments"
assert self.phase == PHASE_PREDICTION_STR, "This function is only available for predict phase"
Expand Down Expand Up @@ -168,7 +174,12 @@ def argmax_axis_1(self,list_of_lists):
result = [sublist.index(max(sublist)) for sublist in list_of_lists]
return result

"""
This function is used to calculate the confusion matrix for each cluster and each class in the distributed tokens
attention! get_confusion_matrices_tiles_new take the sum of the probabilities and from that do majority vote and
then decide 0 or 1 for each class
"""
def get_confusion_matrices_tiles_new(self , normalize : bool = False ,plot : bool = False , saveToFile : bool = False):
assert self.experiment_flow_type == "classification", "This function is only available for classification experiments"
assert self.phase == PHASE_PREDICTION_STR, "This function is only available for predict phase"
Expand Down

0 comments on commit f9569da

Please sign in to comment.