Skip to content

Commit

Permalink
Add Workflow (#3017)
Browse files Browse the repository at this point in the history
Co-authored-by: Daniel J. Beutel <[email protected]>
  • Loading branch information
panh99 and danieljanes authored Feb 28, 2024
1 parent 256548e commit d663cd4
Show file tree
Hide file tree
Showing 13 changed files with 680 additions and 177 deletions.
16 changes: 12 additions & 4 deletions examples/app-pytorch/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@ Let's assume the following project structure:
```bash
$ tree .
.
├── client.py # <-- contains `ClientApp`
├── server.py # <-- contains `ServerApp`
├── task.py # <-- task-specific code (model, data)
└── requirements.txt # <-- dependencies
├── client.py # <-- contains `ClientApp`
├── server.py # <-- contains `ServerApp`
├── server_workflow.py # <-- contains `ServerApp` with workflow
├── server_custom.py # <-- contains `ServerApp` with custom main function
├── task.py # <-- task-specific code (model, data)
└── requirements.txt # <-- dependencies
```

## Install dependencies
Expand Down Expand Up @@ -52,6 +54,12 @@ With both the long-running server (SuperLink) and two clients (SuperNode) up and
flower-server-app server:app --insecure
```

Or, to try the workflow example, run:

```bash
flower-server-app server_workflow:app --insecure
```

Or, to try the custom server function example, run:

```bash
Expand Down
55 changes: 55 additions & 0 deletions examples/app-pytorch/server_workflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from typing import List, Tuple

import flwr as fl
from flwr.common import Context, Metrics
from flwr.server import Driver, LegacyContext


# Define metric aggregation function
def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
examples = [num_examples for num_examples, _ in metrics]

# Multiply accuracy of each client by number of examples used
train_losses = [num_examples * m["train_loss"] for num_examples, m in metrics]
train_accuracies = [
num_examples * m["train_accuracy"] for num_examples, m in metrics
]
val_losses = [num_examples * m["val_loss"] for num_examples, m in metrics]
val_accuracies = [num_examples * m["val_accuracy"] for num_examples, m in metrics]

# Aggregate and return custom metric (weighted average)
return {
"train_loss": sum(train_losses) / sum(examples),
"train_accuracy": sum(train_accuracies) / sum(examples),
"val_loss": sum(val_losses) / sum(examples),
"val_accuracy": sum(val_accuracies) / sum(examples),
}


# Define strategy
strategy = fl.server.strategy.FedAvg(
fraction_fit=1.0, # Select all available clients
fraction_evaluate=0.0, # Disable evaluation
min_available_clients=2,
fit_metrics_aggregation_fn=weighted_average,
)


# Run via `flower-server-app server_workflow:app`
app = fl.server.ServerApp()


@app.main()
def main(driver: Driver, context: Context) -> None:
# Construct the LegacyContext
context = LegacyContext(
state=context.state,
config=fl.server.ServerConfig(num_rounds=3),
strategy=strategy,
)

# Create the workflow
workflow = fl.server.workflow.DefaultWorkflow()

# Execute
workflow(driver, context)
4 changes: 4 additions & 0 deletions src/py/flwr/server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@


from . import strategy
from . import workflow as workflow
from .app import run_driver_api as run_driver_api
from .app import run_fleet_api as run_fleet_api
from .app import run_superlink as run_superlink
from .app import start_server as start_server
from .client_manager import ClientManager as ClientManager
from .client_manager import SimpleClientManager as SimpleClientManager
from .compat import LegacyContext as LegacyContext
from .compat import start_driver as start_driver
from .driver import Driver as Driver
from .history import History as History
Expand All @@ -34,6 +36,7 @@
"ClientManager",
"Driver",
"History",
"LegacyContext",
"run_driver_api",
"run_fleet_api",
"run_server_app",
Expand All @@ -45,4 +48,5 @@
"start_driver",
"start_server",
"strategy",
"workflow",
]
2 changes: 2 additions & 0 deletions src/py/flwr/server/compat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@


from .app import start_driver as start_driver
from .legacy_context import LegacyContext as LegacyContext

__all__ = [
"LegacyContext",
"start_driver",
]
95 changes: 7 additions & 88 deletions src/py/flwr/server/compat/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,21 @@


import sys
import threading
import time
from logging import INFO
from pathlib import Path
from typing import Dict, Optional, Union
from typing import Optional, Union

from flwr.common import EventType, event
from flwr.common.address import parse_address
from flwr.common.logger import log, warn_deprecated_feature
from flwr.proto import driver_pb2 # pylint: disable=E0611
from flwr.server.client_manager import ClientManager
from flwr.server.history import History
from flwr.server.server import Server, init_defaults, run_fl
from flwr.server.server_config import ServerConfig
from flwr.server.strategy import Strategy

from ..driver import Driver
from ..driver.grpc_driver import GrpcDriver
from .driver_client_proxy import DriverClientProxy
from .app_utils import start_update_client_manager_thread

DEFAULT_SERVER_ADDRESS_DRIVER = "[::]:9091"

Expand Down Expand Up @@ -104,11 +100,7 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals
"""
event(EventType.START_DRIVER_ENTER)

if driver:
# pylint: disable=protected-access
grpc_driver, _ = driver._get_grpc_driver_and_run_id()
# pylint: enable=protected-access
else:
if driver is None:
# Not passing a `Driver` object is deprecated
warn_deprecated_feature("start_driver")

Expand All @@ -122,12 +114,9 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals
# Create the Driver
if isinstance(root_certificates, str):
root_certificates = Path(root_certificates).read_bytes()
grpc_driver = GrpcDriver(
driver = Driver(
driver_service_address=address, root_certificates=root_certificates
)
grpc_driver.connect()

lock = threading.Lock()

# Initialize the Driver API server and config
initialized_server, initialized_config = init_defaults(
Expand All @@ -142,18 +131,10 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals
initialized_config,
)

f_stop = threading.Event()
# Start the thread updating nodes
thread = threading.Thread(
target=update_client_manager,
args=(
grpc_driver,
initialized_server.client_manager(),
lock,
f_stop,
),
thread, f_stop = start_update_client_manager_thread(
driver, initialized_server.client_manager()
)
thread.start()

# Start training
hist = run_fl(
Expand All @@ -164,72 +145,10 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals
f_stop.set()

# Stop the Driver API server and the thread
with lock:
if driver:
del driver
else:
grpc_driver.disconnect()
del driver

thread.join()

event(EventType.START_SERVER_LEAVE)

return hist


def update_client_manager(
driver: GrpcDriver,
client_manager: ClientManager,
lock: threading.Lock,
f_stop: threading.Event,
) -> None:
"""Update the nodes list in the client manager.
This function periodically communicates with the associated driver to get all
node_ids. Each node_id is then converted into a `DriverClientProxy` instance
and stored in the `registered_nodes` dictionary with node_id as key.
New nodes will be added to the ClientManager via `client_manager.register()`,
and dead nodes will be removed from the ClientManager via
`client_manager.unregister()`.
"""
# Request for run_id
run_id = driver.create_run(
driver_pb2.CreateRunRequest() # pylint: disable=E1101
).run_id

# Loop until the driver is disconnected
registered_nodes: Dict[int, DriverClientProxy] = {}
while not f_stop.is_set():
with lock:
# End the while loop if the driver is disconnected
if driver.stub is None:
break
get_nodes_res = driver.get_nodes(
req=driver_pb2.GetNodesRequest(run_id=run_id) # pylint: disable=E1101
)
all_node_ids = {node.node_id for node in get_nodes_res.nodes}
dead_nodes = set(registered_nodes).difference(all_node_ids)
new_nodes = all_node_ids.difference(registered_nodes)

# Unregister dead nodes
for node_id in dead_nodes:
client_proxy = registered_nodes[node_id]
client_manager.unregister(client_proxy)
del registered_nodes[node_id]

# Register new nodes
for node_id in new_nodes:
client_proxy = DriverClientProxy(
node_id=node_id,
driver=driver,
anonymous=False,
run_id=run_id,
)
if client_manager.register(client_proxy):
registered_nodes[node_id] = client_proxy
else:
raise RuntimeError("Could not register node.")

# Sleep for 3 seconds
time.sleep(3)
84 changes: 0 additions & 84 deletions src/py/flwr/server/compat/app_test.py

This file was deleted.

Loading

0 comments on commit d663cd4

Please sign in to comment.