Skip to content

Commit

Permalink
Update E2E tests to follow recent conventions
Browse files Browse the repository at this point in the history
  • Loading branch information
danieljanes committed Jul 13, 2024
1 parent 09b7170 commit 129ad83
Show file tree
Hide file tree
Showing 14 changed files with 53 additions and 74 deletions.
8 changes: 3 additions & 5 deletions e2e/bare-client-auth/client.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
from typing import Optional

import numpy as np

import flwr as fl
from flwr.client import ClientApp, NumPyClient
from flwr.common import Context

model_params = np.array([1])
objective = 5


# Define Flower client
class FlowerClient(fl.client.NumPyClient):
class FlowerClient(NumPyClient):
def get_parameters(self, config):
return model_params

Expand All @@ -30,6 +28,6 @@ def client_fn(context: Context):
return FlowerClient().to_client()


app = fl.client.ClientApp(
app = ClientApp(
client_fn=client_fn,
)
9 changes: 4 additions & 5 deletions e2e/bare-https/client.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
from pathlib import Path
from typing import Optional

import numpy as np

import flwr as fl
from flwr.client import ClientApp, NumPyClient, start_client
from flwr.common import Context

model_params = np.array([1])
objective = 5


# Define Flower client
class FlowerClient(fl.client.NumPyClient):
class FlowerClient(NumPyClient):
def get_parameters(self, config):
return model_params

Expand All @@ -31,13 +30,13 @@ def client_fn(context: Context):
return FlowerClient().to_client()


app = fl.client.ClientApp(
app = ClientApp(
client_fn=client_fn,
)

if __name__ == "__main__":
# Start Flower client
fl.client.start_client(
start_client(
server_address="127.0.0.1:8080",
client=FlowerClient().to_client(),
root_certificates=Path("certificates/ca.crt").read_bytes(),
Expand Down
11 changes: 4 additions & 7 deletions e2e/bare/client.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from datetime import datetime
from typing import Optional

import numpy as np

import flwr as fl
from flwr.client import ClientApp, NumPyClient, start_client
from flwr.common import ConfigsRecord, Context

SUBSET_SIZE = 1000
Expand All @@ -15,7 +14,7 @@


# Define Flower client
class FlowerClient(fl.client.NumPyClient):
class FlowerClient(NumPyClient):
def get_parameters(self, config):
return model_params

Expand Down Expand Up @@ -56,12 +55,10 @@ def client_fn(context: Context):
return FlowerClient().to_client()


app = fl.client.ClientApp(
app = ClientApp(
client_fn=client_fn,
)

if __name__ == "__main__":
# Start Flower client
fl.client.start_client(
server_address="127.0.0.1:8080", client=FlowerClient().to_client()
)
start_client(server_address="127.0.0.1:8080", client=FlowerClient().to_client())
1 change: 0 additions & 1 deletion e2e/docker/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import warnings
from collections import OrderedDict
from typing import Optional

import torch
import torch.nn as nn
Expand Down
9 changes: 4 additions & 5 deletions e2e/framework-fastai/client.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import warnings
from collections import OrderedDict
from typing import Optional

import numpy as np
import torch
from fastai.vision.all import *

import flwr as fl
from flwr.client import ClientApp, NumPyClient, start_client
from flwr.common import Context

warnings.filterwarnings("ignore", category=UserWarning)
Expand All @@ -31,7 +30,7 @@


# Define Flower client
class FlowerClient(fl.client.NumPyClient):
class FlowerClient(NumPyClient):
def get_parameters(self, config):
return [val.cpu().numpy() for _, val in learn.model.state_dict().items()]

Expand All @@ -55,14 +54,14 @@ def client_fn(context: Context):
return FlowerClient().to_client()


app = fl.client.ClientApp(
app = ClientApp(
client_fn=client_fn,
)


if __name__ == "__main__":
# Start Flower client
fl.client.start_client(
start_client(
server_address="127.0.0.1:8080",
client=FlowerClient().to_client(),
)
12 changes: 5 additions & 7 deletions e2e/framework-jax/client.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
"""Flower client example using JAX for linear regression."""

from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Tuple

import jax
import jax_training
import numpy as np

import flwr as fl
from flwr.client import ClientApp, NumPyClient, start_client
from flwr.common import Context

# Load data and determine model shape
Expand All @@ -15,7 +15,7 @@
model_shape = train_x.shape[1:]


class FlowerClient(fl.client.NumPyClient):
class FlowerClient(NumPyClient):
def __init__(self):
self.params = jax_training.load_model(model_shape)

Expand Down Expand Up @@ -53,12 +53,10 @@ def client_fn(context: Context):
return FlowerClient().to_client()


app = fl.client.ClientApp(
app = ClientApp(
client_fn=client_fn,
)

if __name__ == "__main__":
# Start Flower client
fl.client.start_client(
server_address="127.0.0.1:8080", client=FlowerClient().to_client()
)
start_client(server_address="127.0.0.1:8080", client=FlowerClient().to_client())
9 changes: 4 additions & 5 deletions e2e/framework-opacus/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import math
from collections import OrderedDict
from typing import Optional

import torch
import torch.nn as nn
Expand All @@ -10,7 +9,7 @@
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10

import flwr as fl
from flwr.client import ClientApp, NumPyClient, start_client
from flwr.common import Context

# Define parameters.
Expand Down Expand Up @@ -97,7 +96,7 @@ def load_data():


# Define Flower client.
class FlowerClient(fl.client.NumPyClient):
class FlowerClient(NumPyClient):
def __init__(self, model) -> None:
super().__init__()
# Create a privacy engine which will add DP and keep track of the privacy budget.
Expand Down Expand Up @@ -146,11 +145,11 @@ def client_fn(context: Context):
return FlowerClient(model).to_client()


app = fl.client.ClientApp(
app = ClientApp(
client_fn=client_fn,
)

if __name__ == "__main__":
fl.client.start_client(
start_client(
server_address="127.0.0.1:8080", client=FlowerClient(model).to_client()
)
10 changes: 5 additions & 5 deletions e2e/framework-pandas/client.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Tuple

import numpy as np
import pandas as pd

import flwr as fl
from flwr.client import ClientApp, NumPyClient, start_client
from flwr.common import Context

df = pd.read_csv("./data/client.csv")
Expand All @@ -17,7 +17,7 @@ def compute_hist(df: pd.DataFrame, col_name: str) -> np.ndarray:


# Define Flower client
class FlowerClient(fl.client.NumPyClient):
class FlowerClient(NumPyClient):
def fit(
self, parameters: List[np.ndarray], config: Dict[str, str]
) -> Tuple[List[np.ndarray], int, Dict]:
Expand All @@ -37,13 +37,13 @@ def client_fn(context: Context):
return FlowerClient().to_client()


app = fl.client.ClientApp(
app = ClientApp(
client_fn=client_fn,
)

if __name__ == "__main__":
# Start Flower client
fl.client.start_client(
start_client(
server_address="127.0.0.1:8080",
client=FlowerClient().to_client(),
)
9 changes: 4 additions & 5 deletions e2e/framework-pytorch-lightning/client.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
from collections import OrderedDict
from typing import Optional

import mnist
import pytorch_lightning as pl
import torch

import flwr as fl
from flwr.client import ClientApp, NumPyClient, start_client
from flwr.common import Context


class FlowerClient(fl.client.NumPyClient):
class FlowerClient(NumPyClient):
def __init__(self, model, train_loader, val_loader, test_loader):
self.model = model
self.train_loader = train_loader
Expand Down Expand Up @@ -61,7 +60,7 @@ def client_fn(context: Context):
return FlowerClient(model, train_loader, val_loader, test_loader).to_client()


app = fl.client.ClientApp(
app = ClientApp(
client_fn=client_fn,
)

Expand All @@ -73,7 +72,7 @@ def main() -> None:

# Flower client
client = FlowerClient(model, train_loader, val_loader, test_loader).to_client()
fl.client.start_client(server_address="127.0.0.1:8080", client=client)
start_client(server_address="127.0.0.1:8080", client=client)


if __name__ == "__main__":
Expand Down
9 changes: 4 additions & 5 deletions e2e/framework-pytorch/client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import warnings
from collections import OrderedDict
from datetime import datetime
from typing import Optional

import torch
import torch.nn as nn
Expand All @@ -11,7 +10,7 @@
from torchvision.transforms import Compose, Normalize, ToTensor
from tqdm import tqdm

import flwr as fl
from flwr.client import ClientApp, NumPyClient, start_client
from flwr.common import ConfigsRecord, Context

# #############################################################################
Expand Down Expand Up @@ -90,7 +89,7 @@ def load_data():


# Define Flower client
class FlowerClient(fl.client.NumPyClient):
class FlowerClient(NumPyClient):
def get_parameters(self, config):
return [val.cpu().numpy() for _, val in net.state_dict().items()]

Expand Down Expand Up @@ -141,14 +140,14 @@ def client_fn(context: Context):
return FlowerClient().to_client()


app = fl.client.ClientApp(
app = ClientApp(
client_fn=client_fn,
)


if __name__ == "__main__":
# Start Flower client
fl.client.start_client(
start_client(
server_address="127.0.0.1:8080",
client=FlowerClient().to_client(),
)
10 changes: 4 additions & 6 deletions e2e/framework-scikit-learn/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import log_loss

import flwr as fl
from flwr.client import ClientApp, NumPyClient, start_client
from flwr.common import Context

# Load MNIST dataset from https://www.openml.org/d/554
Expand All @@ -28,7 +28,7 @@


# Define Flower client
class FlowerClient(fl.client.NumPyClient):
class FlowerClient(NumPyClient):
def get_parameters(self, config): # type: ignore
return utils.get_model_parameters(model)

Expand All @@ -51,12 +51,10 @@ def client_fn(context: Context):
return FlowerClient().to_client()


app = fl.client.ClientApp(
app = ClientApp(
client_fn=client_fn,
)

if __name__ == "__main__":
# Start Flower client
fl.client.start_client(
server_address="0.0.0.0:8080", client=FlowerClient().to_client()
)
start_client(server_address="0.0.0.0:8080", client=FlowerClient().to_client())
Loading

0 comments on commit 129ad83

Please sign in to comment.