Skip to content

Commit

Permalink
Merge pull request #384 from leondavi/new_confusion_natrix
Browse files Browse the repository at this point in the history
[new_confusion_matrix] Build a new func that support in any policy
  • Loading branch information
leondavi authored Aug 7, 2024
2 parents ade2879 + 08bbb27 commit 3942784
Showing 1 changed file with 84 additions and 72 deletions.
156 changes: 84 additions & 72 deletions src_py/apiServer/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ def __init__(self, experiment_phase: ExperimentPhase):
self.experiment_flow_type = self.experiment_phase.get_experiment_flow_type()
if (self.phase == PHASE_PREDICTION_STR):
for source_piece_inst in self.experiment_phase.get_sources_pieces():
csv_dataset = source_piece_inst.get_csv_dataset_parent()
source_piece_csv_labels_file = csv_dataset.genrate_source_piece_ds_csv_file_labels(source_piece_inst, self.phase)
csv_dataset_inst = source_piece_inst.get_csv_dataset_parent() # get the csv dataset instance (csv_dataset_db.py)
source_piece_csv_labels_file = csv_dataset_inst.genrate_source_piece_ds_csv_file_labels(source_piece_inst, self.phase) #return the path of the csv file that contains the source piece labels data
source_piece_inst.set_pointer_to_sourcePiece_CsvDataSet_labels(source_piece_csv_labels_file)
self.headers_list = csv_dataset.get_headers_row()
self.headers_list = csv_dataset_inst.get_headers_row()

Path(f'{EXPERIMENT_RESULTS_PATH}/{self.experiment_phase.get_experiment_flow_name()}').mkdir(parents=True, exist_ok=True)
Path(f'{EXPERIMENT_RESULTS_PATH}/{self.experiment_phase.get_phase_type()}/{self.experiment_phase.get_experiment_flow_name()}').mkdir(parents=True, exist_ok=True)
Expand All @@ -49,13 +49,13 @@ def get_loss_foh_missing(self , plot : bool = False , saveToFile : bool = False)
pass

#TODO Implement this function
def get_loss_by_source(self , plot : bool = False , saveToFile : bool = False): # Todo change i
def get_loss_by_source(self , plot : bool = False , saveToFile : bool = False):
"""
Returns a dictionary of {source : DataFrame[BatchID,'w1','w2'...'wn']} for each source in the experiment.
"""
pass

def get_loss_ts(self , plot : bool = False , saveToFile : bool = False): # Todo change it
def get_loss_ts(self , plot : bool = False , saveToFile : bool = False):
"""
Returns a dictionary of {worker : loss list} for each worker in the experiment.
use plot=True to plot the loss function.
Expand Down Expand Up @@ -83,7 +83,6 @@ def get_loss_ts(self , plot : bool = False , saveToFile : bool = False): # Todo

df = pd.DataFrame(loss_dict)
self.loss_ts_pd = df
#print(df)

if plot:
sns.lineplot(data=df)
Expand Down Expand Up @@ -117,25 +116,6 @@ def get_min_loss(self , plot : bool = False , saveToFile : bool = False): # Todo
plt.title('Training Min Loss')
return min_loss_dict


# if plot:
# plt.figure(figsize = (30,15), dpi = 150)
# plt.rcParams.update({'font.size': 22})
# for worker_name, loss_list in loss_dict.items():
# plt.plot(loss_list, label=worker_name)
# plt.legend(list(loss_dict.keys()))
# plt.xlabel('Batch Number' , fontsize=30)
# plt.ylabel('Loss' , fontsize=30)
# plt.yscale('log')
# plt.xlim(left=0)
# plt.ylim(bottom=0)
# plt.title('Training Loss Function')
# plt.grid(visible=True, which='major', linestyle='-')
# plt.minorticks_on()
# plt.grid(visible=True, which='minor', linestyle='-', alpha=0.7)
# plt.show()
# plt.savefig(f'{EXPERIMENT_RESULTS_PATH}/{self.experiment.name}/Training/Loss_graph.png')

def expend_labels_df(self, df):
assert self.phase == PHASE_PREDICTION_STR, "This function is only available for predict phase"
temp_list = list(range(df.shape[1]))
Expand All @@ -144,65 +124,73 @@ def expend_labels_df(self, df):
df = df.reindex(columns = [*df.columns.tolist(), *temp_list], fill_value = 0)
assert df.shape[1] == 2 * num_of_labels, "Error in expend_labels_df function"
return df


def get_confusion_matrices(self , normalize : bool = False ,plot : bool = False , saveToFile : bool = False):

def build_worker_label_df(original_df, batch_ids, batch_size):
rows_list = []

for batch_id in batch_ids:
# Calculate the start and end indices for the rows to be copied
start_idx = batch_id * batch_size
end_idx = (batch_id + 1) * batch_size

# Extract the rows and append to the list
batch_rows = original_df.iloc[start_idx:end_idx]
rows_list.append(batch_rows)

# Concatenate all the extracted rows into a new DataFrame
df_worker_labels = pd.concat(rows_list, ignore_index=True)
return df_worker_labels

assert self.experiment_flow_type == "classification", "This function is only available for classification experiments"
assert self.phase == PHASE_PREDICTION_STR, "This function is only available for predict phase"
sources_pieces_list = self.experiment_phase.get_sources_pieces()
workers_model_db_list = self.nerl_model_db.get_workers_model_db_list()
confusion_matrix_source_dict = {}
confusion_matrix_worker_dict = {}
recived_batches_dict = self.get_recieved_batches()
for source_piece_inst in sources_pieces_list:
nerltensorType = source_piece_inst.get_nerltensor_type()
source_name = source_piece_inst.get_source_name()
sourcePiece_csv_labels_path = source_piece_inst.get_pointer_to_sourcePiece_CsvDataSet_labels()
df_actual_labels = pd.read_csv(sourcePiece_csv_labels_path)
num_of_labels = df_actual_labels.shape[1]
header_list = range(num_of_labels)
df_actual_labels.columns = header_list
df_actual_labels = self.expend_labels_df(df_actual_labels)
#print(df_actual_labels)
source_name = source_piece_inst.get_source_name()

# build confusion matrix for each worker
target_workers = source_piece_inst.get_target_workers()
worker_missed_batches = {}
# build confusion matrix for each worker
target_workers_string = source_piece_inst.get_target_workers()
target_workers_names = target_workers_string.split(',')
batch_size = source_piece_inst.get_batch_size()
for worker_db in workers_model_db_list:
worker_name = worker_db.get_worker_name()
if worker_name not in target_workers:
if worker_name not in target_workers_names:
continue
df_worker_labels = df_actual_labels.copy()
total_batches_per_source = worker_db.get_total_batches_per_source(source_name)
for batch_id in range(total_batches_per_source):
batch_db = worker_db.get_batch(source_name, str(batch_id))
if not batch_db: # if batch is missing
if not self.missed_batches_warning_msg:
LOG_WARNING(f"missed batches")
self.missed_batches_warning_msg = True
starting_offset = source_piece_inst.get_starting_offset()
df_worker_labels.iloc[batch_id * batch_size: (batch_id + 1) * batch_size, num_of_labels:] = None # set the actual label to None for the predict labels in the df
worker_missed_batches[(worker_name, source_name, str(batch_id))] = (starting_offset + batch_id * batch_size, batch_size) # save the missing batch
worker_recived_batches_id = recived_batches_dict.get(f"phase:{self.experiment_phase.get_name()},{source_name}->{worker_name}") # get a list of recived batches id for the worker
df_worker_labels = build_worker_label_df(df_actual_labels, worker_recived_batches_id, batch_size)
header_list = range(num_of_labels)
df_worker_labels.columns = header_list
df_worker_labels = self.expend_labels_df(df_worker_labels) #Now there is a csv file with the actual labels of the source piece and empty columns for the predict labels

# Check if the actual labels are integers and the predict labels are floats, if so convert the actual labels to nerltensorType
if nerltensorType == 'float':
if any(pd.api.types.is_integer_dtype(df_worker_labels[col]) for col in df_worker_labels.columns):
df_worker_labels = df_worker_labels.astype(float)

df_worker_labels = df_worker_labels.dropna()
# print(df_worker_labels)
for batch_id in range(total_batches_per_source):
#build df_worker_labels with the actual labels and the predict labels
index = 0
for batch_id in worker_recived_batches_id:
batch_db = worker_db.get_batch(source_name, str(batch_id))
if batch_db:
# counter = according indexs of array
# cycle = according indexs of panadas (with jump)
cycle = int(batch_db.get_batch_id())
tensor_data = batch_db.get_tensor_data()
# print(f"tensor_data shape: {tensor_data.shape}")
tensor_data = tensor_data.reshape(batch_size, num_of_labels)
#print(df_worker_labels)
#print(tensor_data)
start_index = cycle * batch_size
end_index = (cycle + 1) * batch_size
df_worker_labels.iloc[start_index:end_index, num_of_labels:] = None # Fix an issue of pandas of incompatible dtype
df_worker_labels.iloc[start_index:end_index, num_of_labels:] = tensor_data
# print(df_worker_labels)

if len(self.headers_list) == 1:
if not batch_db: #It's not necessary to check if the batch is missing, because we already know wich batches are recieved
LOG_INFO(f"Batch {batch_id} is missing for worker {worker_name}")
continue
tensor_data = batch_db.get_tensor_data()
tensor_data = tensor_data.reshape(batch_size, num_of_labels).copy() # Make the tensor_data array writable
start_index = index * batch_size
end_index = (index + 1) * batch_size
df_worker_labels.iloc[start_index:end_index, num_of_labels:] = tensor_data
index += 1

if len(self.headers_list) == 1: # One class
class_name = self.headers_list[0]
actual_labels = df_worker_labels.iloc[:, :num_of_labels].values.flatten().tolist()
predict_labels = df_worker_labels.iloc[:, num_of_labels:].values.flatten().tolist()
Expand Down Expand Up @@ -232,8 +220,7 @@ def get_confusion_matrices(self , normalize : bool = False ,plot : bool = False
confusion_matrix_worker_dict[(worker_name, class_name)] = confusion_matrix
else:
confusion_matrix_worker_dict[(worker_name, class_name)] += confusion_matrix



if plot:
workers = sorted(list({tup[0] for tup in confusion_matrix_worker_dict.keys()}))
classes = sorted(list({tup[1] for tup in confusion_matrix_worker_dict.keys()}))
Expand All @@ -242,7 +229,6 @@ def get_confusion_matrices(self , normalize : bool = False ,plot : bool = False
for i , worker in enumerate(workers):
for j , pred_class in enumerate(classes):
conf_mat = confusion_matrix_worker_dict[(worker , pred_class)]
# print(f"conf_mat: {conf_mat}")
heatmap = sns.heatmap(data=conf_mat ,ax=ax[i,j], annot=True , fmt="d", cmap='Blues',annot_kws={"size": 8}, cbar_kws={'pad': 0.1})
cbar = heatmap.collections[0].colorbar
cbar.ax.tick_params(labelsize = 8)
Expand All @@ -264,14 +250,40 @@ def get_confusion_matrices(self , normalize : bool = False ,plot : bool = False
ax[i].set_aspect('equal')
fig.subplots_adjust(wspace=0.4 , hspace=0.4)
plt.show()




return confusion_matrix_source_dict, confusion_matrix_worker_dict

def get_recieved_batches(self):
"""
Returns a dictionary of recieved batches in the experiment phase.
recived_batches_dict = {(source_name, worker_name): [batch_id,...]}
"""
def recieved_batches_key(phase_name, source_name, worker_name):
return f"phase:{phase_name},{source_name}->{worker_name}"

phase_name = self.experiment_phase.get_name()
recived_batches_dict = {}
sources_pieces_list = self.experiment_phase.get_sources_pieces()
workers_model_db_list = self.nerl_model_db.get_workers_model_db_list()
for source_piece_inst in sources_pieces_list:
source_name = source_piece_inst.get_source_name()
target_workers_string = source_piece_inst.get_target_workers()
target_workers_names = target_workers_string.split(',')
for worker_db in workers_model_db_list:
worker_name = worker_db.get_worker_name()
if worker_name in target_workers_names: # Check if the worker is in the target workers list of this source
for batch_id in range(source_piece_inst.get_num_of_batches()):
batch_db = worker_db.get_batch(source_name, str(batch_id))
if batch_db: # if batch is recieved
recieved_batch_key_str = recieved_batches_key(phase_name, source_name, worker_name)
if recieved_batch_key_str not in recived_batches_dict:
recived_batches_dict[recieved_batch_key_str] = []
recived_batches_dict[recieved_batch_key_str].append(batch_id)
return recived_batches_dict

def get_missed_batches(self):
"""
Returns a list of missed batches in the experiment phase.
Returns a dictionary of missed batches in the experiment phase.
{(source_name, worker_name): [batch_id,...]}
"""
def missed_batches_key(phase_name, source_name, worker_name):
Expand Down

0 comments on commit 3942784

Please sign in to comment.