Skip to content

Commit

Permalink
Merge branch 'main' into add-tests-for-mnist-m
Browse files Browse the repository at this point in the history
  • Loading branch information
jafermarq authored Jul 18, 2024
2 parents 4cc1b15 + 02b1959 commit 0dbc988
Show file tree
Hide file tree
Showing 27 changed files with 155 additions and 85 deletions.
4 changes: 2 additions & 2 deletions .devcontainer/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ RUN apt-get install -y curl wget gnupg python3 python-is-python3 python3-pip git
build-essential tmux vim

RUN python -m pip install \
pip==24.0.0 \
setuptools==69.5.1 \
pip==24.1.2 \
setuptools==70.3.0 \
poetry==1.7.1

USER $USERNAME
Expand Down
4 changes: 2 additions & 2 deletions .github/actions/bootstrap/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ inputs:
default: 3.8
pip-version:
description: "Version of pip to be installed using pip"
default: 24.0.0
default: 24.1.2
setuptools-version:
description: "Version of setuptools to be installed using pip"
default: 69.5.1
default: 70.3.0
poetry-version:
description: "Version of poetry to be installed using pip"
default: 1.7.1
Expand Down
20 changes: 17 additions & 3 deletions datasets/flwr_datasets/federated_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""FederatedDataset."""


from typing import Dict, Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union

import datasets
from datasets import Dataset, DatasetDict
Expand Down Expand Up @@ -65,6 +65,12 @@ class FederatedDataset:
Seed used for dataset shuffling. It has no effect if `shuffle` is False. The
seed cannot be set in the later stages. If `None`, then fresh, unpredictable
entropy will be pulled from the OS. Defaults to 42.
load_dataset_kwargs : Any
Additional keyword arguments passed to `datasets.load_dataset` function.
Currently used paramters used are dataset => path (in load_dataset),
subset => name (in load_dataset). You can pass e.g., `num_proc=4`,
`trust_remote_code=True`. Do not pass any parameters that modify the
return type such as another type than DatasetDict is returned.
Examples
--------
Expand All @@ -73,7 +79,7 @@ class FederatedDataset:
>>> from flwr_datasets import FederatedDataset
>>>
>>> fds = FederatedDataset(dataset="mnist", partitioners={"train": 100})
>>> # Load partition for client with ID 10.
>>> # Load partition for a client with ID 10.
>>> partition = fds.load_partition(10)
>>> # Use test split for centralized evaluation.
>>> centralized = fds.load_split("test")
Expand Down Expand Up @@ -107,6 +113,7 @@ def __init__(
partitioners: Dict[str, Union[Partitioner, int]],
shuffle: bool = True,
seed: Optional[int] = 42,
**load_dataset_kwargs: Any,
) -> None:
_check_if_dataset_tested(dataset)
self._dataset_name: str = dataset
Expand All @@ -127,6 +134,7 @@ def __init__(
self._event = {
"load_partition": {split: False for split in self._partitioners},
}
self._load_dataset_kwargs = load_dataset_kwargs

def load_partition(
self,
Expand Down Expand Up @@ -289,8 +297,14 @@ def _prepare_dataset(self) -> None:
happen before the resplitting.
"""
self._dataset = datasets.load_dataset(
path=self._dataset_name, name=self._subset
path=self._dataset_name, name=self._subset, **self._load_dataset_kwargs
)
if not isinstance(self._dataset, datasets.DatasetDict):
raise ValueError(
"Probably one of the specified parameter in `load_dataset_kwargs` "
"change the return type of the datasets.load_dataset function. "
"Make sure to use parameter such that the return type is DatasetDict."
)
if self._shuffle:
# Note it shuffles all the splits. The self._dataset is DatasetDict
# so e.g. {"train": train_data, "test": test_data}. All splits get shuffled.
Expand Down
42 changes: 38 additions & 4 deletions datasets/flwr_datasets/federated_dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,10 @@ def test_multiple_partitioners(self) -> None:
dataset_test_partition0 = dataset_fds.load_partition(0, self.test_split)

dataset = datasets.load_dataset(self.dataset_name)
self.assertEqual(
len(dataset_test_partition0),
len(dataset[self.test_split]) // num_test_partitions,
)
expected_len = len(dataset[self.test_split]) // num_test_partitions
mod = len(dataset[self.test_split]) % num_test_partitions
expected_len += 1 if 0 < mod else 0
self.assertEqual(len(dataset_test_partition0), expected_len)

def test_no_need_for_split_keyword_if_one_partitioner(self) -> None:
"""Test if partitions got with and without split args are the same."""
Expand Down Expand Up @@ -217,6 +217,23 @@ def resplit(dataset: DatasetDict) -> DatasetDict:
dataset_length = sum([len(ds) for ds in dataset.values()])
self.assertEqual(len(full), dataset_length)

def test_use_load_dataset_kwargs(self) -> None:
"""Test if the FederatedDataset works correctly with load_dataset_kwargs."""
try:
fds = FederatedDataset(
dataset=self.dataset_name,
shuffle=False,
partitioners={"train": 10},
num_proc=2,
)
_ = fds.load_partition(0)
# Try to catch as broad as possible
except Exception as e: # pylint: disable=broad-except
self.fail(
f"Error when using load_dataset_kwargs: {e}. "
f"This code should not raise any exceptions."
)


class ShufflingResplittingOnArtificialDatasetTest(unittest.TestCase):
"""Test shuffling and resplitting using small artificial dataset.
Expand Down Expand Up @@ -417,6 +434,23 @@ def test_cannot_use_the_old_split_names(self) -> None:
with self.assertRaises(ValueError):
fds.load_partition(0, "train")

def test_use_load_dataset_kwargs(self) -> None:
"""Test if the FederatedDataset raises with incorrect load_dataset_kwargs.
The FederatedDataset should throw an error when the load_dataset_kwargs make the
return type different from a DatasetDict.
Use split which makes the load_dataset return a Dataset.
"""
fds = FederatedDataset(
dataset="mnist",
shuffle=False,
partitioners={"train": 10},
split="train",
)
with self.assertRaises(ValueError):
_ = fds.load_partition(0)


def datasets_are_equal(ds1: Dataset, ds2: Dataset) -> bool:
"""Check if two Datasets have the same values."""
Expand Down
4 changes: 2 additions & 2 deletions dev/bootstrap.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ cd "$(cd "$(dirname "${BASH_SOURCE[0]}")" >/dev/null 2>&1 && pwd)"/../
./dev/rm-caches.sh

# Upgrade/install spcific versions of `pip`, `setuptools`, and `poetry`
python -m pip install -U pip==24.0.0
python -m pip install -U setuptools==69.5.1
python -m pip install -U pip==24.1.2
python -m pip install -U setuptools==70.3.0
python -m pip install -U poetry==1.7.1

# Use `poetry` to install project dependencies
Expand Down
7 changes: 4 additions & 3 deletions src/py/flwr/cli/new/new.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def new(

framework_str = framework_str.lower()

llm_challenge_str = None
if framework_str == "flowertune":
llm_challenge_value = prompt_options(
"Please select LLM challenge by typing in the number",
Expand Down Expand Up @@ -171,7 +172,7 @@ def new(
}

# List of files to render
if framework_str == "flowertune":
if llm_challenge_str:
files = {
".gitignore": {"template": "app/.gitignore.tpl"},
"pyproject.toml": {"template": f"app/pyproject.{framework_str}.toml.tpl"},
Expand Down Expand Up @@ -228,10 +229,10 @@ def new(
"README.md": {"template": "app/README.md.tpl"},
"pyproject.toml": {"template": f"app/pyproject.{framework_str}.toml.tpl"},
f"{import_name}/__init__.py": {"template": "app/code/__init__.py.tpl"},
f"{import_name}/server.py": {
f"{import_name}/server_app.py": {
"template": f"app/code/server.{framework_str}.py.tpl"
},
f"{import_name}/client.py": {
f"{import_name}/client_app.py": {
"template": f"app/code/client.{framework_str}.py.tpl"
},
}
Expand Down
4 changes: 2 additions & 2 deletions src/py/flwr/cli/new/new_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ def test_new_correct_name(tmp_path: str) -> None:
}
expected_files_module = {
"__init__.py",
"server.py",
"client.py",
"server_app.py",
"client_app.py",
"task.py",
}

Expand Down
10 changes: 7 additions & 3 deletions src/py/flwr/cli/new/templates/app/code/client.hf.py.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@ class FlowerClient(NumPyClient):

def fit(self, parameters, config):
self.set_parameters(parameters)
train(self.net, self.trainloader, epochs=1)
train(
self.net,
self.trainloader,
epochs=int(self.context.run_config["local-epochs"]),
)
return self.get_parameters(config={}), len(self.trainloader), {}

def evaluate(self, parameters, config):
Expand All @@ -45,8 +49,8 @@ def client_fn(context: Context):
CHECKPOINT, num_labels=2
).to(DEVICE)

partition_id = int(context.node_config['partition-id'])
num_partitions = int(context.node_config['num-partitions])
partition_id = int(context.node_config["partition-id"])
num_partitions = int(context.node_config["num-partitions"])
trainloader, valloader = load_data(partition_id, num_partitions)

# Return Client instance
Expand Down
18 changes: 10 additions & 8 deletions src/py/flwr/cli/new/templates/app/code/client.mlx.py.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,19 @@ from $import_name.task import (
# Define Flower Client and client_fn
class FlowerClient(NumPyClient):
def __init__(self, data):
num_layers = 2
hidden_dim = 32
num_layers = int(self.context.run_config["num-layers"])
hidden_dim = int(self.context.run_config["hidden-dim"])
num_classes = 10
batch_size = 256
num_epochs = 1
learning_rate = 1e-1
batch_size = int(self.context.run_config["batch-size"])
learning_rate = float(self.context.run_config["lr"])
num_epochs = int(self.context.run_config["local-epochs"])

self.train_images, self.train_labels, self.test_images, self.test_labels = data
self.model = MLP(num_layers, self.train_images.shape[-1], hidden_dim, num_classes)
self.optimizer = optim.SGD(learning_rate=learning_rate)
self.loss_and_grad_fn = nn.value_and_grad(self.model, loss_fn)
self.model = MLP(
num_layers, self.train_images.shape[-1], hidden_dim, num_classes
)
self.optimizer = optim.SGD(learning_rate=learning_rate)
self.loss_and_grad_fn = nn.value_and_grad(self.model, loss_fn)
self.num_epochs = num_epochs
self.batch_size = batch_size

Expand Down
8 changes: 7 additions & 1 deletion src/py/flwr/cli/new/templates/app/code/client.pytorch.py.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,13 @@ class FlowerClient(NumPyClient):

def fit(self, parameters, config):
set_weights(self.net, parameters)
results = train(self.net, self.trainloader, self.valloader, 1, DEVICE)
results = train(
self.net,
self.trainloader,
self.valloader,
int(self.context.run_config["local-epochs"]),
DEVICE,
)
return get_weights(self.net), len(self.trainloader.dataset), results

def evaluate(self, parameters, config):
Expand Down
3 changes: 2 additions & 1 deletion src/py/flwr/cli/new/templates/app/code/client.sklearn.py.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,11 @@ class FlowerClient(NumPyClient):

return loss, len(self.X_test), {"accuracy": accuracy}

fds = FederatedDataset(dataset="mnist", partitioners={"train": 2})

def client_fn(context: Context):
partition_id = int(context.node_config["partition-id"])
num_partitions = int(context.node_config["num-partitions"])
fds = FederatedDataset(dataset="mnist", partitioners={"train": num_partitions})
dataset = fds.load_partition(partition_id, "train").with_format("numpy")

X, y = dataset["image"].reshape((len(dataset), -1)), dataset["label"]
Expand Down
11 changes: 9 additions & 2 deletions src/py/flwr/cli/new/templates/app/code/client.tensorflow.py.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,13 @@ class FlowerClient(NumPyClient):

def fit(self, parameters, config):
self.model.set_weights(parameters)
self.model.fit(self.x_train, self.y_train, epochs=1, batch_size=32, verbose=0)
self.model.fit(
self.x_train,
self.y_train,
epochs=int(self.context.run_config["local-epochs"]),
batch_size=int(self.context.run_config["batch-size"]),
verbose=bool(self.context.run_config.get("verbose")),
)
return self.model.get_weights(), len(self.x_train), {}

def evaluate(self, parameters, config):
Expand All @@ -34,7 +40,8 @@ def client_fn(context: Context):
net = load_model()

partition_id = int(context.node_config["partition-id"])
x_train, y_train, x_test, y_test = load_data(partition_id, 2)
num_partitions = int(context.node_config["num-partitions"])
x_train, y_train, x_test, y_test = load_data(partition_id, num_partitions)

# Return Client instance
return FlowerClient(net, x_train, y_train, x_test, y_test).to_client()
Expand Down
4 changes: 2 additions & 2 deletions src/py/flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ from flwr.client import ClientApp
from flwr.common import ndarrays_to_parameters
from flwr.server import ServerApp, ServerConfig

from $import_name.client import gen_client_fn, get_parameters
from $import_name.client_app import gen_client_fn, get_parameters
from $import_name.dataset import get_tokenizer_and_data_collator_and_propt_formatting
from $import_name.models import get_model
from $import_name.server import fit_weighted_average, get_evaluate_fn, get_on_fit_config
from $import_name.server_app import fit_weighted_average, get_evaluate_fn, get_on_fit_config

# Avoid warnings
warnings.filterwarnings("ignore", category=UserWarning)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""$project_name: A Flower / FlowerTune app."""

from $import_name.client import set_parameters
from $import_name.client_app import set_parameters
from $import_name.models import get_model


Expand Down
5 changes: 3 additions & 2 deletions src/py/flwr/cli/new/templates/app/pyproject.hf.toml.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,12 @@ packages = ["."]
publisher = "$username"
[tool.flwr.app.components]
serverapp = "$import_name.server:app"
clientapp = "$import_name.client:app"
serverapp = "$import_name.server_app:app"
clientapp = "$import_name.client_app:app"
[tool.flwr.app.config]
num-server-rounds = "3"
local-epochs = "1"
[tool.flwr.federations]
default = "localhost"
Expand Down
4 changes: 2 additions & 2 deletions src/py/flwr/cli/new/templates/app/pyproject.jax.toml.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ packages = ["."]
publisher = "$username"
[tool.flwr.app.components]
serverapp = "$import_name.server:app"
clientapp = "$import_name.client:app"
serverapp = "$import_name.server_app:app"
clientapp = "$import_name.client_app:app"
[tool.flwr.app.config]
num-server-rounds = "3"
Expand Down
9 changes: 7 additions & 2 deletions src/py/flwr/cli/new/templates/app/pyproject.mlx.toml.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,16 @@ packages = ["."]
publisher = "$username"
[tool.flwr.app.components]
serverapp = "$import_name.server:app"
clientapp = "$import_name.client:app"
serverapp = "$import_name.server_app:app"
clientapp = "$import_name.client_app:app"
[tool.flwr.app.config]
num-server-rounds = "3"
local-epochs = "1"
num-layers = "2"
hidden-dim = "32"
batch-size = "256"
lr = "0.1"
[tool.flwr.federations]
default = "localhost"
Expand Down
4 changes: 2 additions & 2 deletions src/py/flwr/cli/new/templates/app/pyproject.numpy.toml.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ packages = ["."]
publisher = "$username"

[tool.flwr.app.components]
serverapp = "$import_name.server:app"
clientapp = "$import_name.client:app"
serverapp = "$import_name.server_app:app"
clientapp = "$import_name.client_app:app"

[tool.flwr.app.config]
num-server-rounds = "3"
Expand Down
5 changes: 3 additions & 2 deletions src/py/flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@ packages = ["."]
publisher = "$username"
[tool.flwr.app.components]
serverapp = "$import_name.server:app"
clientapp = "$import_name.client:app"
serverapp = "$import_name.server_app:app"
clientapp = "$import_name.client_app:app"
[tool.flwr.app.config]
num-server-rounds = "3"
local-epochs = "1"
[tool.flwr.federations]
default = "localhost"
Expand Down
Loading

0 comments on commit 0dbc988

Please sign in to comment.