Skip to content

Commit

Permalink
Migrate fed kaplan to flwr next (#2996)
Browse files Browse the repository at this point in the history
Co-authored-by: Daniel J. Beutel <[email protected]>
  • Loading branch information
adam-narozniak and danieljanes authored Feb 23, 2024
1 parent 3c28c26 commit 52b1b09
Show file tree
Hide file tree
Showing 8 changed files with 56 additions and 71 deletions.
29 changes: 18 additions & 11 deletions examples/federated-kaplan-meier-fitter/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ This will create a new directory called `federated-kaplan-meier-fitter` containi
-- client.py
-- server.py
-- centralized.py
-- run.sh
-- README.md
```

Expand Down Expand Up @@ -68,24 +67,32 @@ pip install -r requirements.txt

## Run Federated Survival Analysis with Flower and lifelines's KaplanMeierFitter

Afterwards you are ready to start the Flower server as well as the clients. You can simply start the server in a terminal as follows:
### Start the long-running Flower server (SuperLink)

```shell
$ python3 server.py
```bash
flower-superlink --insecure
```

Now you are ready to start the Flower clients which will participate in the learning. To do so simply open two more terminal windows and run the following commands.
### Start the long-running Flower client (SuperNode)

Start client 1 in the first terminal:
In a new terminal window, start the first long-running Flower client:

```shell
$ python3 client.py --node-id 0
```bash
flower-client-app client:node_1_app --insecure
```

Start client 2 in the second terminal:
In yet another new terminal window, start the second long-running Flower client:

```shell
$ python3 client.py --node-id 1
```bash
flower-client-app client:node_2_app --insecure
```

### Run the Flower App

With both the long-running server (SuperLink) and two clients (SuperNode) up and running, we can now run the actual Flower App:

```bash
flower-server-app server:app --insecure
```

You will see that the server is printing survival function, median survival time and saves the plot with the survival function.
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion examples/federated-kaplan-meier-fitter/centralized.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@
print(fitter.median_survival_time_)
fitter.plot_survival_function()
plt.title("Survival function of fruit flies (Walton's data)", fontsize=16)
plt.savefig("./survival_function_centralized.png", dpi=200)
plt.savefig("./_static/survival_function_centralized.png", dpi=200)
print("Centralized survival function saved.")
45 changes: 21 additions & 24 deletions examples/federated-kaplan-meier-fitter/client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import argparse
from typing import Dict, List, Tuple

import flwr as fl
Expand Down Expand Up @@ -41,28 +40,26 @@ def fit(
)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Flower")
parser.add_argument(
"--node-id",
type=int,
required=True,
help="Node id. Each node holds different part of the dataset.",
)
args = parser.parse_args()
partition_id = args.node_id
# Prepare data
X = load_waltons()
partitioner = NaturalIdPartitioner(partition_by="group")
partitioner.dataset = Dataset.from_pandas(X)

# Prepare data
X = load_waltons()
partitioner = NaturalIdPartitioner(partition_by="group")
partitioner.dataset = Dataset.from_pandas(X)
partition = partitioner.load_partition(partition_id).to_pandas()
events = partition["E"].values
times = partition["T"].values

# Start Flower client
client = FlowerClient(times=times, events=events).to_client()
fl.client.start_client(
server_address="127.0.0.1:8080",
client=client,
)
def get_client_fn(partition_id: int):
def client_fn(cid: str):
partition = partitioner.load_partition(partition_id).to_pandas()
events = partition["E"].values
times = partition["T"].values
return FlowerClient(times=times, events=events).to_client()

return client_fn


# Run via `flower-client-app client:app`
node_1_app = fl.client.ClientApp(
client_fn=get_client_fn(0),
)
node_2_app = fl.client.ClientApp(
client_fn=get_client_fn(1),
)
2 changes: 1 addition & 1 deletion examples/federated-kaplan-meier-fitter/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ maintainers = ["The Flower Authors <[email protected]>"]

[tool.poetry.dependencies]
python = ">=3.9,<3.11"
flwr = ">=1.0,<2.0"
flwr-nightly = "*"
flwr-datasets = ">=0.0.2,<1.0.0"
numpy = ">=1.23.2"
pandas = ">=2.0.0"
Expand Down
2 changes: 1 addition & 1 deletion examples/federated-kaplan-meier-fitter/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
flwr>=1.0, <2.0
flwr-nightly
flwr-datasets>=0.0.2, <1.0.0
numpy>=1.23.2
pandas>=2.0.0
Expand Down
17 changes: 0 additions & 17 deletions examples/federated-kaplan-meier-fitter/run.sh

This file was deleted.

30 changes: 14 additions & 16 deletions examples/federated-kaplan-meier-fitter/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,13 @@ def aggregate_fit(
sorted_times = combined_times[args_sorted]
sorted_events = combined_events[args_sorted]
self.fitter.fit(sorted_times, sorted_events)
print("Survival function:")
print(self.fitter.survival_function_)
self.fitter.plot_survival_function()
plt.title("Survival function of fruit flies (Walton's data)", fontsize=16)
plt.savefig("./_static/survival_function_federated.png", dpi=200)
print("Mean survival time:")
print(self.fitter.median_survival_time_)
return None, {}

# The methods below return None or empty results.
Expand Down Expand Up @@ -129,19 +136,10 @@ def configure_evaluate(
return []


if __name__ == "__main__":
fitter = KaplanMeierFitter() # You can choose other method that work on E, T data
strategy = EventTimeFitterStrategy(min_num_clients=2, fitter=fitter)
# Start Flower server
fl.server.start_server(
server_address="127.0.0.1:8080",
config=fl.server.ServerConfig(num_rounds=1),
strategy=strategy,
)
print("Survival function:")
print(strategy.fitter.survival_function_)
strategy.fitter.plot_survival_function()
plt.title("Survival function of fruit flies (Walton's data)", fontsize=16)
plt.savefig("./survival_function_federated.png", dpi=200)
print("Mean survival time:")
print(strategy.fitter.median_survival_time_)
fitter = KaplanMeierFitter() # You can choose other method that work on E, T data
strategy = EventTimeFitterStrategy(min_num_clients=2, fitter=fitter)

app = fl.server.ServerApp(
config=fl.server.ServerConfig(num_rounds=1),
strategy=strategy,
)

0 comments on commit 52b1b09

Please sign in to comment.