diff --git a/src/algos/base_class.py b/src/algos/base_class.py index 026ed0d..78ee9bd 100644 --- a/src/algos/base_class.py +++ b/src/algos/base_class.py @@ -123,6 +123,8 @@ def __init__( self.stats : Dict[str, int | float | List[int]] = {} + self.streaming_aggregation = config.get("streaming_aggregation", False) + def set_constants(self) -> None: """Add docstring here""" self.best_acc = 0.0 @@ -907,6 +909,39 @@ def aggregate( self.set_model_weights(agg_wts) return None + def aggregate_streaming( + self, + agg_wts: OrderedDict[str, Tensor], + model_wts: OrderedDict[str, Tensor], + coeff: float, + is_initialized: bool, + keys_to_ignore: List[str], + ) -> None: + """ + Incrementally aggregates the model weights into the aggregation state. + + Args: + agg_wts (OrderedDict[str, Tensor]): Aggregated weights (to be updated in place). + model_wts (OrderedDict[str, Tensor]): Weights of the current model to aggregate. + coeff (float): Collaboration weight for the current model. + is_initialized (bool): Whether the aggregation state is initialized. + keys_to_ignore (List[str]): Keys to ignore during aggregation. + + Returns: + None + """ + for key in self.model.state_dict().keys(): + if key in keys_to_ignore: + continue + if not is_initialized: + # Initialize the aggregation state + agg_wts[key] = coeff * model_wts[key].to(self.device) + else: + # Incrementally update the aggregation state + agg_wts[key] += coeff * model_wts[key].to(self.device) + + return None + def receive_pushed_and_aggregate(self, remove_multi = True) -> None: model_updates = self.comm_utils.receive_pushed() if self.is_working: @@ -921,12 +956,71 @@ def receive_pushed_and_aggregate(self, remove_multi = True) -> None: # Aggregate the representations self.aggregate(model_updates, keys_to_ignore=self.model_keys_to_ignore) - def receive_and_aggregate(self, neighbors: List[int]) -> None: + def receive_and_aggregate_streaming(self, neighbors: List[int]) -> None: if self.is_working: - # Receive the model updates from the neighbors - model_updates = self.comm_utils.receive(node_ids=neighbors) - # Aggregate the representations - self.aggregate(model_updates, keys_to_ignore=self.model_keys_to_ignore) + # Initialize the aggregation state + agg_wts: OrderedDict[str, Tensor] = OrderedDict() + is_initialized = False + total_weight = 0.0 # To re-normalize weights after handling dropouts + + # Include the current node's model in the aggregation + current_model_wts = self.get_model_weights() + assert "model" in current_model_wts, "Model not found in the current model." + current_model_wts = current_model_wts["model"] + current_weight = 1.0 / (len(neighbors) + 1) # Weight for the current node + self.aggregate_streaming( + agg_wts, + current_model_wts, + coeff=current_weight, + is_initialized=is_initialized, + keys_to_ignore=self.model_keys_to_ignore, + ) + is_initialized = True + total_weight += current_weight + + # Process models from neighbors one at a time + for neighbor in neighbors: + # Receive the model update from the current neighbor + model_update = self.comm_utils.receive(node_ids=[neighbor]) + model_update, _ = self.strip_empty_models(model_update) + if len(model_update) == 0: + # Skip empty models (dropouts) + continue + + model_update = model_update[0] + assert "model" in model_update, "Model not found in the received message" + model_wts = model_update["model"] + + # Get the collaboration weight for the current neighbor + coeff = current_weight # Default weight + + # Perform streaming aggregation for the current model + self.aggregate_streaming( + agg_wts, + model_wts, + coeff=coeff, + is_initialized=is_initialized, + keys_to_ignore=self.model_keys_to_ignore, + ) + total_weight += coeff + + # Re-normalize the aggregated weights if there were dropouts + if total_weight > 0: + for key in agg_wts.keys(): + agg_wts[key] /= total_weight + + # Update the model with the aggregated weights + self.set_model_weights(agg_wts) + + def receive_and_aggregate(self, neighbors: List[int]) -> None: + if self.streaming_aggregation: + self.receive_and_aggregate_streaming(neighbors) + else: + if self.is_working: + # Receive the model updates from the neighbors + model_updates = self.comm_utils.receive(node_ids=neighbors) + # Aggregate the representations + self.aggregate(model_updates, keys_to_ignore=self.model_keys_to_ignore) def get_collaborator_weights( diff --git a/src/algos/swift.py b/src/algos/swift.py index 458e62f..3ea9462 100644 --- a/src/algos/swift.py +++ b/src/algos/swift.py @@ -19,6 +19,7 @@ def __init__( self, config: Dict[str, Any], comm_utils: CommunicationManager ) -> None: super().__init__(config, comm_utils) + assert self.streaming_aggregation == False, "Streaming aggregation not supported for push-based algorithms for now." def run_protocol(self) -> None: """ diff --git a/src/configs/sys_config.py b/src/configs/sys_config.py index 94c4c65..5ea4d69 100644 --- a/src/configs/sys_config.py +++ b/src/configs/sys_config.py @@ -355,6 +355,7 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE): "dropout_dicts": dropout_dicts, "test_samples_per_user": 200, "log_memory": True, + # "streaming_aggregation": True, # Make it true for fedstatic "assign_based_on_host": True, "hostname_to_device_ids": { "matlaber1": [2, 3, 4, 5, 6, 7],