Skip to content

Commit

Permalink
feat(framework) Use types in template configs (#3875)
Browse files Browse the repository at this point in the history
  • Loading branch information
charlesbvll authored Jul 23, 2024
1 parent 7ad7bd4 commit f3602b6
Show file tree
Hide file tree
Showing 19 changed files with 81 additions and 44 deletions.
8 changes: 5 additions & 3 deletions src/py/flwr/cli/new/templates/app/code/client.hf.py.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@ from $import_name.task import (

# Flower client
class FlowerClient(NumPyClient):
def __init__(self, net, trainloader, testloader):
def __init__(self, net, trainloader, testloader, local_epochs):
self.net = net
self.trainloader = trainloader
self.testloader = testloader
self.local_epochs = local_epochs

def get_parameters(self, config):
return get_weights(self.net)
Expand All @@ -33,7 +34,7 @@ class FlowerClient(NumPyClient):
train(
self.net,
self.trainloader,
epochs=int(self.context.run_config["local-epochs"]),
epochs=self.local_epochs,
)
return self.get_parameters(config={}), len(self.trainloader), {}

Expand All @@ -52,9 +53,10 @@ def client_fn(context: Context):
partition_id = int(context.node_config["partition-id"])
num_partitions = int(context.node_config["num-partitions"])
trainloader, valloader = load_data(partition_id, num_partitions)
local_epochs = context.run_config["local-epochs"]

# Return Client instance
return FlowerClient(net, trainloader, valloader).to_client()
return FlowerClient(net, trainloader, valloader, local_epochs).to_client()


# Flower ClientApp
Expand Down
34 changes: 26 additions & 8 deletions src/py/flwr/cli/new/templates/app/code/client.mlx.py.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,22 @@ from $import_name.task import (

# Define Flower Client and client_fn
class FlowerClient(NumPyClient):
def __init__(self, data):
num_layers = int(self.context.run_config["num-layers"])
hidden_dim = int(self.context.run_config["hidden-dim"])
num_classes = 10
batch_size = int(self.context.run_config["batch-size"])
learning_rate = float(self.context.run_config["lr"])
num_epochs = int(self.context.run_config["local-epochs"])
def __init__(
self,
data,
num_layers,
hidden_dim,
num_classes,
batch_size,
learning_rate,
num_epochs,
):
self.num_layers = num_layers
self.hidden_dim = hidden_dim
self.num_classes = num_classes
self.batch_size = batch_size
self.learning_rate = learning_rate
self.num_epochs = num_epochs

self.train_images, self.train_labels, self.test_images, self.test_labels = data
self.model = MLP(
Expand Down Expand Up @@ -65,8 +74,17 @@ def client_fn(context: Context):
num_partitions = int(context.node_config["num-partitions"])
data = load_data(partition_id, num_partitions)

num_layers = context.run_config["num-layers"]
hidden_dim = context.run_config["hidden-dim"]
num_classes = 10
batch_size = context.run_config["batch-size"]
learning_rate = context.run_config["lr"]
num_epochs = context.run_config["local-epochs"]

# Return Client instance
return FlowerClient(data).to_client()
return FlowerClient(
data, num_layers, hidden_dim, num_classes, batch_size, learning_rate, num_epochs
).to_client()


# Flower ClientApp
Expand Down
8 changes: 5 additions & 3 deletions src/py/flwr/cli/new/templates/app/code/client.pytorch.py.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,19 @@ from $import_name.task import (

# Define Flower Client and client_fn
class FlowerClient(NumPyClient):
def __init__(self, net, trainloader, valloader):
def __init__(self, net, trainloader, valloader, local_epochs):
self.net = net
self.trainloader = trainloader
self.valloader = valloader
self.local_epochs = local_epochs

def fit(self, parameters, config):
set_weights(self.net, parameters)
results = train(
self.net,
self.trainloader,
self.valloader,
int(self.context.run_config["local-epochs"]),
self.local_epochs,
DEVICE,
)
return get_weights(self.net), len(self.trainloader.dataset), results
Expand All @@ -44,9 +45,10 @@ def client_fn(context: Context):
partition_id = int(context.node_config["partition-id"])
num_partitions = int(context.node_config["num-partitions"])
trainloader, valloader = load_data(partition_id, num_partitions)
local_epochs = context.run_config["local-epochs"]

# Return Client instance
return FlowerClient(net, trainloader, valloader).to_client()
return FlowerClient(net, trainloader, valloader, local_epochs).to_client()


# Flower ClientApp
Expand Down
20 changes: 15 additions & 5 deletions src/py/flwr/cli/new/templates/app/code/client.tensorflow.py.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,17 @@ from $import_name.task import load_data, load_model

# Define Flower Client and client_fn
class FlowerClient(NumPyClient):
def __init__(self, model, x_train, y_train, x_test, y_test):
def __init__(
self, model, x_train, y_train, x_test, y_test, epochs, batch_size, verbose
):
self.model = model
self.x_train = x_train
self.y_train = y_train
self.x_test = x_test
self.y_test = y_test
self.epochs = epochs
self.batch_size = batch_size
self.verbose = verbose

def get_parameters(self, config):
return self.model.get_weights()
Expand All @@ -23,9 +28,9 @@ class FlowerClient(NumPyClient):
self.model.fit(
self.x_train,
self.y_train,
epochs=int(self.context.run_config["local-epochs"]),
batch_size=int(self.context.run_config["batch-size"]),
verbose=bool(self.context.run_config.get("verbose")),
epochs=self.epochs,
batch_size=self.batch_size,
verbose=self.verbose,
)
return self.model.get_weights(), len(self.x_train), {}

Expand All @@ -42,9 +47,14 @@ def client_fn(context: Context):
partition_id = int(context.node_config["partition-id"])
num_partitions = int(context.node_config["num-partitions"])
x_train, y_train, x_test, y_test = load_data(partition_id, num_partitions)
epochs = context.run_config["local-epochs"]
batch_size = context.run_config["batch-size"]
verbose = context.run_config.get("verbose")

# Return Client instance
return FlowerClient(net, x_train, y_train, x_test, y_test).to_client()
return FlowerClient(
net, x_train, y_train, x_test, y_test, epochs, batch_size, verbose
).to_client()


# Flower ClientApp
Expand Down
3 changes: 2 additions & 1 deletion src/py/flwr/cli/new/templates/app/code/server.hf.py.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ from flwr.server import ServerApp, ServerAppComponents, ServerConfig

def server_fn(context: Context):
# Read from config
num_rounds = int(context.run_config["num-server-rounds"])
num_rounds = context.run_config["num-server-rounds"]

# Define strategy
strategy = FedAvg(
Expand All @@ -18,5 +18,6 @@ def server_fn(context: Context):

return ServerAppComponents(strategy=strategy, config=config)


# Create ServerApp
app = ServerApp(server_fn=server_fn)
3 changes: 2 additions & 1 deletion src/py/flwr/cli/new/templates/app/code/server.jax.py.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@ from flwr.server import ServerApp, ServerAppComponents, ServerConfig

def server_fn(context: Context):
# Read from config
num_rounds = int(context.run_config["num-server-rounds"])
num_rounds = context.run_config["num-server-rounds"]

# Define strategy
strategy = FedAvg()
config = ServerConfig(num_rounds=num_rounds)

return ServerAppComponents(strategy=strategy, config=config)


# Create ServerApp
app = ServerApp(server_fn=server_fn)
3 changes: 2 additions & 1 deletion src/py/flwr/cli/new/templates/app/code/server.mlx.py.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@ from flwr.server.strategy import FedAvg

def server_fn(context: Context):
# Read from config
num_rounds = int(context.run_config["num-server-rounds"])
num_rounds = context.run_config["num-server-rounds"]

# Define strategy
strategy = FedAvg()
config = ServerConfig(num_rounds=num_rounds)

return ServerAppComponents(strategy=strategy, config=config)


# Create ServerApp
app = ServerApp(server_fn=server_fn)
3 changes: 2 additions & 1 deletion src/py/flwr/cli/new/templates/app/code/server.numpy.py.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@ from flwr.server.strategy import FedAvg

def server_fn(context: Context):
# Read from config
num_rounds = int(context.run_config["num-server-rounds"])
num_rounds = context.run_config["num-server-rounds"]

# Define strategy
strategy = FedAvg()
config = ServerConfig(num_rounds=num_rounds)

return ServerAppComponents(strategy=strategy, config=config)


# Create ServerApp
app = ServerApp(server_fn=server_fn)
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ parameters = ndarrays_to_parameters(ndarrays)

def server_fn(context: Context):
# Read from config
num_rounds = int(context.run_config["num-server-rounds"])
num_rounds = context.run_config["num-server-rounds"]

# Define strategy
strategy = FedAvg(
Expand Down
3 changes: 2 additions & 1 deletion src/py/flwr/cli/new/templates/app/code/server.sklearn.py.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ from flwr.server.strategy import FedAvg

def server_fn(context: Context):
# Read from config
num_rounds = int(context.run_config["num-server-rounds"])
num_rounds = context.run_config["num-server-rounds"]

# Define strategy
strategy = FedAvg(
Expand All @@ -19,5 +19,6 @@ def server_fn(context: Context):

return ServerAppComponents(strategy=strategy, config=config)


# Create ServerApp
app = ServerApp(server_fn=server_fn)
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ parameters = ndarrays_to_parameters(load_model().get_weights())

def server_fn(context: Context):
# Read from config
num_rounds = int(context.run_config["num-server-rounds"])
num_rounds = context.run_config["num-server-rounds"]

# Define strategy
strategy = strategy = FedAvg(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ serverapp = "$import_name.app:server"
clientapp = "$import_name.app:client"
[tool.flwr.app.config]
num-server-rounds = "3"
num-server-rounds = 3
[tool.flwr.federations]
default = "localhost"
Expand Down
4 changes: 2 additions & 2 deletions src/py/flwr/cli/new/templates/app/pyproject.hf.toml.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ serverapp = "$import_name.server_app:app"
clientapp = "$import_name.client_app:app"
[tool.flwr.app.config]
num-server-rounds = "3"
local-epochs = "1"
num-server-rounds = 3
local-epochs = 1
[tool.flwr.federations]
default = "localhost"
Expand Down
2 changes: 1 addition & 1 deletion src/py/flwr/cli/new/templates/app/pyproject.jax.toml.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ serverapp = "$import_name.server_app:app"
clientapp = "$import_name.client_app:app"
[tool.flwr.app.config]
num-server-rounds = "3"
num-server-rounds = 3
[tool.flwr.federations]
default = "localhost"
Expand Down
12 changes: 6 additions & 6 deletions src/py/flwr/cli/new/templates/app/pyproject.mlx.toml.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@ serverapp = "$import_name.server_app:app"
clientapp = "$import_name.client_app:app"
[tool.flwr.app.config]
num-server-rounds = "3"
local-epochs = "1"
num-layers = "2"
hidden-dim = "32"
batch-size = "256"
lr = "0.1"
num-server-rounds = 3
local-epochs = 1
num-layers = 2
hidden-dim = 32
batch-size = 256
lr = 0.1
[tool.flwr.federations]
default = "localhost"
Expand Down
2 changes: 1 addition & 1 deletion src/py/flwr/cli/new/templates/app/pyproject.numpy.toml.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ serverapp = "$import_name.server_app:app"
clientapp = "$import_name.client_app:app"

[tool.flwr.app.config]
num-server-rounds = "3"
num-server-rounds = 3

[tool.flwr.federations]
default = "localhost"
Expand Down
4 changes: 2 additions & 2 deletions src/py/flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ serverapp = "$import_name.server_app:app"
clientapp = "$import_name.client_app:app"
[tool.flwr.app.config]
num-server-rounds = "3"
local-epochs = "1"
num-server-rounds = 3
local-epochs = 1
[tool.flwr.federations]
default = "localhost"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ serverapp = "$import_name.server_app:app"
clientapp = "$import_name.client_app:app"

[tool.flwr.app.config]
num-server-rounds = "3"
num-server-rounds = 3

[tool.flwr.federations]
default = "localhost"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ serverapp = "$import_name.server_app:app"
clientapp = "$import_name.client_app:app"

[tool.flwr.app.config]
num-server-rounds = "3"
local-epochs = "1"
batch-size = "32"
verbose = "" # Empty string means False
num-server-rounds = 3
local-epochs = 1
batch-size = 32
verbose = false

[tool.flwr.federations]
default = "localhost"
Expand Down

0 comments on commit f3602b6

Please sign in to comment.