Skip to content

Commit

Permalink
Noniid config (#17)
Browse files Browse the repository at this point in the history
* configs work for fed weight non iid

* fixed some inconsistent names

* added fed weight and fed data repr algos

* added sys config for non iid and support distr

* fixed names to allow non iid data

* cleaned

* modified data utils to work with system

* added defkt

* update docs (#29)

* update gitignore (#30)

* fix mkdocs

* Add tasks page

* Update mkdocs yml

* restructuring docs

* add l2c description

* add collabench paper

* revamp docs

* Rename Logging.md to logging.md

* Object Detection Support (#28)

* data loader for pascal dataset

* model for yolo

* added albumentations

* ignoring pascal images

* fixed data loaders

* added path for pascal dataset

* support for training for pascal dset

* settings for object detection and yolo model

* small fix

* adding yolo model to model utils

* exporting the model function

* export model

* changed num samples per user

* changed NUM_CLS to num_cls

* algos updated setting

* configs for yolo and pascal

* enabled support for pretrained yolo:

* support for testing for pascal

* modularize test and train

* modularized

* configs for yolo

* separated the configs for object detect

---------

Co-authored-by: jyuan24 <[email protected]>

* New grpc (#34)

* reduce logging

* revamp communication interface; mpi version works in this commit

* enable strict type checking

* first version of grpc working

* automate configuration

* first version working end to end for fl.py

* add instructions for later

* reduce memory footprint and improve type checking

* keep a small model as default

* Object detection docs (#36)

* object detection docs

* small fix

* revamp docs

* minor fix

* include datasets page (#39)

* add text classification information

* datasets update

* three tables matching three figures in paper

* fixed format

* test

* test

* test

* fixed table 3

* write up fixes

* Update README.md

* Update README.md

* image classification writeup (#43)

* set up development tab

* restructure

* remove view and edit page access

* three tables matching three figures in paper (#47)

* tables, config page

* add feature table (#49)

* config

* getting started dir struct

* add customize.md

* feature comparison table (#53)

* Multiple machines (#55)

* define config data type

* improve grpc execution

* fix missing attribute bug

* add capability to run across multiple machines

* update docs

* update the docs

* configs work for fed weight non iid

* added fed weight and fed data repr algos

* added sys config for non iid and support distr

* added defkt

* updated config to current code

* algo update

* minor fix

* modified algos but comms isn't working

* support for tags, and list of ids

* defkt runs

* fed isolated algo

* modified fed central to work with comm utils

* fed iso and fed central

* modified l2c to fit comms protocol

* add train dset attribute

* fed central finally working

* changed clients to users for all files

* config file updates for l2c and fed central

* changed clients to users

* swarm algo updated to current comm utils protocol

* fl weight updated to new comm utils

* fed weight and swarm works :)

* fl static upgraded to new comms

* fl static works now :)

* took out some unsupported algos

* fix key errors for num_users < 4

* algo config updated

* static.py run

---------

Co-authored-by: jyuan24 <[email protected]>
Co-authored-by: Gautam Jajoo <[email protected]>
Co-authored-by: ishaan <[email protected]>
Co-authored-by: Abhishek Singh <[email protected]>
  • Loading branch information
5 people authored Aug 26, 2024
1 parent dbea41e commit 71d7d19
Show file tree
Hide file tree
Showing 25 changed files with 1,367 additions and 613 deletions.
20 changes: 5 additions & 15 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,6 @@ The application currently uses MPI and GRPC (experimental) to enable communicati
| 4 | 61.5723 | 50.2906 |
| 6 | 69.4671 | 47.7867 |

---

**AUC Camelyon17 (30 users, 200 rounds)**

| Num Domains | Within Domain | Random |
Expand All @@ -56,21 +54,17 @@ The application currently uses MPI and GRPC (experimental) to enable communicati
| 3 | 179.1761 | 153.0658 |
| 5 | 176.5059 | 139.4547 |

---

**AUC Digit-Five (30 users, 200 rounds)**

| Num Domains | Within Domain | Random |
|-------------|---------------|----------|
| 2 | 71.8536 | 65.6555 |
| 3 | 74.4239 | 72.6996 |
| 5 | 77.3709 | 76.3041 |
| Num Domains | Within Domain | Random |
|-------------|---------------------|---------------|
| 2 | 71.8536 | 65.6555 |
| 3 | 74.4239 | 72.6996 |
| 5 | 77.3709 | 76.3041 |


**Table 4** Test Accuracy and Standard Deviation Over Rounds

**DomainNet (39 users, 3 domains)**

| Rounds | Within Domain | | Random | |
|--------|---------------|-----------|---------------|-----------|
| | Mean | Std | Mean | Std |
Expand All @@ -80,8 +74,6 @@ The application currently uses MPI and GRPC (experimental) to enable communicati
| 400 | 0.4353 | 0.0687 | 0.4355 | 0.0585 |
| 500 | 0.4726 | 0.0502 | 0.4499 | 0.0496 |

---

**Camelyon17 (39 users, 3 domains)**

| Rounds | Within Domain | | Random | |
Expand All @@ -93,8 +85,6 @@ The application currently uses MPI and GRPC (experimental) to enable communicati
| 160 | 0.9361 | 0.0239 | 0.8122 | 0.1346 |
| 200 | 0.9353 | 0.0251 | 0.7762 | 0.1516 |

---

**Digit-Five (39 users, 3 domains)**

| Rounds | Within Domain | | Random | |
Expand Down
2 changes: 2 additions & 0 deletions docs/logging.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ This documentation aims to provide transparency on the logging mechanisms implem
- Scalar Logging: Logs scalar values to TensorBoard for tracking metrics(loss, accuracy) over time.
- Image Logging: Logs images to both a file and TensorBoard for visual analysis.

The tensorboard logs can be viewed by running tensorboard as follows: `tensorboard --logdir=expt_dump/ --host 0.0.0.0`. Assuming `expt_dump` is the folder where the experiment logs are stored.

## Log Sources

| Component/Module | Data Logged | Log Level | Format | Storage Location | Frequency/Trigger |
Expand Down
31 changes: 16 additions & 15 deletions src/algos/L2C.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections import OrderedDict
from typing import Any, Dict, List
from utils.communication.comm_utils import CommunicationManager
import torch
import numpy as np
from torch import Tensor, cat, tensor, optim
Expand All @@ -13,8 +14,8 @@


class L2CClient(BaseFedAvgClient):
def __init__(self, config) -> None:
super().__init__(config)
def __init__(self, config: Dict[str, Any], comm_utils: CommunicationManager) -> None:
super().__init__(config, comm_utils)

self.init_collab_weights()
self.sharing_mode = self.config["sharing"]
Expand Down Expand Up @@ -214,22 +215,22 @@ def run_protocol(self):
}

# Wait on server to start the round
self.comm_utils.wait_for_signal(
src=self.server_node, tag=self.tag.ROUND_START
self.comm_utils.receive(
node_ids=self.server_node, tag=self.tag.ROUND_START
)

# Train locally and send the representation to the server
round_stats["train_loss"], round_stats["train_acc"] = self.local_train(
epochs_per_round
)
repr = self.get_representation()
self.comm_utils.send_signal(
self.comm_utils.send(
dest=self.server_node, data=repr, tag=self.tag.REPR_ADVERT
)

# Collect the representations from all other nodes from the server
reprs = self.comm_utils.wait_for_signal(
src=self.server_node, tag=self.tag.REPRS_SHARE
reprs = self.comm_utils.receive(
node_ids=self.server_node, tag=self.tag.REPRS_SHARE
)
# In the future this dict might be generated by the server to send
# only requested models
Expand Down Expand Up @@ -264,16 +265,16 @@ def run_protocol(self):

# Lower the number of neighbors
if round == self.config["T_0"]:
self.filter_out_worse_neighbors(self.config["target_clients_after_T_0"])
self.filter_out_worse_neighbors(self.config["target_users_after_T_0"])

self.comm_utils.send_signal(
self.comm_utils.send(
dest=self.server_node, data=round_stats, tag=self.tag.ROUND_STATS
)


class L2CServer(BaseFedAvgServer):
def __init__(self, config) -> None:
super().__init__(config)
def __init__(self, config: Dict[str, Any], comm_utils: CommunicationManager) -> None:
super().__init__(config, comm_utils)
# self.set_parameters()
self.config = config
self.set_model_parameters(config)
Expand Down Expand Up @@ -301,23 +302,23 @@ def single_round(self):

# Send signal to all clients to start local training
for client_node in self.users:
self.comm_utils.send_signal(
self.comm_utils.send(
dest=client_node, data=None, tag=self.tag.ROUND_START
)
self.log_utils.log_console(
"Server waiting for all clients to finish local training"
)

# Collect representations (from all clients
reprs = self.comm_utils.wait_for_all_clients(self.users, self.tag.REPR_ADVERT)
reprs = self.comm_utils.all_gather(self.tag.REPR_ADVERT)
self.log_utils.log_console("Server received all clients models")

# Broadcast the representations to all clients
self.send_representations(reprs)

# Collect round stats from all clients
round_stats = self.comm_utils.wait_for_all_clients(
self.users, self.tag.ROUND_STATS
round_stats = self.comm_utils.all_gather(
self.tag.ROUND_STATS
)
self.log_utils.log_console("Server received all clients stats")

Expand Down
2 changes: 1 addition & 1 deletion src/algos/MetaL2C.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def run_protocol(self):

if round == self.config["T_0"]:
self.filter_out_worse_neighbors(
self.config["target_clients_after_T_0"], collab_weights_dict
self.config["target_users_after_T_0"], collab_weights_dict
)

if self.sharing_mode == "updates":
Expand Down
12 changes: 6 additions & 6 deletions src/algos/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ def is_same_dest(dset):
self.val_dloader = DataLoader(val_dset, batch_size=batch_size, shuffle=True)

self.train_indices = train_indices
self.train_dset = train_dset
# self.dloader = DataLoader(train_dset, batch_size=batch_size*len(self.device_ids), shuffle=True)
self.dloader = DataLoader(train_dset, batch_size=batch_size, shuffle=True)

Expand Down Expand Up @@ -484,9 +485,8 @@ class BaseFedAvgClient(BaseClient):
"""
Abstract class for FedAvg based algorithms
"""

def __init__(self, config, comm_protocol=CommProtocol) -> None:
super().__init__(config)
def __init__(self, config: Dict[str, Any], comm_utils: CommunicationManager, comm_protocol=CommProtocol) -> None:
super().__init__(config, comm_utils)
self.config = config
self.model_save_path = "{}/saved_models/node_{}.pt".format(
self.config["results_path"], self.node_id
Expand Down Expand Up @@ -659,9 +659,9 @@ class BaseFedAvgServer(BaseServer):
"""
Abstract class for orchestrator
"""

def __init__(self, config, comm_protocol=CommProtocol) -> None:
super().__init__(config)
def __init__(self, config: Dict[str, Any], comm_utils: CommunicationManager, comm_protocol=CommProtocol) -> None:
super().__init__(config, comm_utils)
self.tag = comm_protocol

def send_representations(self, representations: Dict[int, OrderedDict[str, Tensor]]):
Expand Down
33 changes: 15 additions & 18 deletions src/algos/def_kt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
import copy
import random
from collections import OrderedDict
from typing import List
from typing import Any, Dict, List
from torch import Tensor
from utils.communication.comm_utils import CommunicationManager
import torch.nn as nn

from algos.base_class import BaseClient, BaseServer
Expand All @@ -28,8 +29,8 @@ class DefKTClient(BaseClient):
"""
Client class for DefKT (Deep Mutual Learning with Knowledge Transfer)
"""
def __init__(self, config) -> None:
super().__init__(config)
def __init__(self, config: Dict[str, Any], comm_utils: CommunicationManager) -> None:
super().__init__(config, comm_utils)
self.config = config
self.tag = CommProtocol
self.model_save_path = f"{self.config['results_path']}/saved_models/node_{self.node_id}.pt"
Expand Down Expand Up @@ -112,15 +113,15 @@ def send_representations(self, representation):
Send the model representations to the clients
"""
for client_node in self.clients:
self.comm_utils.send_signal(client_node, representation, self.tag.UPDATES)
self.comm_utils.send(client_node, representation, tag=self.tag.UPDATES)
print(f"Node 1 sent average weight to {len(self.clients)} nodes")

def single_round(self, self_repr):
"""
Runs a single training round
"""
print("Node 1 waiting for all clients to finish")
reprs = self.comm_utils.wait_for_all_clients(self.clients, self.tag.DONE)
reprs = self.comm_utils.all_gather(tag=self.tag.DONE)
reprs.append(self_repr)
print(f"Node 1 received {len(reprs)} clients' weights")
avg_wts = self.aggregate(reprs)
Expand Down Expand Up @@ -151,34 +152,30 @@ def run_protocol(self):
start_epochs = self.config.get("start_epochs", 0)
total_epochs = self.config["epochs"]
for epoch in range(start_epochs, total_epochs):
status = self.comm_utils.wait_for_signal(src=0, tag=self.tag.START)
status = self.comm_utils.receive(0, tag=self.tag.START)
self.assign_own_status(status)
if self.status == "teacher":
self.local_train()
self_repr = self.get_representation()
self.comm_utils.send_signal(
dest=self.pair_id, data=self_repr, tag=self.tag.DONE
)
self.comm_utils.send(dest=self.pair_id, data=self_repr, tag=self.tag.DONE)
print(f"Node {self.node_id} sent repr to student node {self.pair_id}")
elif self.status == "student":
teacher_repr = self.comm_utils.wait_for_signal(
src=self.pair_id, tag=self.tag.DONE
)
teacher_repr = self.comm_utils.receive(self.pair_id, tag=self.tag.DONE)
print(f"Node {self.node_id} received repr from teacher node {self.pair_id}")
self.deep_mutual_train(teacher_repr)
else:
print(f"Node {self.node_id} do nothing")
acc = self.local_test()
print(f"Node {self.node_id} test_acc:{acc:.4f}")
self.comm_utils.send_signal(dest=0, data=acc, tag=self.tag.FINISH)
self.comm_utils.send(0, data=acc, tag=self.tag.FINISH)


class DefKTServer(BaseServer):
"""
Server class for DefKT (Deep Mutual Learning with Knowledge Transfer)
"""
def __init__(self, config) -> None:
super().__init__(config)
def __init__(self, config: Dict[str, Any], comm_utils: CommunicationManager) -> None:
super().__init__(config, comm_utils)
self.config = config
self.set_model_parameters(config)
self.tag = CommProtocol
Expand All @@ -190,7 +187,7 @@ def send_representations(self, representations):
Send the model representations to the clients
"""
for client_node in self.users:
self.comm_utils.send_signal(client_node, representations, self.tag.UPDATES)
self.comm_utils.send(client_node, representations, self.tag.UPDATES)
self.log_utils.log_console(
f"Server sent {len(representations)} representations to node {client_node}"
)
Expand Down Expand Up @@ -232,7 +229,7 @@ def single_round(self):
self.log_utils.log_console(
f"Server sending status from {self.node_id} to {client_node}"
)
self.comm_utils.send_signal(
self.comm_utils.send(
dest=client_node, data=[teachers, students], tag=self.tag.START
)

Expand All @@ -246,5 +243,5 @@ def run_protocol(self):
for epoch in range(start_epochs, total_epochs):
self.log_utils.log_console(f"Starting round {epoch}")
self.single_round()
accs = self.comm_utils.wait_for_all_clients(self.users, self.tag.FINISH)
accs = self.comm_utils.all_gather(tag=self.tag.FINISH)
self.log_utils.log_console(f"Round {epoch} done; acc {accs}")
Loading

0 comments on commit 71d7d19

Please sign in to comment.