diff --git a/src/algos/topologies/collections.py b/src/algos/topologies/collections.py index f324b48..b0ce8e9 100644 --- a/src/algos/topologies/collections.py +++ b/src/algos/topologies/collections.py @@ -4,7 +4,7 @@ from math import ceil, log2 import networkx as nx -from base_exponential import OnePeerExponentialGraph, HyperHyperCube, SimpleBaseGraph, BaseGraph +from algos.topologies.base_exponential import OnePeerExponentialGraph, HyperHyperCube, SimpleBaseGraph, BaseGraph class RingTopology(BaseTopology): @@ -96,25 +96,40 @@ def __init__(self, config: ConfigType, rank: int): super().__init__(config, rank) self.itr = -1 + def _convert_labels_to_int(self) -> None: + """ + Performs two operations: + 1. Convert the labels of the graph to integers - useful for grid like graphs where labels are tuples + 2. Convert the graph to use 1-based indexing - useful for indexing because we reserve 0 for the super node + """ + if self.graph is None: + raise ValueError("Graph not initialized") + self.graph = [nx.convert_node_labels_to_integers(graph, first_label=1) for graph in self.graph] # type: ignore + def generate_graph(self) -> None: raise NotImplementedError def _convert_weight_matrices_to_graph(self, w_list): - return [nx.from_numpy_array(w, create_using=nx.DiGraph) for w in w_list] + g_list = [] + for w in w_list: + G = nx.from_numpy_array(w.numpy(), create_using=nx.DiGraph) + G.remove_edges_from(nx.selfloop_edges(G)) + g_list.append(G) + return g_list def get_in_neighbors(self): """ Returns the list of in neighbours of the current node """ self.itr += 1 - return self.graph[self.itr%len(self.graph)].predecessors(self.rank) + return list(self.graph[self.itr%len(self.graph)].predecessors(self.rank)) def get_out_neighbors(self, i): """ Returns the list of out neighbours of the current node """ self.itr += 1 - return self.graph[self.itr%len(self.graph)].successors(i) + return list(self.graph[self.itr%len(self.graph)].successors(i)) def get_all_neighbours(self) -> List[int]: self.itr += 1 @@ -149,7 +164,7 @@ class HyperHyperCubeTopology(DynamicBaseGraph): def __init__(self, config: ConfigType, rank: int): super().__init__(config, rank) self.seed = config["seed"] - self.max_degree = config["topology"]["max_degree"] + self.max_degree = config["topology"].get("max_degree", 1) def generate_graph(self) -> None: super().generate_graph(HyperHyperCube, self.num_users, self.max_degree, self.seed) @@ -158,7 +173,7 @@ class SimpleBaseGraphTopology(DynamicBaseGraph): def __init__(self, config: ConfigType, rank: int): super().__init__(config, rank) self.seed = config["seed"] - self.max_degree = config["topology"]["max_degree"] + self.max_degree = config["topology"].get("max_degree", 1) self.inner_edges = config["topology"].get("inner_edges", True) def generate_graph(self) -> None: @@ -168,7 +183,7 @@ class BaseGraphTopology(DynamicBaseGraph): def __init__(self, config: ConfigType, rank: int): super().__init__(config, rank) self.seed = config["seed"] - self.max_degree = config["topology"]["max_degree"] + self.max_degree = config["topology"].get("max_degree", 1) self.inner_edges = config["topology"].get("inner_edges", True) def generate_graph(self) -> None: diff --git a/src/configs/algo_config.py b/src/configs/algo_config.py index 3d85900..366795e 100644 --- a/src/configs/algo_config.py +++ b/src/configs/algo_config.py @@ -203,6 +203,7 @@ def get_malicious_types(malicious_config_list: List[ConfigType]) -> Dict[str, st # Collaboration setup "algo": "fedstatic", "topology": {"name": "watts_strogatz", "k": 3, "p": 0.2}, # type: ignore + # "topology": {"name": "base_graph", "max_degree": 2}, # type: ignore "rounds": 3, # Model parameters diff --git a/src/inversefed/__init__.py b/src/inversefed/__init__.py index 0dacb9c..4e31def 100644 --- a/src/inversefed/__init__.py +++ b/src/inversefed/__init__.py @@ -3,7 +3,7 @@ from inversefed import nn from inversefed.nn import construct_model, MetaMonkey -from inversefed.data import construct_dataloaders +# from inversefed.data import construct_dataloaders from inversefed.training import train from inversefed import utils