From d50d38c14b82e9337344fe5222b9984030c583e1 Mon Sep 17 00:00:00 2001 From: Rishi Sharma Date: Wed, 20 Nov 2024 21:52:49 +0100 Subject: [PATCH] Add exponential topologies --- src/algos/fl_static.py | 2 +- src/algos/swift.py | 11 +- src/algos/topologies/base.py | 37 ++- src/algos/topologies/base_exponential.py | 382 +++++++++++++++++++++++ src/algos/topologies/collections.py | 94 +++++- 5 files changed, 520 insertions(+), 6 deletions(-) create mode 100644 src/algos/topologies/base_exponential.py diff --git a/src/algos/fl_static.py b/src/algos/fl_static.py index 9a476093..b6649648 100644 --- a/src/algos/fl_static.py +++ b/src/algos/fl_static.py @@ -28,7 +28,7 @@ def get_neighbors(self) -> List[int]: """ Returns a list of neighbours for the client. """ - neighbors = self.topology.sample_neighbours(self.num_collaborators) + neighbors = self.topology.sample_neighbours(self.num_collaborators, mode="pull") self.stats["neighbors"] = neighbors return neighbors diff --git a/src/algos/swift.py b/src/algos/swift.py index 3ea9462a..9cdbf709 100644 --- a/src/algos/swift.py +++ b/src/algos/swift.py @@ -1,7 +1,7 @@ """ Module for FedStaticClient and FedStaticServer in Federated Learning. """ -from typing import Any, Dict, OrderedDict +from typing import Any, Dict, OrderedDict, List from utils.communication.comm_utils import CommunicationManager import torch import time @@ -21,6 +21,15 @@ def __init__( super().__init__(config, comm_utils) assert self.streaming_aggregation == False, "Streaming aggregation not supported for push-based algorithms for now." + def get_neighbors(self) -> List[int]: + """ + Returns a list of neighbours for the client. + """ + neighbors = self.topology.sample_neighbours(self.num_collaborators, mode="push") + self.stats["neighbors"] = neighbors + + return neighbors + def run_protocol(self) -> None: """ Runs the federated learning protocol for the client. diff --git a/src/algos/topologies/base.py b/src/algos/topologies/base.py index fdc86aef..f80a8b5c 100644 --- a/src/algos/topologies/base.py +++ b/src/algos/topologies/base.py @@ -15,7 +15,7 @@ def __init__(self, config: ConfigType, rank: int) -> None: self.config = config self.rank = rank self.num_users: int = self.config["num_users"] # type: ignore - self.graph: nx.Graph | None = None + self.graph: nx.Graph | List[nx.DiGraph] | None = None self.neighbor_sample_generator = np.random.default_rng(seed=int(self.config["seed"])*10000 + self.rank ) # type: ignore @abstractmethod @@ -53,15 +53,46 @@ def get_all_neighbours(self) -> List[int]: if self.graph is None: raise ValueError("Graph not initialized") return list(self.graph.neighbors(self.rank)) # type: ignore + + def get_in_neighbors(self) -> List[int]: + """ + Returns the list of in neighbours of the current node + """ + return self.get_all_neighbours() + + def get_out_neighbors(self) -> List[int]: + """ + Returns the list of out neighbours of the current node + """ + return self.get_all_neighbours() - def sample_neighbours(self, k: int) -> List[int]: + def sample_neighbours(self, k: int, mode = None) -> List[int]: """ Returns a random sample of k neighbours of the current node If the number of neighbours is less than k, return all neighbours + + Parameters + ---------- + k : int + Number of neighbours to sample + mode : str + Mode of sampling - "pull" or "push" + "pull" - Sample neighbours from the incoming edges + "push" - Sample neighbours from the outgoing edges + """ + if self.graph is None: raise ValueError("Graph not initialized") - neighbours = self.get_all_neighbours() + + if mode == "push": + neighbours = self.get_out_neighbors() + elif mode == "pull": + neighbours = self.get_in_neighbors() + else: + neighbours = self.get_all_neighbours() + + if len(neighbours) <= k: return neighbours return self.neighbor_sample_generator.choice(neighbours, size=k, replace=False).tolist() diff --git a/src/algos/topologies/base_exponential.py b/src/algos/topologies/base_exponential.py new file mode 100644 index 00000000..dd2b240d --- /dev/null +++ b/src/algos/topologies/base_exponential.py @@ -0,0 +1,382 @@ +# Taken from https://github.com/yukiTakezawa/BaseGraph + +# Part of @inproceedings{takezawa2023exponential, +# title={Beyond Exponential Graph: Communication-Efficient Topologies for Decentralized Learning via Finite-time Convergence}, +# author={Yuki Takezawa and Ryoma Sato and Han Bao and Kenta Niwa and Makoto Yamada}, +# year={2023}, +# booktitle={NeurIPS} +#} + +import torch +import math +import copy +import sympy +import numpy as np +import networkx as nx +import math + +class DynamicGraph(): + def __init__(self, w_list): + """ + Parameter + -------- + w_list (list of torch.tensor): + list of mixing matrix + """ + self.w_list = w_list + self.n_nodes = w_list[0].size()[0] + self.length = len(w_list) + self.itr = 0 + + def get_in_neighbors(self, i): + """ + Parameter + ---------- + i (int): + a node index + Return + ---------- + dictionary of (neighbors's index: weight of the edge (i,j)) + """ + w = self.w_list[self.itr%self.length] + + return {idx.item(): w[idx, i].item() for idx in torch.nonzero(w[:,i])} + + def get_out_neighbors(self, i): + """ + Parameter + ---------- + i (int): + a node index + Return + ---------- + dictionary of (neighbors's index: weight of the edge (i,j)) + """ + w = self.w_list[self.itr%self.length] + + return {idx.item(): w[i,idx].item() for idx in torch.nonzero(w[i])} + + + def get_neighbors(self, i): + in_neighbors = self.get_in_neighbors(i) + out_neighbors = self.get_out_neighbors(i) + self.itr += 1 + return in_neighbors, out_neighbors + + + def get_w(self): + w = self.w_list[self.itr%self.length] + self.itr += 1 + return w + + +class OnePeerExponentialGraph(DynamicGraph): + def __init__(self, n_nodes): + w_list = [] + + n_neighbors = int(math.log2(n_nodes-1)) + + for j in range(n_neighbors+1): + + w = torch.zeros((n_nodes, n_nodes)) + for i in range(n_nodes): + w[i,i] = 1/2 + w[i, (i+2**j)%n_nodes] = 1/2 + + w_list.append(w) + + super().__init__(w_list) + + +class HyperHyperCube(DynamicGraph): + def __init__(self, n_nodes, seed=0, max_degree=1): + self.state = np.random.RandomState(seed) + self.max_degree = max_degree + + if n_nodes == 1: + super().__init__([torch.eye(1)]) + else: + if list(sympy.factorint(n_nodes))[-1] > max_degree+1: + print(f"Can not construct {max_degree}-peer graphs") + + node_list = list(range(n_nodes)) + factors_list = self.split_node(node_list, n_nodes) + #print(factors_list) + super().__init__(self.construct(node_list, factors_list, n_nodes)) + + def construct(self, node_list, factors_list, n_nodes): + w_list = [] + for k in range(len(factors_list)): + #print(factors_list) + + w = torch.zeros((n_nodes, n_nodes)) + b = torch.zeros(n_nodes) + + for i_idx in range(len(node_list)): + for nk in range(1, factors_list[k]): + + i = node_list[i_idx] + j = int(i + np.prod(factors_list[:k]) * nk) % n_nodes + + if b[i] < factors_list[k]-1 and b[j] < factors_list[k]-1: + #print("g", i, j, b[i], b[j]) + b[i] += 1 + b[j] += 1 + + w[i, j], w[j, i] = 1/factors_list[k], 1/factors_list[k] + w[i, i], w[j, j] = 1/factors_list[k], 1/factors_list[k] + + w_list.append(w) + + return w_list + + + def split_node(self, node_list, n_nodes): + factors_list = [] + rest = n_nodes + + for factor in reversed(range(2, self.max_degree+2)): + while rest % factor == 0: + factors_list.append(factor) + rest = int(rest / factor) + + if rest == 1: + break + + factors_list.reverse() + return factors_list + + +class SimpleBaseGraph(DynamicGraph): + def __init__(self, n_nodes, max_degree=1, seed=0, inner_edges=True): + self.state = np.random.RandomState(seed) + self.inner_edges = inner_edges + self.max_degree = max_degree + self.n_nodes = n_nodes + + super().__init__(self.construct()) + + def construct(self): + node_list_list, n_nodes_list = self.split_nodes() + node_list_list_list = self.split_nodes2(node_list_list) + L = len(node_list_list) + + if self.n_nodes == 1: + return [torch.eye(1)] + elif max(list(sympy.factorint(self.n_nodes))) <= self.max_degree + 1: + return HyperHyperCube(self.n_nodes, max_degree=self.max_degree).w_list + + # construct k-peer HyperHyperCube + hyperhyper_cubes = [HyperHyperCube(len(node_list_list[i]), max_degree=self.max_degree) for i in range(L)] + hyperhyper_cubes2 = [HyperHyperCube(len(node_list_list_list[i][0]), max_degree=self.max_degree) for i in range(L)] + max_length_of_hyper = len(hyperhyper_cubes[0].w_list) + + b = torch.zeros(L) + true_b = torch.tensor([len(hyperhyper_cube.w_list) for hyperhyper_cube in hyperhyper_cubes2]) + + w_list = [] + m = -1 + while True: + m += 1 + w = torch.zeros((self.n_nodes, self.n_nodes)) + isolated_nodes = None + all_isolated_nodes = None + + for l in reversed(range(L)): + + if m < max_length_of_hyper: + length = len(hyperhyper_cubes[l].w_list) + w += self.extend(hyperhyper_cubes[l].w_list[m % length], node_list_list[l]) + + elif m < max_length_of_hyper + l: + if isolated_nodes is None: + isolated_nodes = copy.deepcopy(node_list_list_list[m - max_length_of_hyper]) + all_isolated_nodes = [node for nodes in isolated_nodes for node in nodes] + + for i in node_list_list[l]: + a_l = len(isolated_nodes) + + for k in range(a_l): + j = isolated_nodes[k].pop(-1) + all_isolated_nodes.remove(j) + w[i, j] = n_nodes_list[m - max_length_of_hyper] / sum(n_nodes_list[m - max_length_of_hyper:]) / a_l + w[j, i] = n_nodes_list[m - max_length_of_hyper] / sum(n_nodes_list[m - max_length_of_hyper:]) / a_l + + w[j, j] = 1 - w[i, j] + w[i, i] = 1 - n_nodes_list[m - max_length_of_hyper] / sum(n_nodes_list[m - max_length_of_hyper:]) + + elif m == max_length_of_hyper + l and l != L-1: + while len(all_isolated_nodes) > 1 and self.inner_edges: + sampled_nodes = all_isolated_nodes[:min(self.max_degree+1,len(all_isolated_nodes))] + + for node_id in sampled_nodes: + all_isolated_nodes.remove(node_id) + + for i in sampled_nodes: + for j in sampled_nodes: + w[i, j] = 1 / len(sampled_nodes) + w[j, i] = 1 / len(sampled_nodes) + w[i, i] = 1 / len(sampled_nodes) + w[j, j] = 1 / len(sampled_nodes) + + else: + if n_nodes_list[l] < self.max_degree+1: + length = len(hyperhyper_cubes[l].w_list) + w += self.extend(hyperhyper_cubes[l].w_list[int(b[l] % length)], node_list_list[l]) + else: + a_l = len(node_list_list_list[l]) + + for k in range(a_l): + length = len(hyperhyper_cubes2[l].w_list) + w += self.extend(hyperhyper_cubes2[l].w_list[int(b[l] % length)], node_list_list_list[l][k]) + + b[l] += 1 + + # add self-loop + for i in range(self.n_nodes): + if w[i, i] == 0: + w[i,i] = 1.0 + w_list.append(w) + + #if (b >= true_b).all(): + # break + if b[0] == len(hyperhyper_cubes2[0].w_list): + break + + return w_list + + def diag(self, X, Y): + new_W = torch.zeros((X.size()[0] + Y.size()[0], X.size()[0] + Y.size()[0])) + new_W[0:X.size()[0], 0:X.size()[0]] = X + new_W[X.size()[0]:, X.size()[0]:] = Y + return new_W + + + def extend(self, w, node_list): + new_w = torch.zeros((self.n_nodes, self.n_nodes)) + for i in range(len(node_list)): + for j in range(len(node_list)): + new_w[node_list[i], node_list[j]] = w[i, j] + return new_w + + def split_nodes(self): + factor = (self.max_degree + 1)**int(math.log(self.n_nodes, self.max_degree+1)) + n_nodes_list = [] + + while sum(n_nodes_list) != self.n_nodes: + + rest = self.n_nodes - sum(n_nodes_list) + + if rest >= factor: + n_nodes_list.append((rest // factor) * factor) + factor = int(factor/(self.max_degree + 1)) + node_list = list(range(self.n_nodes)) + node_list_list = [] + for i in range(len(n_nodes_list)): + node_list_list.append(node_list[sum(n_nodes_list[:i]):sum(n_nodes_list[:i+1])]) + + return node_list_list, n_nodes_list + + + def split_nodes2(self, node_list_list): + """ + len(node_list) can be written as a_l * (max_degree + 1)^{p_l} where al \in \{1, 2, \cdots, k\}. + """ + + node_list_list_list = [] + + for node_list in node_list_list: + n_nodes = len(node_list) + power = math.gcd(n_nodes, (self.max_degree+1) ** int(math.log(n_nodes, self.max_degree+1))) + rest = int(n_nodes / power) + + node_list_list_list.append([]) + for i in range(rest): + node_list_list_list[-1].append(node_list[i*power:(i+1)*power]) + + return node_list_list_list + + +class BaseGraph(DynamicGraph): + def __init__(self, n_nodes, max_degree=1, seed=0, inner_edges=True): + self.state = np.random.RandomState(seed) + self.inner_edges = inner_edges + self.max_degree = max_degree + self.n_nodes = n_nodes + self.seed = seed + + super().__init__(self.construct()) + + def construct(self): + node_list_list1, node_list_list2, n_power, n_rest = self.split_nodes() + + simple_adics = [SimpleBaseGraph(len(node_list_list1[i]), max_degree=self.max_degree) for i in range(n_power)] + hyper_cubes = [HyperHyperCube(len(node_list_list2[i]), max_degree=self.max_degree) for i in range(n_rest)] + + # check which is better + g = SimpleBaseGraph(self.n_nodes, max_degree=self.max_degree, seed=self.seed, inner_edges=self.inner_edges) + if len(g.w_list) < len(simple_adics[0].w_list) + len(hyper_cubes[0].w_list): + return g.w_list + + + w_list = [] + for m in range(len(simple_adics[0].w_list)): + w = torch.zeros((self.n_nodes, self.n_nodes)) + + for l in range(n_power): + w += self.extend(simple_adics[l].w_list[m], node_list_list1[l]) + w_list.append(w) + + for m in range(len(hyper_cubes[0].w_list)): + w = torch.zeros((self.n_nodes, self.n_nodes)) + + for l in range(n_rest): + w += self.extend(hyper_cubes[l].w_list[m], node_list_list2[l]) + w_list.append(w) + + return w_list + + + def diag(self, X, Y): + new_W = torch.zeros((X.size()[0] + Y.size()[0], X.size()[0] + Y.size()[0])) + new_W[0:X.size()[0], 0:X.size()[0]] = X + new_W[X.size()[0]:, X.size()[0]:] = Y + return new_W + + + def extend(self, w, node_list): + new_w = torch.zeros((self.n_nodes, self.n_nodes)) + for i in range(len(node_list)): + for j in range(len(node_list)): + new_w[node_list[i], node_list[j]] = w[i, j] + return new_w + + + def split_nodes(self): + factors = [n**int(math.log(self.n_nodes, n)) for n in range(2, self.max_degree+2)] + factor = np.prod(factors) + n_power = math.gcd(self.n_nodes, factor) + n_rest = int(self.n_nodes / n_power) + + node_list = list(range(self.n_nodes)) + node_list_list1 = [] + for i in range(n_power): + node_list_list1.append(node_list[n_rest*i:n_rest*(i+1)]) + + node_list_list2 = [[] for _ in range(n_rest)] + for i in range(n_power): + for j in range(n_rest): + node_list_list2[j].append(node_list_list1[i][j]) + + return node_list_list1, node_list_list2, n_power, n_rest + + + def get_neighbors(self, i): + in_neighbors = self.get_in_neighbors(i) + out_neighbors = self.get_out_neighbors(i) + self.itr += 1 + + #if self.itr % len(self.w_list) == 0: + # self.w_list = self.shuffle_node_index(self.w_list, self.n_nodes) + + return in_neighbors, out_neighbors \ No newline at end of file diff --git a/src/algos/topologies/collections.py b/src/algos/topologies/collections.py index 632f42b4..f324b480 100644 --- a/src/algos/topologies/collections.py +++ b/src/algos/topologies/collections.py @@ -1,8 +1,10 @@ +from typing import List from algos.topologies.base import BaseTopology from utils.types import ConfigType -from math import ceil +from math import ceil, log2 import networkx as nx +from base_exponential import OnePeerExponentialGraph, HyperHyperCube, SimpleBaseGraph, BaseGraph class RingTopology(BaseTopology): @@ -89,6 +91,88 @@ def __init__(self, config: ConfigType, rank: int): def generate_graph(self) -> None: self.graph = nx.watts_strogatz_graph(self.num_users, self.k, self.p, self.seed) # type: ignore +class DynamicGraph(BaseTopology): + def __init__(self, config: ConfigType, rank: int): + super().__init__(config, rank) + self.itr = -1 + + def generate_graph(self) -> None: + raise NotImplementedError + + def _convert_weight_matrices_to_graph(self, w_list): + return [nx.from_numpy_array(w, create_using=nx.DiGraph) for w in w_list] + + def get_in_neighbors(self): + """ + Returns the list of in neighbours of the current node + """ + self.itr += 1 + return self.graph[self.itr%len(self.graph)].predecessors(self.rank) + + def get_out_neighbors(self, i): + """ + Returns the list of out neighbours of the current node + """ + self.itr += 1 + return self.graph[self.itr%len(self.graph)].successors(i) + + def get_all_neighbours(self) -> List[int]: + self.itr += 1 + return self.get_in_neighbors(self.rank) + +class DynamicBaseGraph(DynamicGraph): + def __init__(self, config: ConfigType, rank: int): + super().__init__(config, rank) + + def generate_graph(self, class_name, *args, **kwargs) -> None: + w_list = class_name(*args, **kwargs).w_list + self.graph = self._convert_weight_matrices_to_graph(w_list) + + +class OnePeerExponentialTopology(DynamicBaseGraph): + def __init__(self, config: ConfigType, rank: int): + super().__init__(config, rank) + + def generate_graph(self) -> None: + super().generate_graph(OnePeerExponentialGraph, self.num_users) + + # def generate_graph(self) -> None: + # self.graph = [nx.DiGraph() for _ in range(self.num_users)] + + # num_neighbors = int(log2(self.num_users-1)) + + # for j in range(num_neighbors+1): + # for i in range(self.num_users): + # self.graph[j].add_edge(self.rank, (i+2**j)%self.num_users) + +class HyperHyperCubeTopology(DynamicBaseGraph): + def __init__(self, config: ConfigType, rank: int): + super().__init__(config, rank) + self.seed = config["seed"] + self.max_degree = config["topology"]["max_degree"] + + def generate_graph(self) -> None: + super().generate_graph(HyperHyperCube, self.num_users, self.max_degree, self.seed) + +class SimpleBaseGraphTopology(DynamicBaseGraph): + def __init__(self, config: ConfigType, rank: int): + super().__init__(config, rank) + self.seed = config["seed"] + self.max_degree = config["topology"]["max_degree"] + self.inner_edges = config["topology"].get("inner_edges", True) + + def generate_graph(self) -> None: + super().generate_graph(SimpleBaseGraph, self.num_users, self.max_degree, self.seed, self.inner_edges) + +class BaseGraphTopology(DynamicBaseGraph): + def __init__(self, config: ConfigType, rank: int): + super().__init__(config, rank) + self.seed = config["seed"] + self.max_degree = config["topology"]["max_degree"] + self.inner_edges = config["topology"].get("inner_edges", True) + + def generate_graph(self) -> None: + super().generate_graph(BaseGraph, self.num_users, self.max_degree, self.seed, self.inner_edges) def select_topology(config: ConfigType, rank: int) -> BaseTopology: """ @@ -113,4 +197,12 @@ def select_topology(config: ConfigType, rank: int) -> BaseTopology: return WattsStrogatzTopology(config, rank) if topology_name == "tree": return TreeTopology(config, rank) + if topology_name == "one_peer_exponential": + return OnePeerExponentialTopology(config, rank) + if topology_name == "hyper_hypercube": + return HyperHyperCubeTopology(config, rank) + if topology_name == "simple_base_graph": + return SimpleBaseGraphTopology(config, rank) + if topology_name == "base_graph": + return BaseGraphTopology(config, rank) raise ValueError(f"Topology {topology_name} not implemented") \ No newline at end of file