Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add streaming aggregation for fedstatic-like algos #145

Merged
merged 1 commit into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 99 additions & 5 deletions src/algos/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ def __init__(

self.stats : Dict[str, int | float | List[int]] = {}

self.streaming_aggregation = config.get("streaming_aggregation", False)

def set_constants(self) -> None:
"""Add docstring here"""
self.best_acc = 0.0
Expand Down Expand Up @@ -907,6 +909,39 @@ def aggregate(
self.set_model_weights(agg_wts)
return None

def aggregate_streaming(
self,
agg_wts: OrderedDict[str, Tensor],
model_wts: OrderedDict[str, Tensor],
coeff: float,
is_initialized: bool,
keys_to_ignore: List[str],
) -> None:
"""
Incrementally aggregates the model weights into the aggregation state.

Args:
agg_wts (OrderedDict[str, Tensor]): Aggregated weights (to be updated in place).
model_wts (OrderedDict[str, Tensor]): Weights of the current model to aggregate.
coeff (float): Collaboration weight for the current model.
is_initialized (bool): Whether the aggregation state is initialized.
keys_to_ignore (List[str]): Keys to ignore during aggregation.

Returns:
None
"""
for key in self.model.state_dict().keys():
if key in keys_to_ignore:
continue
if not is_initialized:
# Initialize the aggregation state
agg_wts[key] = coeff * model_wts[key].to(self.device)
else:
# Incrementally update the aggregation state
agg_wts[key] += coeff * model_wts[key].to(self.device)

return None

def receive_pushed_and_aggregate(self, remove_multi = True) -> None:
model_updates = self.comm_utils.receive_pushed()
if self.is_working:
Expand All @@ -921,12 +956,71 @@ def receive_pushed_and_aggregate(self, remove_multi = True) -> None:
# Aggregate the representations
self.aggregate(model_updates, keys_to_ignore=self.model_keys_to_ignore)

def receive_and_aggregate(self, neighbors: List[int]) -> None:
def receive_and_aggregate_streaming(self, neighbors: List[int]) -> None:
if self.is_working:
# Receive the model updates from the neighbors
model_updates = self.comm_utils.receive(node_ids=neighbors)
# Aggregate the representations
self.aggregate(model_updates, keys_to_ignore=self.model_keys_to_ignore)
# Initialize the aggregation state
agg_wts: OrderedDict[str, Tensor] = OrderedDict()
is_initialized = False
total_weight = 0.0 # To re-normalize weights after handling dropouts

# Include the current node's model in the aggregation
current_model_wts = self.get_model_weights()
assert "model" in current_model_wts, "Model not found in the current model."
current_model_wts = current_model_wts["model"]
current_weight = 1.0 / (len(neighbors) + 1) # Weight for the current node
self.aggregate_streaming(
agg_wts,
current_model_wts,
coeff=current_weight,
is_initialized=is_initialized,
keys_to_ignore=self.model_keys_to_ignore,
)
is_initialized = True
total_weight += current_weight

# Process models from neighbors one at a time
for neighbor in neighbors:
# Receive the model update from the current neighbor
model_update = self.comm_utils.receive(node_ids=[neighbor])
model_update, _ = self.strip_empty_models(model_update)
if len(model_update) == 0:
# Skip empty models (dropouts)
continue

model_update = model_update[0]
assert "model" in model_update, "Model not found in the received message"
model_wts = model_update["model"]

# Get the collaboration weight for the current neighbor
coeff = current_weight # Default weight

# Perform streaming aggregation for the current model
self.aggregate_streaming(
agg_wts,
model_wts,
coeff=coeff,
is_initialized=is_initialized,
keys_to_ignore=self.model_keys_to_ignore,
)
total_weight += coeff

# Re-normalize the aggregated weights if there were dropouts
if total_weight > 0:
for key in agg_wts.keys():
agg_wts[key] /= total_weight

# Update the model with the aggregated weights
self.set_model_weights(agg_wts)

def receive_and_aggregate(self, neighbors: List[int]) -> None:
if self.streaming_aggregation:
self.receive_and_aggregate_streaming(neighbors)
else:
if self.is_working:
# Receive the model updates from the neighbors
model_updates = self.comm_utils.receive(node_ids=neighbors)
# Aggregate the representations
self.aggregate(model_updates, keys_to_ignore=self.model_keys_to_ignore)


def get_collaborator_weights(
Expand Down
1 change: 1 addition & 0 deletions src/algos/swift.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def __init__(
self, config: Dict[str, Any], comm_utils: CommunicationManager
) -> None:
super().__init__(config, comm_utils)
assert self.streaming_aggregation == False, "Streaming aggregation not supported for push-based algorithms for now."

def run_protocol(self) -> None:
"""
Expand Down
1 change: 1 addition & 0 deletions src/configs/sys_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,7 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE):
"dropout_dicts": dropout_dicts,
"test_samples_per_user": 200,
"log_memory": True,
# "streaming_aggregation": True, # Make it true for fedstatic
"assign_based_on_host": True,
"hostname_to_device_ids": {
"matlaber1": [2, 3, 4, 5, 6, 7],
Expand Down
Loading