Skip to content

Commit

Permalink
Merge pull request #380 from leondavi/fix_missed_batches_RR
Browse files Browse the repository at this point in the history
Fix missed batches rr
  • Loading branch information
leondavi authored Jul 28, 2024
2 parents 4f3bd0c + c6b4210 commit 8382305
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 18 deletions.
2 changes: 2 additions & 0 deletions src_py/apiServer/networkComponents.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(self, dc_json: dict):
self.sourcesPolicies = []
self.sourceEpochs = {}
self.routers = []
self.sources_policy_dict = {}


# Initializing maps
Expand Down Expand Up @@ -87,6 +88,7 @@ def __init__(self, dc_json: dict):
self.sourcesPolicies.append(source[GetFields.get_policy_field_name()])
self.sourceEpochs[source[GetFields.get_name_field_name()]] = source[GetFields.get_epochs_field_name()]
self.map_name_to_type[source[GetFields.get_name_field_name()]] = TYPE_SOURCE
self.sources_policy_dict[source[GetFields.get_name_field_name()]] = source[GetFields.get_policy_field_name()]

# Getting the names of all the routers:
routersJsons = self.jsonData[GetFields.get_routers_field_name()]
Expand Down
58 changes: 40 additions & 18 deletions src_py/apiServer/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,27 +275,49 @@ def get_missed_batches(self):
"""
def missed_batches_key(phase_name, source_name, worker_name):
return f"phase:{phase_name},{source_name}->{worker_name}"

if self.phase == PHASE_PREDICTION_STR:
phase_name = self.experiment_phase.get_name()
missed_batches_dict = {}
sources_pieces_list = self.experiment_phase.get_sources_pieces()
workers_model_db_list = self.nerl_model_db.get_workers_model_db_list()
for source_piece_inst in sources_pieces_list:
source_name = source_piece_inst.get_source_name()

phase_name = self.experiment_phase.get_name()
missed_batches_dict = {}
sources_pieces_list = self.experiment_phase.get_sources_pieces()
workers_model_db_list = self.nerl_model_db.get_workers_model_db_list()
for source_piece_inst in sources_pieces_list:
source_name = source_piece_inst.get_source_name()
source_policy = globe.components.sources_policy_dict[source_name] # 0 -> casting , 1 -> round robin, 2 -> random
target_workers_string = source_piece_inst.get_target_workers()
target_workers_names = target_workers_string.split(',')
if source_policy == '0': # casting policy
for worker_db in workers_model_db_list:
worker_name = worker_db.get_worker_name()
if worker_name in target_workers_names: # Check if the worker is in the target workers list of this source
for batch_id in range(source_piece_inst.get_num_of_batches()):
batch_db = worker_db.get_batch(source_name, str(batch_id))
if not batch_db: # if batch is missing
missed_batch_key_str = missed_batches_key(phase_name, source_name, worker_name)
if missed_batch_key_str not in missed_batches_dict:
missed_batches_dict[missed_batch_key_str] = []
missed_batches_dict[missed_batch_key_str].append(batch_id)
elif source_policy == '1': # round robin policy
number_of_workers = len(target_workers_names)
batches_indexes = [i for i in range(source_piece_inst.get_num_of_batches())]
batch_worker_tuple = [(batch_index, target_workers_names[batch_index % number_of_workers]) for batch_index in batches_indexes] # (batch_index, worker_name_that_should_recive_the_batch)
worker_batches_dict = {worker_name: [] for worker_name in target_workers_names} # Create a dictionary to hold batches id for each worker
for batch_index, worker_name in batch_worker_tuple:
worker_batches_dict[worker_name].append(batch_index)
for worker_db in workers_model_db_list:
worker_name = worker_db.get_worker_name()
total_batches_per_source = worker_db.get_total_batches_per_source(source_name)
for batch_id in range(total_batches_per_source):
batch_db = worker_db.get_batch(source_name, str(batch_id))
if not batch_db: # if batch is missing
missed_batch_key_str = missed_batches_key(phase_name, source_name, worker_name)
if missed_batch_key_str not in missed_batches_dict:
missed_batches_dict[missed_batch_key_str] = []
missed_batches_dict[missed_batch_key_str].append(batch_id)
#print(f"missed_batches_dict: {missed_batches_dict}")
if worker_name in target_workers_names: # Check if the worker is in the target workers list of this source
for batch_id in worker_batches_dict[worker_name]:
batch_db = worker_db.get_batch(source_name, str(batch_id))
if not batch_db:
missed_batch_key_str = missed_batches_key(phase_name, source_name, worker_name)
if missed_batch_key_str not in missed_batches_dict:
missed_batches_dict[missed_batch_key_str] = []
missed_batches_dict[missed_batch_key_str].append(batch_id)
elif source_policy == '2': # random policy
LOG_INFO(f"Source {source_name} policy is random, it's not posiblle check for missed batches")
break
return missed_batches_dict

def get_communication_stats_workers(self):
# return dictionary of {worker : {communication_stats}}
communication_stats_workers_dict = OrderedDict()
Expand Down

0 comments on commit 8382305

Please sign in to comment.