Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ci(*:skip) Update client_fn args in e2e tests #3775

Merged
merged 6 commits into from
Jul 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions e2e/bare-client-auth/client.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import numpy as np

import flwr as fl
from flwr.client import ClientApp, NumPyClient
from flwr.common import Context

model_params = np.array([1])
objective = 5


# Define Flower client
class FlowerClient(fl.client.NumPyClient):
class FlowerClient(NumPyClient):
def get_parameters(self, config):
return model_params

Expand All @@ -23,10 +24,10 @@ def evaluate(self, parameters, config):
return loss, 1, {"accuracy": accuracy}


def client_fn(cid):
def client_fn(context: Context):
return FlowerClient().to_client()


app = fl.client.ClientApp(
app = ClientApp(
client_fn=client_fn,
)
11 changes: 6 additions & 5 deletions e2e/bare-https/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@

import numpy as np

import flwr as fl
from flwr.client import ClientApp, NumPyClient, start_client
from flwr.common import Context

model_params = np.array([1])
objective = 5


# Define Flower client
class FlowerClient(fl.client.NumPyClient):
class FlowerClient(NumPyClient):
def get_parameters(self, config):
return model_params

Expand All @@ -25,17 +26,17 @@ def evaluate(self, parameters, config):
return loss, 1, {"accuracy": accuracy}


def client_fn(cid):
def client_fn(context: Context):
return FlowerClient().to_client()


app = fl.client.ClientApp(
app = ClientApp(
client_fn=client_fn,
)

if __name__ == "__main__":
# Start Flower client
fl.client.start_client(
start_client(
server_address="127.0.0.1:8080",
client=FlowerClient().to_client(),
root_certificates=Path("certificates/ca.crt").read_bytes(),
Expand Down
14 changes: 6 additions & 8 deletions e2e/bare/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import numpy as np

import flwr as fl
from flwr.common import ConfigsRecord
from flwr.client import ClientApp, NumPyClient, start_client
from flwr.common import ConfigsRecord, Context

SUBSET_SIZE = 1000
STATE_VAR = "timestamp"
Expand All @@ -14,7 +14,7 @@


# Define Flower client
class FlowerClient(fl.client.NumPyClient):
class FlowerClient(NumPyClient):
def get_parameters(self, config):
return model_params

Expand Down Expand Up @@ -51,16 +51,14 @@ def evaluate(self, parameters, config):
)


def client_fn(cid):
def client_fn(context: Context):
return FlowerClient().to_client()


app = fl.client.ClientApp(
app = ClientApp(
client_fn=client_fn,
)

if __name__ == "__main__":
# Start Flower client
fl.client.start_client(
server_address="127.0.0.1:8080", client=FlowerClient().to_client()
)
start_client(server_address="127.0.0.1:8080", client=FlowerClient().to_client())
3 changes: 2 additions & 1 deletion e2e/docker/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from torchvision.transforms import Compose, Normalize, ToTensor

from flwr.client import ClientApp, NumPyClient
from flwr.common import Context

# #############################################################################
# 1. Regular PyTorch pipeline: nn.Module, train, test, and DataLoader
Expand Down Expand Up @@ -122,7 +123,7 @@ def evaluate(self, parameters, config):
return loss, len(testloader.dataset), {"accuracy": accuracy}


def client_fn(cid: str):
def client_fn(context: Context):
"""Create and return an instance of Flower `Client`."""
return FlowerClient().to_client()

Expand Down
11 changes: 6 additions & 5 deletions e2e/framework-fastai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import torch
from fastai.vision.all import *

import flwr as fl
from flwr.client import ClientApp, NumPyClient, start_client
from flwr.common import Context

warnings.filterwarnings("ignore", category=UserWarning)

Expand All @@ -29,7 +30,7 @@


# Define Flower client
class FlowerClient(fl.client.NumPyClient):
class FlowerClient(NumPyClient):
def get_parameters(self, config):
return [val.cpu().numpy() for _, val in learn.model.state_dict().items()]

Expand All @@ -49,18 +50,18 @@ def evaluate(self, parameters, config):
return loss, len(dls.valid), {"accuracy": 1 - error_rate}


def client_fn(cid):
def client_fn(context: Context):
return FlowerClient().to_client()


app = fl.client.ClientApp(
app = ClientApp(
client_fn=client_fn,
)


if __name__ == "__main__":
# Start Flower client
fl.client.start_client(
start_client(
server_address="127.0.0.1:8080",
client=FlowerClient().to_client(),
)
13 changes: 6 additions & 7 deletions e2e/framework-jax/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,16 @@
import jax_training
import numpy as np

import flwr as fl
from flwr.client import ClientApp, NumPyClient, start_client
from flwr.common import Context

# Load data and determine model shape
train_x, train_y, test_x, test_y = jax_training.load_data()
grad_fn = jax.grad(jax_training.loss_fn)
model_shape = train_x.shape[1:]


class FlowerClient(fl.client.NumPyClient):
class FlowerClient(NumPyClient):
def __init__(self):
self.params = jax_training.load_model(model_shape)

Expand Down Expand Up @@ -48,16 +49,14 @@ def evaluate(
return float(loss), num_examples, {"loss": float(loss)}


def client_fn(cid):
def client_fn(context: Context):
return FlowerClient().to_client()


app = fl.client.ClientApp(
app = ClientApp(
client_fn=client_fn,
)

if __name__ == "__main__":
# Start Flower client
fl.client.start_client(
server_address="127.0.0.1:8080", client=FlowerClient().to_client()
)
start_client(server_address="127.0.0.1:8080", client=FlowerClient().to_client())
11 changes: 6 additions & 5 deletions e2e/framework-opacus/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10

import flwr as fl
from flwr.client import ClientApp, NumPyClient, start_client
from flwr.common import Context

# Define parameters.
PARAMS = {
Expand Down Expand Up @@ -95,7 +96,7 @@ def load_data():


# Define Flower client.
class FlowerClient(fl.client.NumPyClient):
class FlowerClient(NumPyClient):
def __init__(self, model) -> None:
super().__init__()
# Create a privacy engine which will add DP and keep track of the privacy budget.
Expand Down Expand Up @@ -139,16 +140,16 @@ def evaluate(self, parameters, config):
return float(loss), len(testloader), {"accuracy": float(accuracy)}


def client_fn(cid):
def client_fn(context: Context):
model = Net()
return FlowerClient(model).to_client()


app = fl.client.ClientApp(
app = ClientApp(
client_fn=client_fn,
)

if __name__ == "__main__":
fl.client.start_client(
start_client(
server_address="127.0.0.1:8080", client=FlowerClient(model).to_client()
)
11 changes: 6 additions & 5 deletions e2e/framework-pandas/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import numpy as np
import pandas as pd

import flwr as fl
from flwr.client import ClientApp, NumPyClient, start_client
from flwr.common import Context

df = pd.read_csv("./data/client.csv")

Expand All @@ -16,7 +17,7 @@ def compute_hist(df: pd.DataFrame, col_name: str) -> np.ndarray:


# Define Flower client
class FlowerClient(fl.client.NumPyClient):
class FlowerClient(NumPyClient):
def fit(
self, parameters: List[np.ndarray], config: Dict[str, str]
) -> Tuple[List[np.ndarray], int, Dict]:
Expand All @@ -32,17 +33,17 @@ def fit(
)


def client_fn(cid):
def client_fn(context: Context):
return FlowerClient().to_client()


app = fl.client.ClientApp(
app = ClientApp(
client_fn=client_fn,
)

if __name__ == "__main__":
# Start Flower client
fl.client.start_client(
start_client(
server_address="127.0.0.1:8080",
client=FlowerClient().to_client(),
)
11 changes: 6 additions & 5 deletions e2e/framework-pytorch-lightning/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
import pytorch_lightning as pl
import torch

import flwr as fl
from flwr.client import ClientApp, NumPyClient, start_client
from flwr.common import Context


class FlowerClient(fl.client.NumPyClient):
class FlowerClient(NumPyClient):
def __init__(self, model, train_loader, val_loader, test_loader):
self.model = model
self.train_loader = train_loader
Expand Down Expand Up @@ -51,15 +52,15 @@ def _set_parameters(model, parameters):
model.load_state_dict(state_dict, strict=True)


def client_fn(cid):
def client_fn(context: Context):
model = mnist.LitAutoEncoder()
train_loader, val_loader, test_loader = mnist.load_data()

# Flower client
return FlowerClient(model, train_loader, val_loader, test_loader).to_client()


app = fl.client.ClientApp(
app = ClientApp(
client_fn=client_fn,
)

Expand All @@ -71,7 +72,7 @@ def main() -> None:

# Flower client
client = FlowerClient(model, train_loader, val_loader, test_loader).to_client()
fl.client.start_client(server_address="127.0.0.1:8080", client=client)
start_client(server_address="127.0.0.1:8080", client=client)


if __name__ == "__main__":
Expand Down
12 changes: 6 additions & 6 deletions e2e/framework-pytorch/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from torchvision.transforms import Compose, Normalize, ToTensor
from tqdm import tqdm

import flwr as fl
from flwr.common import ConfigsRecord
from flwr.client import ClientApp, NumPyClient, start_client
from flwr.common import ConfigsRecord, Context

# #############################################################################
# 1. Regular PyTorch pipeline: nn.Module, train, test, and DataLoader
Expand Down Expand Up @@ -89,7 +89,7 @@ def load_data():


# Define Flower client
class FlowerClient(fl.client.NumPyClient):
class FlowerClient(NumPyClient):
def get_parameters(self, config):
return [val.cpu().numpy() for _, val in net.state_dict().items()]

Expand Down Expand Up @@ -136,18 +136,18 @@ def set_parameters(model, parameters):
return


def client_fn(cid):
def client_fn(context: Context):
return FlowerClient().to_client()


app = fl.client.ClientApp(
app = ClientApp(
client_fn=client_fn,
)


if __name__ == "__main__":
# Start Flower client
fl.client.start_client(
start_client(
server_address="127.0.0.1:8080",
client=FlowerClient().to_client(),
)
Loading
Loading