Skip to content

Commit

Permalink
Merge branch 'main' into add-simulation-engine-plugin
Browse files Browse the repository at this point in the history
  • Loading branch information
danieljanes authored Jul 13, 2024
2 parents 1af5c9e + 5505e0a commit f04461a
Show file tree
Hide file tree
Showing 14 changed files with 79 additions and 78 deletions.
9 changes: 5 additions & 4 deletions e2e/bare-client-auth/client.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import numpy as np

import flwr as fl
from flwr.client import ClientApp, NumPyClient
from flwr.common import Context

model_params = np.array([1])
objective = 5


# Define Flower client
class FlowerClient(fl.client.NumPyClient):
class FlowerClient(NumPyClient):
def get_parameters(self, config):
return model_params

Expand All @@ -23,10 +24,10 @@ def evaluate(self, parameters, config):
return loss, 1, {"accuracy": accuracy}


def client_fn(cid):
def client_fn(context: Context):
return FlowerClient().to_client()


app = fl.client.ClientApp(
app = ClientApp(
client_fn=client_fn,
)
11 changes: 6 additions & 5 deletions e2e/bare-https/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@

import numpy as np

import flwr as fl
from flwr.client import ClientApp, NumPyClient, start_client
from flwr.common import Context

model_params = np.array([1])
objective = 5


# Define Flower client
class FlowerClient(fl.client.NumPyClient):
class FlowerClient(NumPyClient):
def get_parameters(self, config):
return model_params

Expand All @@ -25,17 +26,17 @@ def evaluate(self, parameters, config):
return loss, 1, {"accuracy": accuracy}


def client_fn(cid):
def client_fn(context: Context):
return FlowerClient().to_client()


app = fl.client.ClientApp(
app = ClientApp(
client_fn=client_fn,
)

if __name__ == "__main__":
# Start Flower client
fl.client.start_client(
start_client(
server_address="127.0.0.1:8080",
client=FlowerClient().to_client(),
root_certificates=Path("certificates/ca.crt").read_bytes(),
Expand Down
14 changes: 6 additions & 8 deletions e2e/bare/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import numpy as np

import flwr as fl
from flwr.common import ConfigsRecord
from flwr.client import ClientApp, NumPyClient, start_client
from flwr.common import ConfigsRecord, Context

SUBSET_SIZE = 1000
STATE_VAR = "timestamp"
Expand All @@ -14,7 +14,7 @@


# Define Flower client
class FlowerClient(fl.client.NumPyClient):
class FlowerClient(NumPyClient):
def get_parameters(self, config):
return model_params

Expand Down Expand Up @@ -51,16 +51,14 @@ def evaluate(self, parameters, config):
)


def client_fn(cid):
def client_fn(context: Context):
return FlowerClient().to_client()


app = fl.client.ClientApp(
app = ClientApp(
client_fn=client_fn,
)

if __name__ == "__main__":
# Start Flower client
fl.client.start_client(
server_address="127.0.0.1:8080", client=FlowerClient().to_client()
)
start_client(server_address="127.0.0.1:8080", client=FlowerClient().to_client())
3 changes: 2 additions & 1 deletion e2e/docker/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from torchvision.transforms import Compose, Normalize, ToTensor

from flwr.client import ClientApp, NumPyClient
from flwr.common import Context

# #############################################################################
# 1. Regular PyTorch pipeline: nn.Module, train, test, and DataLoader
Expand Down Expand Up @@ -122,7 +123,7 @@ def evaluate(self, parameters, config):
return loss, len(testloader.dataset), {"accuracy": accuracy}


def client_fn(cid: str):
def client_fn(context: Context):
"""Create and return an instance of Flower `Client`."""
return FlowerClient().to_client()

Expand Down
11 changes: 6 additions & 5 deletions e2e/framework-fastai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import torch
from fastai.vision.all import *

import flwr as fl
from flwr.client import ClientApp, NumPyClient, start_client
from flwr.common import Context

warnings.filterwarnings("ignore", category=UserWarning)

Expand All @@ -29,7 +30,7 @@


# Define Flower client
class FlowerClient(fl.client.NumPyClient):
class FlowerClient(NumPyClient):
def get_parameters(self, config):
return [val.cpu().numpy() for _, val in learn.model.state_dict().items()]

Expand All @@ -49,18 +50,18 @@ def evaluate(self, parameters, config):
return loss, len(dls.valid), {"accuracy": 1 - error_rate}


def client_fn(cid):
def client_fn(context: Context):
return FlowerClient().to_client()


app = fl.client.ClientApp(
app = ClientApp(
client_fn=client_fn,
)


if __name__ == "__main__":
# Start Flower client
fl.client.start_client(
start_client(
server_address="127.0.0.1:8080",
client=FlowerClient().to_client(),
)
13 changes: 6 additions & 7 deletions e2e/framework-jax/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,16 @@
import jax_training
import numpy as np

import flwr as fl
from flwr.client import ClientApp, NumPyClient, start_client
from flwr.common import Context

# Load data and determine model shape
train_x, train_y, test_x, test_y = jax_training.load_data()
grad_fn = jax.grad(jax_training.loss_fn)
model_shape = train_x.shape[1:]


class FlowerClient(fl.client.NumPyClient):
class FlowerClient(NumPyClient):
def __init__(self):
self.params = jax_training.load_model(model_shape)

Expand Down Expand Up @@ -48,16 +49,14 @@ def evaluate(
return float(loss), num_examples, {"loss": float(loss)}


def client_fn(cid):
def client_fn(context: Context):
return FlowerClient().to_client()


app = fl.client.ClientApp(
app = ClientApp(
client_fn=client_fn,
)

if __name__ == "__main__":
# Start Flower client
fl.client.start_client(
server_address="127.0.0.1:8080", client=FlowerClient().to_client()
)
start_client(server_address="127.0.0.1:8080", client=FlowerClient().to_client())
11 changes: 6 additions & 5 deletions e2e/framework-opacus/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10

import flwr as fl
from flwr.client import ClientApp, NumPyClient, start_client
from flwr.common import Context

# Define parameters.
PARAMS = {
Expand Down Expand Up @@ -95,7 +96,7 @@ def load_data():


# Define Flower client.
class FlowerClient(fl.client.NumPyClient):
class FlowerClient(NumPyClient):
def __init__(self, model) -> None:
super().__init__()
# Create a privacy engine which will add DP and keep track of the privacy budget.
Expand Down Expand Up @@ -139,16 +140,16 @@ def evaluate(self, parameters, config):
return float(loss), len(testloader), {"accuracy": float(accuracy)}


def client_fn(cid):
def client_fn(context: Context):
model = Net()
return FlowerClient(model).to_client()


app = fl.client.ClientApp(
app = ClientApp(
client_fn=client_fn,
)

if __name__ == "__main__":
fl.client.start_client(
start_client(
server_address="127.0.0.1:8080", client=FlowerClient(model).to_client()
)
11 changes: 6 additions & 5 deletions e2e/framework-pandas/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import numpy as np
import pandas as pd

import flwr as fl
from flwr.client import ClientApp, NumPyClient, start_client
from flwr.common import Context

df = pd.read_csv("./data/client.csv")

Expand All @@ -16,7 +17,7 @@ def compute_hist(df: pd.DataFrame, col_name: str) -> np.ndarray:


# Define Flower client
class FlowerClient(fl.client.NumPyClient):
class FlowerClient(NumPyClient):
def fit(
self, parameters: List[np.ndarray], config: Dict[str, str]
) -> Tuple[List[np.ndarray], int, Dict]:
Expand All @@ -32,17 +33,17 @@ def fit(
)


def client_fn(cid):
def client_fn(context: Context):
return FlowerClient().to_client()


app = fl.client.ClientApp(
app = ClientApp(
client_fn=client_fn,
)

if __name__ == "__main__":
# Start Flower client
fl.client.start_client(
start_client(
server_address="127.0.0.1:8080",
client=FlowerClient().to_client(),
)
11 changes: 6 additions & 5 deletions e2e/framework-pytorch-lightning/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
import pytorch_lightning as pl
import torch

import flwr as fl
from flwr.client import ClientApp, NumPyClient, start_client
from flwr.common import Context


class FlowerClient(fl.client.NumPyClient):
class FlowerClient(NumPyClient):
def __init__(self, model, train_loader, val_loader, test_loader):
self.model = model
self.train_loader = train_loader
Expand Down Expand Up @@ -51,15 +52,15 @@ def _set_parameters(model, parameters):
model.load_state_dict(state_dict, strict=True)


def client_fn(cid):
def client_fn(context: Context):
model = mnist.LitAutoEncoder()
train_loader, val_loader, test_loader = mnist.load_data()

# Flower client
return FlowerClient(model, train_loader, val_loader, test_loader).to_client()


app = fl.client.ClientApp(
app = ClientApp(
client_fn=client_fn,
)

Expand All @@ -71,7 +72,7 @@ def main() -> None:

# Flower client
client = FlowerClient(model, train_loader, val_loader, test_loader).to_client()
fl.client.start_client(server_address="127.0.0.1:8080", client=client)
start_client(server_address="127.0.0.1:8080", client=client)


if __name__ == "__main__":
Expand Down
12 changes: 6 additions & 6 deletions e2e/framework-pytorch/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from torchvision.transforms import Compose, Normalize, ToTensor
from tqdm import tqdm

import flwr as fl
from flwr.common import ConfigsRecord
from flwr.client import ClientApp, NumPyClient, start_client
from flwr.common import ConfigsRecord, Context

# #############################################################################
# 1. Regular PyTorch pipeline: nn.Module, train, test, and DataLoader
Expand Down Expand Up @@ -89,7 +89,7 @@ def load_data():


# Define Flower client
class FlowerClient(fl.client.NumPyClient):
class FlowerClient(NumPyClient):
def get_parameters(self, config):
return [val.cpu().numpy() for _, val in net.state_dict().items()]

Expand Down Expand Up @@ -136,18 +136,18 @@ def set_parameters(model, parameters):
return


def client_fn(cid):
def client_fn(context: Context):
return FlowerClient().to_client()


app = fl.client.ClientApp(
app = ClientApp(
client_fn=client_fn,
)


if __name__ == "__main__":
# Start Flower client
fl.client.start_client(
start_client(
server_address="127.0.0.1:8080",
client=FlowerClient().to_client(),
)
Loading

0 comments on commit f04461a

Please sign in to comment.