Skip to content

Commit

Permalink
Add streaming aggregation for fedstatic-like algos (#145)
Browse files Browse the repository at this point in the history
  • Loading branch information
rishi-s8 authored Nov 15, 2024
1 parent 1ec39b6 commit 8e3b4ac
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 5 deletions.
104 changes: 99 additions & 5 deletions src/algos/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ def __init__(

self.stats : Dict[str, int | float | List[int]] = {}

self.streaming_aggregation = config.get("streaming_aggregation", False)

def set_constants(self) -> None:
"""Add docstring here"""
self.best_acc = 0.0
Expand Down Expand Up @@ -907,6 +909,39 @@ def aggregate(
self.set_model_weights(agg_wts)
return None

def aggregate_streaming(
self,
agg_wts: OrderedDict[str, Tensor],
model_wts: OrderedDict[str, Tensor],
coeff: float,
is_initialized: bool,
keys_to_ignore: List[str],
) -> None:
"""
Incrementally aggregates the model weights into the aggregation state.
Args:
agg_wts (OrderedDict[str, Tensor]): Aggregated weights (to be updated in place).
model_wts (OrderedDict[str, Tensor]): Weights of the current model to aggregate.
coeff (float): Collaboration weight for the current model.
is_initialized (bool): Whether the aggregation state is initialized.
keys_to_ignore (List[str]): Keys to ignore during aggregation.
Returns:
None
"""
for key in self.model.state_dict().keys():
if key in keys_to_ignore:
continue
if not is_initialized:
# Initialize the aggregation state
agg_wts[key] = coeff * model_wts[key].to(self.device)
else:
# Incrementally update the aggregation state
agg_wts[key] += coeff * model_wts[key].to(self.device)

return None

def receive_pushed_and_aggregate(self, remove_multi = True) -> None:
model_updates = self.comm_utils.receive_pushed()
if self.is_working:
Expand All @@ -921,12 +956,71 @@ def receive_pushed_and_aggregate(self, remove_multi = True) -> None:
# Aggregate the representations
self.aggregate(model_updates, keys_to_ignore=self.model_keys_to_ignore)

def receive_and_aggregate(self, neighbors: List[int]) -> None:
def receive_and_aggregate_streaming(self, neighbors: List[int]) -> None:
if self.is_working:
# Receive the model updates from the neighbors
model_updates = self.comm_utils.receive(node_ids=neighbors)
# Aggregate the representations
self.aggregate(model_updates, keys_to_ignore=self.model_keys_to_ignore)
# Initialize the aggregation state
agg_wts: OrderedDict[str, Tensor] = OrderedDict()
is_initialized = False
total_weight = 0.0 # To re-normalize weights after handling dropouts

# Include the current node's model in the aggregation
current_model_wts = self.get_model_weights()
assert "model" in current_model_wts, "Model not found in the current model."
current_model_wts = current_model_wts["model"]
current_weight = 1.0 / (len(neighbors) + 1) # Weight for the current node
self.aggregate_streaming(
agg_wts,
current_model_wts,
coeff=current_weight,
is_initialized=is_initialized,
keys_to_ignore=self.model_keys_to_ignore,
)
is_initialized = True
total_weight += current_weight

# Process models from neighbors one at a time
for neighbor in neighbors:
# Receive the model update from the current neighbor
model_update = self.comm_utils.receive(node_ids=[neighbor])
model_update, _ = self.strip_empty_models(model_update)
if len(model_update) == 0:
# Skip empty models (dropouts)
continue

model_update = model_update[0]
assert "model" in model_update, "Model not found in the received message"
model_wts = model_update["model"]

# Get the collaboration weight for the current neighbor
coeff = current_weight # Default weight

# Perform streaming aggregation for the current model
self.aggregate_streaming(
agg_wts,
model_wts,
coeff=coeff,
is_initialized=is_initialized,
keys_to_ignore=self.model_keys_to_ignore,
)
total_weight += coeff

# Re-normalize the aggregated weights if there were dropouts
if total_weight > 0:
for key in agg_wts.keys():
agg_wts[key] /= total_weight

# Update the model with the aggregated weights
self.set_model_weights(agg_wts)

def receive_and_aggregate(self, neighbors: List[int]) -> None:
if self.streaming_aggregation:
self.receive_and_aggregate_streaming(neighbors)
else:
if self.is_working:
# Receive the model updates from the neighbors
model_updates = self.comm_utils.receive(node_ids=neighbors)
# Aggregate the representations
self.aggregate(model_updates, keys_to_ignore=self.model_keys_to_ignore)


def get_collaborator_weights(
Expand Down
1 change: 1 addition & 0 deletions src/algos/swift.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def __init__(
self, config: Dict[str, Any], comm_utils: CommunicationManager
) -> None:
super().__init__(config, comm_utils)
assert self.streaming_aggregation == False, "Streaming aggregation not supported for push-based algorithms for now."

def run_protocol(self) -> None:
"""
Expand Down
1 change: 1 addition & 0 deletions src/configs/sys_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,7 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE):
"dropout_dicts": dropout_dicts,
"test_samples_per_user": 200,
"log_memory": True,
# "streaming_aggregation": True, # Make it true for fedstatic
"assign_based_on_host": True,
"hostname_to_device_ids": {
"matlaber1": [2, 3, 4, 5, 6, 7],
Expand Down

0 comments on commit 8e3b4ac

Please sign in to comment.