Skip to content

Commit

Permalink
Merge branch 'main' into intro-abc-conn
Browse files Browse the repository at this point in the history
  • Loading branch information
danieljanes authored Jan 24, 2025
2 parents 42bffc8 + 5f9951f commit 98f84c8
Show file tree
Hide file tree
Showing 145 changed files with 9,879 additions and 5,793 deletions.
3 changes: 0 additions & 3 deletions .github/workflows/e2e.yml
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,6 @@ jobs:
- directory: e2e-pytorch-lightning
e2e: e2e_pytorch_lightning
dataset: |
from torchvision.datasets import MNIST
MNIST('./data', download=True)

- directory: e2e-scikit-learn
e2e: e2e_scikit_learn
Expand Down
4 changes: 2 additions & 2 deletions baselines/dasha/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ classifiers = [
python = ">=3.10.0, <3.11.0" # don't change this
flwr = { extras = ["simulation"], version = "1.5.0" }
hydra-core = "1.3.2" # don't change this
scikit-learn = "1.3.0"
scikit-learn = "1.5.0"
matplotlib = "3.7.1"
# Installing Torch. TODO: Fix and install the relevant version of Torch based on the current system.
torch = [{ url = "https://download.pytorch.org/whl/cu118/torch-2.0.0%2Bcu118-cp310-cp310-linux_x86_64.whl", markers="sys_platform == 'linux'"},
Expand All @@ -59,7 +59,7 @@ pytest = "==6.2.4"
pytest-watch = "==4.2.0"
types-requests = "==2.27.7"
py-spy = "==0.3.14"
ruff = "==0.0.272"
ruff = "==0.4.5"
virtualenv = "==20.21.0"

[tool.isort]
Expand Down
6 changes: 3 additions & 3 deletions baselines/depthfl/depthfl/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def gen_client_fn( # pylint: disable=too-many-arguments
learning_rate: float,
learning_rate_decay: float,
models: List[DictConfig],
) -> Callable[[str], FlowerClient]:
) -> Callable[[str], fl.client.Client]:
"""Generate the client function that creates the Flower Clients.
Parameters
Expand All @@ -150,7 +150,7 @@ def gen_client_fn( # pylint: disable=too-many-arguments
client function that creates Flower Clients
"""

def client_fn(cid: str) -> FlowerClient:
def client_fn(cid: str) -> fl.client.Client:
"""Create a Flower client representing a single organization."""
# Load model
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
Expand All @@ -176,6 +176,6 @@ def client_fn(cid: str) -> FlowerClient:
learning_rate_decay,
prev_grads,
int(cid),
)
).to_client()

return client_fn
1 change: 1 addition & 0 deletions baselines/depthfl/depthfl/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def fit_round(
client_instructions=client_instructions,
max_workers=self.max_workers,
timeout=timeout,
group_id=server_round,
)
log(
DEBUG,
Expand Down
18 changes: 8 additions & 10 deletions baselines/depthfl/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,7 @@ classifiers = [
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3 :: Only",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: Implementation :: CPython",
"Topic :: Scientific/Engineering",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
Expand All @@ -38,7 +35,7 @@ classifiers = [

[tool.poetry.dependencies]
python = ">=3.10.0, <3.11.0"
flwr = { extras = ["simulation"], version = "1.5.0" }
flwr = { extras = ["simulation"], version = "1.9.0" }
hydra-core = "1.3.2" # don't change this
matplotlib = "3.7.1"
torch = { url = "https://download.pytorch.org/whl/cu116/torch-1.13.1%2Bcu116-cp310-cp310-linux_x86_64.whl" }
Expand All @@ -53,9 +50,8 @@ pylint = "==2.8.2"
flake8 = "==3.9.2"
pytest = "==6.2.4"
pytest-watch = "==4.2.0"
ruff = "==0.0.272"
ruff = "==0.4.5"
types-requests = "==2.27.7"
virtualenv = "==20.21.0"

[tool.isort]
line_length = 88
Expand Down Expand Up @@ -108,11 +104,8 @@ wrap-summaries = 88
wrap-descriptions = 88

[tool.ruff]
target-version = "py38"
target-version = "py39"
line-length = 88
select = ["D", "E", "F", "W", "B", "ISC", "C4"]
fixable = ["D", "E", "F", "W", "B", "ISC", "C4"]
ignore = ["B024", "B027"]
exclude = [
".bzr",
".direnv",
Expand All @@ -137,5 +130,10 @@ exclude = [
"proto",
]

[tool.ruff.lint]
select = ["D", "E", "F", "W", "B", "ISC", "C4"]
fixable = ["D", "E", "F", "W", "B", "ISC", "C4"]
ignore = ["B024", "B027"]

[tool.ruff.pydocstyle]
convention = "numpy"
2 changes: 1 addition & 1 deletion baselines/fedmeta/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ python = ">=3.10.0, <3.11.0"
flwr = { extras = ["simulation"], version = "1.5.0" }
hydra-core = "1.3.2" # don't change this
matplotlib = "3.7.1"
scikit-learn = "1.3.1"
scikit-learn = "1.5.0"
torch = { url = "https://download.pytorch.org/whl/cu117/torch-2.0.1%2Bcu117-cp310-cp310-linux_x86_64.whl" }
torchvision = { url = "https://download.pytorch.org/whl/cu117/torchvision-0.15.2%2Bcu117-cp310-cp310-linux_x86_64.whl" }

Expand Down
6 changes: 3 additions & 3 deletions baselines/fednova/fednova/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,10 @@ def gen_clients_fednova( # pylint: disable=too-many-arguments
data_sizes: List,
model: DictConfig,
exp_config: DictConfig,
) -> Callable[[str], FedNovaClient]:
) -> Callable[[str], fl.client.Client]:
"""Return a generator function to create a FedNova client."""

def client_fn(cid: str) -> FedNovaClient:
def client_fn(cid: str) -> fl.client.Client:
"""Create a Flower client representing a single organization."""
# Load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand All @@ -144,6 +144,6 @@ def client_fn(cid: str) -> FedNovaClient:
num_epochs,
client_dataset_ratio,
exp_config,
)
).to_client()

return client_fn
20 changes: 9 additions & 11 deletions baselines/fednova/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,7 @@ classifiers = [
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3 :: Only",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: Implementation :: CPython",
"Topic :: Scientific/Engineering",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
Expand All @@ -39,11 +36,11 @@ classifiers = [
[tool.poetry.dependencies]
# tested with python == 3.10.11
python = ">=3.10.0, <3.11.0"
flwr = { extras = ["simulation"], version = "1.5.0" }
flwr = { extras = ["simulation"], version = "1.9.0" }
hydra-core = "1.3.2" # don't change this
torch = { url = "https://download.pytorch.org/whl/cu117/torch-2.0.1%2Bcu117-cp310-cp310-linux_x86_64.whl" }
torchvision = { url = "https://download.pytorch.org/whl/cu117/torchvision-0.15.2%2Bcu117-cp310-cp310-linux_x86_64.whl" }
numpy = "1.21.6"
numpy = "1.22.0"
matplotlib = "3.5.3"
pandas = "1.3.5"

Expand All @@ -56,9 +53,8 @@ pylint = "==2.8.2"
flake8 = "==3.9.2"
pytest = "==6.2.4"
pytest-watch = "==4.2.0"
ruff = "==0.0.272"
ruff = "==0.4.5"
types-requests = "==2.27.7"
virtualenv = "20.21.0"

[tool.isort]
line_length = 88
Expand Down Expand Up @@ -111,11 +107,8 @@ wrap-summaries = 88
wrap-descriptions = 88

[tool.ruff]
target-version = "py38"
target-version = "py39"
line-length = 88
select = ["D", "E", "F", "W", "B", "ISC", "C4"]
fixable = ["D", "E", "F", "W", "B", "ISC", "C4"]
ignore = ["B024", "B027"]
exclude = [
".bzr",
".direnv",
Expand All @@ -140,5 +133,10 @@ exclude = [
"proto",
]

[tool.ruff.lint]
select = ["D", "E", "F", "W", "B", "ISC", "C4"]
fixable = ["D", "E", "F", "W", "B", "ISC", "C4"]
ignore = ["B024", "B027"]

[tool.ruff.pydocstyle]
convention = "numpy"
8 changes: 4 additions & 4 deletions baselines/fedpara/fedpara/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,10 @@ def gen_client_fn(
args: Dict,
test_loader: Optional[List[DataLoader]] = None,
state_path: Optional[str] = None,
) -> Callable[[str], fl.client.NumPyClient]:
) -> Callable[[str], fl.client.Client]:
"""Return a function which creates a new FlowerClient for a given cid."""

def client_fn(cid: str) -> fl.client.NumPyClient:
def client_fn(cid: str) -> fl.client.Client:
"""Create a new FlowerClient for a given cid."""
cid_ = int(cid)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
Expand All @@ -185,13 +185,13 @@ def client_fn(cid: str) -> fl.client.NumPyClient:
state_path=cl_path,
algorithm=args["algorithm"].lower(),
device=device,
)
).to_client()
return FlowerClient(
cid=cid_,
net=instantiate(model).to(device),
train_loader=train_loaders[cid_],
num_epochs=num_epochs,
device=device,
)
).to_client()

return client_fn
18 changes: 9 additions & 9 deletions baselines/fedpara/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ classifiers = [
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3 :: Only",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: Implementation :: CPython",
Expand All @@ -38,7 +36,7 @@ classifiers = [

[tool.poetry.dependencies]
python = ">=3.10, <3.12.0" # don't change this
flwr = { extras = ["simulation"], version = "1.5.0" }
flwr = { extras = ["simulation"], version = "1.9.0" }
hydra-core = "1.3.2" # don't change this
matplotlib = "^3.7.2"
tqdm = "^4.66.1"
Expand All @@ -54,9 +52,9 @@ pylint = "==2.8.2"
flake8 = "==3.9.2"
pytest = "==6.2.4"
pytest-watch = "==4.2.0"
ruff = "==0.0.272"
ruff = "==0.4.5"
types-requests = "==2.27.7"
virtualenv = "20.21.0"
virtualenv = "20.26.6"

[tool.isort]
line_length = 88
Expand Down Expand Up @@ -109,11 +107,8 @@ wrap-summaries = 88
wrap-descriptions = 88

[tool.ruff]
target-version = "py38"
target-version = "py39"
line-length = 88
select = ["D", "E", "F", "W", "B", "ISC", "C4"]
fixable = ["D", "E", "F", "W", "B", "ISC", "C4"]
ignore = ["B024", "B027"]
exclude = [
".bzr",
".direnv",
Expand All @@ -138,5 +133,10 @@ exclude = [
"proto",
]

[tool.ruff.lint]
select = ["D", "E", "F", "W", "B", "ISC", "C4"]
fixable = ["D", "E", "F", "W", "B", "ISC", "C4"]
ignore = ["B024", "B027"]

[tool.ruff.pydocstyle]
convention = "numpy"
12 changes: 6 additions & 6 deletions baselines/fedper/fedper/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import numpy as np
import torch
from flwr.client import NumPyClient
from flwr.client import Client, NumPyClient
from flwr.common import NDArrays, Scalar
from omegaconf import DictConfig
from torch.utils.data import DataLoader, Subset, random_split
Expand Down Expand Up @@ -239,7 +239,7 @@ def set_parameters(self, parameters: List[np.ndarray], evaluate=False) -> None:
def get_client_fn_simulation(
config: DictConfig,
client_state_save_path: str = "",
) -> Callable[[str], Union[FedPerClient, BaseClient]]:
) -> Callable[[str], Client]:
"""Generate the client function that creates the Flower Clients.
Parameters
Expand All @@ -251,7 +251,7 @@ def get_client_fn_simulation(
Returns
-------
Tuple[Callable[[str], FlowerClient], DataLoader]
Tuple[Callable[[str], Client], DataLoader]
A tuple containing the client function that creates Flower Clients and
the DataLoader that will be used for testing
"""
Expand Down Expand Up @@ -288,7 +288,7 @@ def get_client_fn_simulation(
)
# ------------------------------------------------------------

def client_fn(cid: str) -> BaseClient:
def client_fn(cid: str) -> Client:
"""Create a Flower client representing a single organization."""
cid_use = int(cid)
if config.dataset.name.lower() == "flickr":
Expand Down Expand Up @@ -343,12 +343,12 @@ def client_fn(cid: str) -> BaseClient:
client_essentials=client_essentials,
config=config,
model_manager_class=manager,
)
).to_client()
return BaseClient(
data_loaders=client_data_loaders,
client_essentials=client_essentials,
config=config,
model_manager_class=manager,
)
).to_client()

return client_fn
5 changes: 3 additions & 2 deletions baselines/fedper/fedper/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@

import matplotlib.pyplot as plt
import numpy as np
from flwr.client import Client
from flwr.server.history import History
from omegaconf import DictConfig

from fedper.client import BaseClient, FedPerClient, get_client_fn_simulation
from fedper.client import get_client_fn_simulation
from fedper.implemented_models.mobile_model import MobileNet, MobileNetModelSplit
from fedper.implemented_models.resnet_model import ResNet, ResNetModelSplit

Expand Down Expand Up @@ -71,7 +72,7 @@ def set_client_state_save_path() -> str:

def get_client_fn(
config: DictConfig, client_state_save_path: str = ""
) -> Callable[[str], Union[FedPerClient, BaseClient]]:
) -> Callable[[str], Client]:
"""Get client function."""
# Get algorithm
algorithm = config.algorithm.lower()
Expand Down
Loading

0 comments on commit 98f84c8

Please sign in to comment.