Skip to content

Commit

Permalink
Add exponential topologies
Browse files Browse the repository at this point in the history
  • Loading branch information
rishi-s8 committed Nov 20, 2024
1 parent 5e960ab commit d50d38c
Show file tree
Hide file tree
Showing 5 changed files with 520 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/algos/fl_static.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def get_neighbors(self) -> List[int]:
"""
Returns a list of neighbours for the client.
"""
neighbors = self.topology.sample_neighbours(self.num_collaborators)
neighbors = self.topology.sample_neighbours(self.num_collaborators, mode="pull")
self.stats["neighbors"] = neighbors

return neighbors
Expand Down
11 changes: 10 additions & 1 deletion src/algos/swift.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
Module for FedStaticClient and FedStaticServer in Federated Learning.
"""
from typing import Any, Dict, OrderedDict
from typing import Any, Dict, OrderedDict, List
from utils.communication.comm_utils import CommunicationManager
import torch
import time
Expand All @@ -21,6 +21,15 @@ def __init__(
super().__init__(config, comm_utils)
assert self.streaming_aggregation == False, "Streaming aggregation not supported for push-based algorithms for now."

def get_neighbors(self) -> List[int]:
"""
Returns a list of neighbours for the client.
"""
neighbors = self.topology.sample_neighbours(self.num_collaborators, mode="push")
self.stats["neighbors"] = neighbors

return neighbors

def run_protocol(self) -> None:
"""
Runs the federated learning protocol for the client.
Expand Down
37 changes: 34 additions & 3 deletions src/algos/topologies/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def __init__(self, config: ConfigType, rank: int) -> None:
self.config = config
self.rank = rank
self.num_users: int = self.config["num_users"] # type: ignore
self.graph: nx.Graph | None = None
self.graph: nx.Graph | List[nx.DiGraph] | None = None
self.neighbor_sample_generator = np.random.default_rng(seed=int(self.config["seed"])*10000 + self.rank ) # type: ignore

@abstractmethod
Expand Down Expand Up @@ -53,15 +53,46 @@ def get_all_neighbours(self) -> List[int]:
if self.graph is None:
raise ValueError("Graph not initialized")
return list(self.graph.neighbors(self.rank)) # type: ignore

def get_in_neighbors(self) -> List[int]:
"""
Returns the list of in neighbours of the current node
"""
return self.get_all_neighbours()

def get_out_neighbors(self) -> List[int]:
"""
Returns the list of out neighbours of the current node
"""
return self.get_all_neighbours()

def sample_neighbours(self, k: int) -> List[int]:
def sample_neighbours(self, k: int, mode = None) -> List[int]:
"""
Returns a random sample of k neighbours of the current node
If the number of neighbours is less than k, return all neighbours
Parameters
----------
k : int
Number of neighbours to sample
mode : str
Mode of sampling - "pull" or "push"
"pull" - Sample neighbours from the incoming edges
"push" - Sample neighbours from the outgoing edges
"""

if self.graph is None:
raise ValueError("Graph not initialized")
neighbours = self.get_all_neighbours()

if mode == "push":
neighbours = self.get_out_neighbors()
elif mode == "pull":
neighbours = self.get_in_neighbors()
else:
neighbours = self.get_all_neighbours()


if len(neighbours) <= k:
return neighbours
return self.neighbor_sample_generator.choice(neighbours, size=k, replace=False).tolist()
Expand Down
Loading

0 comments on commit d50d38c

Please sign in to comment.