Skip to content

Commit

Permalink
refactor(baselines) Update torch and torchvision versions in old …
Browse files Browse the repository at this point in the history
…baselines (#4783)
  • Loading branch information
chongshenng authored Jan 7, 2025
1 parent dd19f5f commit 9508d16
Show file tree
Hide file tree
Showing 9 changed files with 26 additions and 20 deletions.
2 changes: 1 addition & 1 deletion baselines/flwr_baselines/dev/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ echo "=== test.sh ==="
python -m isort --check-only . && echo "- isort: done" &&
python -m black --check . && echo "- black: done" &&
python -m docformatter -i -r flwr_baselines && echo "- docformatter: done" &&
python -m mypy flwr_baselines && echo "- mypy: done" &&
python -m mypy --explicit-package-bases flwr_baselines && echo "- mypy: done" &&
python -m pylint flwr_baselines && echo "- pylint: done" &&
python -m pytest --durations=0 -v flwr_baselines && echo "- pytest: done" &&
echo "- All Python checks passed"
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ def test_sample_without_replacement(self) -> None:
# Prepare
distribution = np.array([0.0, 1.0, 0.0], dtype=np.float32)
empty_classes = [False, False, True]
list_samples = [
list_samples: List[List[np.ndarray]] = [
[
np.zeros((3, 1, 1), dtype=np.float32),
np.zeros((3, 1, 1), dtype=np.float32),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def shuffle_and_create_cifar100_lda_dists(
lda_concentration_coarse: float,
lda_concentration_fine: float,
num_partitions: int,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
) -> Tuple[List[List[np.ndarray]], List[List[np.ndarray]], np.ndarray, np.ndarray]:
"""Shuffles the original dataset and creates the two-level LDA
distributions.
Expand Down Expand Up @@ -329,10 +329,10 @@ def partition_cifar100_and_save(

# obtain sample
sample_x: np.ndarray = x_list[real_class][0]
x_list[real_class] = np.delete(x_list[real_class], 0, 0)
x_list[real_class] = np.delete(x_list[real_class], 0, 0) # type: ignore

sample_y: np.ndarray = y_list[real_class][0]
y_list[real_class] = np.delete(y_list[real_class], 0, 0)
y_list[real_class] = np.delete(y_list[real_class], 0, 0) # type: ignore

x_this_client.append(sample_x)
y_this_client.append(sample_y)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

import flwr as fl
import torch
from flwr.client.client import Client
from flwr.common import Context
from flwr.common.typing import NDArrays, Scalar
from torch.utils.data import DataLoader

Expand Down Expand Up @@ -74,7 +76,7 @@ def gen_client_fn(
num_epochs: int,
batch_size: int,
learning_rate: float,
) -> Tuple[Callable[[str], FlowerClient], DataLoader]:
) -> Tuple[Callable[[Context], Client], DataLoader]:
"""Generates the client function that creates the Flower Clients.
Parameters
Expand Down Expand Up @@ -109,20 +111,22 @@ def gen_client_fn(
iid=iid, balance=balance, num_clients=num_clients, batch_size=batch_size
)

def client_fn(cid: str) -> FlowerClient:
def client_fn(context: Context) -> Client:
"""Create a Flower client representing a single organization."""

# Load model
net = model.Net().to(device)

partition_id = context.node_config["partition-id"]

# Note: each client gets a different trainloader/valloader, so each client
# will train and evaluate on their own unique data
trainloader = trainloaders[int(cid)]
valloader = valloaders[int(cid)]
trainloader = trainloaders[int(partition_id)]
valloader = valloaders[int(partition_id)]

# Create a single Flower client representing a single organization
return FlowerClient(
net, trainloader, valloader, device, num_epochs, learning_rate
)
).to_client()

return client_fn, testloader
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
The weighted average metric.
"""
# Multiply accuracy of each client by number of examples used
accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
accuracies = [
float(num_examples) * float(m["accuracy"]) for num_examples, m in metrics
]
examples = [num_examples for num_examples, _ in metrics]

# Aggregate and return custom metric (weighted average)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Client implementation for federated learning."""

from typing import Dict, List, Tuple
from typing import Dict, List, Optional, Tuple

import flwr as fl
import torch
Expand Down Expand Up @@ -37,7 +37,7 @@ def __init__(
device: torch.device,
num_epochs: int,
learning_rate: float,
num_batches: int = None,
num_batches: Optional[int] = None,
) -> None:
"""
Expand Down Expand Up @@ -111,7 +111,7 @@ def create_client(
num_epochs: int,
learning_rate: float,
num_classes: int = 62,
num_batches: int = None,
num_batches: Optional[int] = None,
) -> FlowerClient:
"""Create client for the flower simulation."""
net = Net(num_classes).to(device)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pathlib
from logging import INFO
from typing import List, Tuple
from typing import List, Optional, Tuple

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -118,7 +118,7 @@ def train_valid_test_partition(
train_split: float = 0.9,
validation_split: float = 0.0,
test_split: float = 0.1,
random_seed: int = None,
random_seed: Optional[int] = None,
) -> Tuple[List[Dataset], List[Dataset], List[Dataset]]:
"""Partition list of datasets to train, validation and test splits (each
dataset from the list individually).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def sample(
sampling_type: str,
frac: float,
n_clients: Optional[int] = None,
random_seed: int = None,
random_seed: Optional[int] = None,
) -> pd.DataFrame:
"""Samples data reference stored in the self._data_info_df.
Expand Down
6 changes: 3 additions & 3 deletions baselines/flwr_baselines/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ classifiers = [
python = ">=3.8.15, <=3.11.0"
# Mandatory dependencies
flwr = { extras = ["simulation"], version = "^1.3.0" }
torch = "^1.10.1"
torchvision = "^0.11.2"
torch = "==2.4.1"
torchvision = "==0.19.1"
hydra-core = "^1.2.0"
numpy = "^1.20.0"
tqdm = "4.66.3"
Expand All @@ -57,7 +57,7 @@ pillow = "==10.2.0"
isort = "==5.13.2"
black = "==24.2.0"
docformatter = "==1.7.5"
mypy = "==0.961"
mypy = "==1.8.0"
pylint = "==2.8.2"
flake8 = "==3.9.2"
pytest = "==6.2.4"
Expand Down

0 comments on commit 9508d16

Please sign in to comment.