Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[new_confusion_matrix] Build a new func that support in any policy #384

Merged
merged 1 commit into from
Aug 7, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 84 additions & 72 deletions src_py/apiServer/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ def __init__(self, experiment_phase: ExperimentPhase):
self.experiment_flow_type = self.experiment_phase.get_experiment_flow_type()
if (self.phase == PHASE_PREDICTION_STR):
for source_piece_inst in self.experiment_phase.get_sources_pieces():
csv_dataset = source_piece_inst.get_csv_dataset_parent()
source_piece_csv_labels_file = csv_dataset.genrate_source_piece_ds_csv_file_labels(source_piece_inst, self.phase)
csv_dataset_inst = source_piece_inst.get_csv_dataset_parent() # get the csv dataset instance (csv_dataset_db.py)
source_piece_csv_labels_file = csv_dataset_inst.genrate_source_piece_ds_csv_file_labels(source_piece_inst, self.phase) #return the path of the csv file that contains the source piece labels data
source_piece_inst.set_pointer_to_sourcePiece_CsvDataSet_labels(source_piece_csv_labels_file)
self.headers_list = csv_dataset.get_headers_row()
self.headers_list = csv_dataset_inst.get_headers_row()

Path(f'{EXPERIMENT_RESULTS_PATH}/{self.experiment_phase.get_experiment_flow_name()}').mkdir(parents=True, exist_ok=True)
Path(f'{EXPERIMENT_RESULTS_PATH}/{self.experiment_phase.get_phase_type()}/{self.experiment_phase.get_experiment_flow_name()}').mkdir(parents=True, exist_ok=True)
Expand All @@ -49,13 +49,13 @@ def get_loss_foh_missing(self , plot : bool = False , saveToFile : bool = False)
pass

#TODO Implement this function
def get_loss_by_source(self , plot : bool = False , saveToFile : bool = False): # Todo change i
def get_loss_by_source(self , plot : bool = False , saveToFile : bool = False):
"""
Returns a dictionary of {source : DataFrame[BatchID,'w1','w2'...'wn']} for each source in the experiment.
"""
pass

def get_loss_ts(self , plot : bool = False , saveToFile : bool = False): # Todo change it
def get_loss_ts(self , plot : bool = False , saveToFile : bool = False):
"""
Returns a dictionary of {worker : loss list} for each worker in the experiment.
use plot=True to plot the loss function.
Expand Down Expand Up @@ -83,7 +83,6 @@ def get_loss_ts(self , plot : bool = False , saveToFile : bool = False): # Todo

df = pd.DataFrame(loss_dict)
self.loss_ts_pd = df
#print(df)

if plot:
sns.lineplot(data=df)
Expand Down Expand Up @@ -117,25 +116,6 @@ def get_min_loss(self , plot : bool = False , saveToFile : bool = False): # Todo
plt.title('Training Min Loss')
return min_loss_dict


# if plot:
# plt.figure(figsize = (30,15), dpi = 150)
# plt.rcParams.update({'font.size': 22})
# for worker_name, loss_list in loss_dict.items():
# plt.plot(loss_list, label=worker_name)
# plt.legend(list(loss_dict.keys()))
# plt.xlabel('Batch Number' , fontsize=30)
# plt.ylabel('Loss' , fontsize=30)
# plt.yscale('log')
# plt.xlim(left=0)
# plt.ylim(bottom=0)
# plt.title('Training Loss Function')
# plt.grid(visible=True, which='major', linestyle='-')
# plt.minorticks_on()
# plt.grid(visible=True, which='minor', linestyle='-', alpha=0.7)
# plt.show()
# plt.savefig(f'{EXPERIMENT_RESULTS_PATH}/{self.experiment.name}/Training/Loss_graph.png')

def expend_labels_df(self, df):
assert self.phase == PHASE_PREDICTION_STR, "This function is only available for predict phase"
temp_list = list(range(df.shape[1]))
Expand All @@ -144,65 +124,73 @@ def expend_labels_df(self, df):
df = df.reindex(columns = [*df.columns.tolist(), *temp_list], fill_value = 0)
assert df.shape[1] == 2 * num_of_labels, "Error in expend_labels_df function"
return df


def get_confusion_matrices(self , normalize : bool = False ,plot : bool = False , saveToFile : bool = False):

def build_worker_label_df(original_df, batch_ids, batch_size):
rows_list = []

for batch_id in batch_ids:
# Calculate the start and end indices for the rows to be copied
start_idx = batch_id * batch_size
end_idx = (batch_id + 1) * batch_size

# Extract the rows and append to the list
batch_rows = original_df.iloc[start_idx:end_idx]
rows_list.append(batch_rows)

# Concatenate all the extracted rows into a new DataFrame
df_worker_labels = pd.concat(rows_list, ignore_index=True)
return df_worker_labels

assert self.experiment_flow_type == "classification", "This function is only available for classification experiments"
assert self.phase == PHASE_PREDICTION_STR, "This function is only available for predict phase"
sources_pieces_list = self.experiment_phase.get_sources_pieces()
workers_model_db_list = self.nerl_model_db.get_workers_model_db_list()
confusion_matrix_source_dict = {}
confusion_matrix_worker_dict = {}
recived_batches_dict = self.get_recieved_batches()
for source_piece_inst in sources_pieces_list:
nerltensorType = source_piece_inst.get_nerltensor_type()
source_name = source_piece_inst.get_source_name()
sourcePiece_csv_labels_path = source_piece_inst.get_pointer_to_sourcePiece_CsvDataSet_labels()
df_actual_labels = pd.read_csv(sourcePiece_csv_labels_path)
num_of_labels = df_actual_labels.shape[1]
header_list = range(num_of_labels)
df_actual_labels.columns = header_list
df_actual_labels = self.expend_labels_df(df_actual_labels)
#print(df_actual_labels)
source_name = source_piece_inst.get_source_name()

# build confusion matrix for each worker
target_workers = source_piece_inst.get_target_workers()
worker_missed_batches = {}
# build confusion matrix for each worker
target_workers_string = source_piece_inst.get_target_workers()
target_workers_names = target_workers_string.split(',')
batch_size = source_piece_inst.get_batch_size()
for worker_db in workers_model_db_list:
worker_name = worker_db.get_worker_name()
if worker_name not in target_workers:
if worker_name not in target_workers_names:
continue
df_worker_labels = df_actual_labels.copy()
total_batches_per_source = worker_db.get_total_batches_per_source(source_name)
for batch_id in range(total_batches_per_source):
batch_db = worker_db.get_batch(source_name, str(batch_id))
if not batch_db: # if batch is missing
if not self.missed_batches_warning_msg:
LOG_WARNING(f"missed batches")
self.missed_batches_warning_msg = True
starting_offset = source_piece_inst.get_starting_offset()
df_worker_labels.iloc[batch_id * batch_size: (batch_id + 1) * batch_size, num_of_labels:] = None # set the actual label to None for the predict labels in the df
worker_missed_batches[(worker_name, source_name, str(batch_id))] = (starting_offset + batch_id * batch_size, batch_size) # save the missing batch
worker_recived_batches_id = recived_batches_dict.get(f"phase:{self.experiment_phase.get_name()},{source_name}->{worker_name}") # get a list of recived batches id for the worker
df_worker_labels = build_worker_label_df(df_actual_labels, worker_recived_batches_id, batch_size)
header_list = range(num_of_labels)
df_worker_labels.columns = header_list
df_worker_labels = self.expend_labels_df(df_worker_labels) #Now there is a csv file with the actual labels of the source piece and empty columns for the predict labels

# Check if the actual labels are integers and the predict labels are floats, if so convert the actual labels to nerltensorType
if nerltensorType == 'float':
if any(pd.api.types.is_integer_dtype(df_worker_labels[col]) for col in df_worker_labels.columns):
df_worker_labels = df_worker_labels.astype(float)

df_worker_labels = df_worker_labels.dropna()
# print(df_worker_labels)
for batch_id in range(total_batches_per_source):
#build df_worker_labels with the actual labels and the predict labels
index = 0
for batch_id in worker_recived_batches_id:
batch_db = worker_db.get_batch(source_name, str(batch_id))
if batch_db:
# counter = according indexs of array
# cycle = according indexs of panadas (with jump)
cycle = int(batch_db.get_batch_id())
tensor_data = batch_db.get_tensor_data()
# print(f"tensor_data shape: {tensor_data.shape}")
tensor_data = tensor_data.reshape(batch_size, num_of_labels)
#print(df_worker_labels)
#print(tensor_data)
start_index = cycle * batch_size
end_index = (cycle + 1) * batch_size
df_worker_labels.iloc[start_index:end_index, num_of_labels:] = None # Fix an issue of pandas of incompatible dtype
df_worker_labels.iloc[start_index:end_index, num_of_labels:] = tensor_data
# print(df_worker_labels)

if len(self.headers_list) == 1:
if not batch_db: #It's not necessary to check if the batch is missing, because we already know wich batches are recieved
LOG_INFO(f"Batch {batch_id} is missing for worker {worker_name}")
continue
tensor_data = batch_db.get_tensor_data()
tensor_data = tensor_data.reshape(batch_size, num_of_labels).copy() # Make the tensor_data array writable
start_index = index * batch_size
end_index = (index + 1) * batch_size
df_worker_labels.iloc[start_index:end_index, num_of_labels:] = tensor_data
index += 1

if len(self.headers_list) == 1: # One class
class_name = self.headers_list[0]
actual_labels = df_worker_labels.iloc[:, :num_of_labels].values.flatten().tolist()
predict_labels = df_worker_labels.iloc[:, num_of_labels:].values.flatten().tolist()
Expand Down Expand Up @@ -232,8 +220,7 @@ def get_confusion_matrices(self , normalize : bool = False ,plot : bool = False
confusion_matrix_worker_dict[(worker_name, class_name)] = confusion_matrix
else:
confusion_matrix_worker_dict[(worker_name, class_name)] += confusion_matrix



if plot:
workers = sorted(list({tup[0] for tup in confusion_matrix_worker_dict.keys()}))
classes = sorted(list({tup[1] for tup in confusion_matrix_worker_dict.keys()}))
Expand All @@ -242,7 +229,6 @@ def get_confusion_matrices(self , normalize : bool = False ,plot : bool = False
for i , worker in enumerate(workers):
for j , pred_class in enumerate(classes):
conf_mat = confusion_matrix_worker_dict[(worker , pred_class)]
# print(f"conf_mat: {conf_mat}")
heatmap = sns.heatmap(data=conf_mat ,ax=ax[i,j], annot=True , fmt="d", cmap='Blues',annot_kws={"size": 8}, cbar_kws={'pad': 0.1})
cbar = heatmap.collections[0].colorbar
cbar.ax.tick_params(labelsize = 8)
Expand All @@ -264,14 +250,40 @@ def get_confusion_matrices(self , normalize : bool = False ,plot : bool = False
ax[i].set_aspect('equal')
fig.subplots_adjust(wspace=0.4 , hspace=0.4)
plt.show()




return confusion_matrix_source_dict, confusion_matrix_worker_dict

def get_recieved_batches(self):
"""
Returns a dictionary of recieved batches in the experiment phase.
recived_batches_dict = {(source_name, worker_name): [batch_id,...]}
"""
def recieved_batches_key(phase_name, source_name, worker_name):
return f"phase:{phase_name},{source_name}->{worker_name}"

phase_name = self.experiment_phase.get_name()
recived_batches_dict = {}
sources_pieces_list = self.experiment_phase.get_sources_pieces()
workers_model_db_list = self.nerl_model_db.get_workers_model_db_list()
for source_piece_inst in sources_pieces_list:
source_name = source_piece_inst.get_source_name()
target_workers_string = source_piece_inst.get_target_workers()
target_workers_names = target_workers_string.split(',')
for worker_db in workers_model_db_list:
worker_name = worker_db.get_worker_name()
if worker_name in target_workers_names: # Check if the worker is in the target workers list of this source
for batch_id in range(source_piece_inst.get_num_of_batches()):
batch_db = worker_db.get_batch(source_name, str(batch_id))
if batch_db: # if batch is recieved
recieved_batch_key_str = recieved_batches_key(phase_name, source_name, worker_name)
if recieved_batch_key_str not in recived_batches_dict:
recived_batches_dict[recieved_batch_key_str] = []
recived_batches_dict[recieved_batch_key_str].append(batch_id)
return recived_batches_dict

def get_missed_batches(self):
"""
Returns a list of missed batches in the experiment phase.
Returns a dictionary of missed batches in the experiment phase.
{(source_name, worker_name): [batch_id,...]}
"""
def missed_batches_key(phase_name, source_name, worker_name):
Expand Down