Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix missed batches rr #380

Merged
merged 3 commits into from
Jul 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src_py/apiServer/networkComponents.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(self, dc_json: dict):
self.sourcesPolicies = []
self.sourceEpochs = {}
self.routers = []
self.sources_policy_dict = {}


# Initializing maps
Expand Down Expand Up @@ -87,6 +88,7 @@ def __init__(self, dc_json: dict):
self.sourcesPolicies.append(source[GetFields.get_policy_field_name()])
self.sourceEpochs[source[GetFields.get_name_field_name()]] = source[GetFields.get_epochs_field_name()]
self.map_name_to_type[source[GetFields.get_name_field_name()]] = TYPE_SOURCE
self.sources_policy_dict[source[GetFields.get_name_field_name()]] = source[GetFields.get_policy_field_name()]

# Getting the names of all the routers:
routersJsons = self.jsonData[GetFields.get_routers_field_name()]
Expand Down
58 changes: 40 additions & 18 deletions src_py/apiServer/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,27 +275,49 @@ def get_missed_batches(self):
"""
def missed_batches_key(phase_name, source_name, worker_name):
return f"phase:{phase_name},{source_name}->{worker_name}"

if self.phase == PHASE_PREDICTION_STR:
phase_name = self.experiment_phase.get_name()
missed_batches_dict = {}
sources_pieces_list = self.experiment_phase.get_sources_pieces()
workers_model_db_list = self.nerl_model_db.get_workers_model_db_list()
for source_piece_inst in sources_pieces_list:
source_name = source_piece_inst.get_source_name()

phase_name = self.experiment_phase.get_name()
missed_batches_dict = {}
sources_pieces_list = self.experiment_phase.get_sources_pieces()
workers_model_db_list = self.nerl_model_db.get_workers_model_db_list()
for source_piece_inst in sources_pieces_list:
source_name = source_piece_inst.get_source_name()
source_policy = globe.components.sources_policy_dict[source_name] # 0 -> casting , 1 -> round robin, 2 -> random
target_workers_string = source_piece_inst.get_target_workers()
target_workers_names = target_workers_string.split(',')
if source_policy == '0': # casting policy
for worker_db in workers_model_db_list:
worker_name = worker_db.get_worker_name()
if worker_name in target_workers_names: # Check if the worker is in the target workers list of this source
for batch_id in range(source_piece_inst.get_num_of_batches()):
batch_db = worker_db.get_batch(source_name, str(batch_id))
if not batch_db: # if batch is missing
missed_batch_key_str = missed_batches_key(phase_name, source_name, worker_name)
if missed_batch_key_str not in missed_batches_dict:
missed_batches_dict[missed_batch_key_str] = []
missed_batches_dict[missed_batch_key_str].append(batch_id)
elif source_policy == '1': # round robin policy
number_of_workers = len(target_workers_names)
batches_indexes = [i for i in range(source_piece_inst.get_num_of_batches())]
batch_worker_tuple = [(batch_index, target_workers_names[batch_index % number_of_workers]) for batch_index in batches_indexes] # (batch_index, worker_name_that_should_recive_the_batch)
worker_batches_dict = {worker_name: [] for worker_name in target_workers_names} # Create a dictionary to hold batches id for each worker
for batch_index, worker_name in batch_worker_tuple:
worker_batches_dict[worker_name].append(batch_index)
for worker_db in workers_model_db_list:
worker_name = worker_db.get_worker_name()
total_batches_per_source = worker_db.get_total_batches_per_source(source_name)
for batch_id in range(total_batches_per_source):
batch_db = worker_db.get_batch(source_name, str(batch_id))
if not batch_db: # if batch is missing
missed_batch_key_str = missed_batches_key(phase_name, source_name, worker_name)
if missed_batch_key_str not in missed_batches_dict:
missed_batches_dict[missed_batch_key_str] = []
missed_batches_dict[missed_batch_key_str].append(batch_id)
#print(f"missed_batches_dict: {missed_batches_dict}")
if worker_name in target_workers_names: # Check if the worker is in the target workers list of this source
for batch_id in worker_batches_dict[worker_name]:
batch_db = worker_db.get_batch(source_name, str(batch_id))
if not batch_db:
missed_batch_key_str = missed_batches_key(phase_name, source_name, worker_name)
if missed_batch_key_str not in missed_batches_dict:
missed_batches_dict[missed_batch_key_str] = []
missed_batches_dict[missed_batch_key_str].append(batch_id)
elif source_policy == '2': # random policy
LOG_INFO(f"Source {source_name} policy is random, it's not posiblle check for missed batches")
break
return missed_batches_dict

def get_communication_stats_workers(self):
# return dictionary of {worker : {communication_stats}}
communication_stats_workers_dict = OrderedDict()
Expand Down