From 2ab4ed6e8648b9d668acdc87b3ef66080e31e884 Mon Sep 17 00:00:00 2001 From: lemonviv Date: Sat, 23 Mar 2024 09:56:40 +0800 Subject: [PATCH] Update the HFL example and README --- examples/hfl/README.md | 10 +++++----- examples/hfl/src/client.py | 2 ++ examples/hfl/src/server.py | 1 + 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/examples/hfl/README.md b/examples/hfl/README.md index cf20e64cd..2916bc560 100644 --- a/examples/hfl/README.md +++ b/examples/hfl/README.md @@ -27,7 +27,7 @@ This example uses the Bank dataset and an MLP model in FL. ## Preparation -Go to the Conda environment that contains the Singa library, and run +Go to the Conda environment that contains the Singa library, and install the required libraries. ```bash pip install -r requirements.txt @@ -41,18 +41,18 @@ Download the bank dataset and split it into 3 partitions. # 3. run the following command which: # (1) splits the dataset into N subsets # (2) splits each subsets into train set and test set (8:2) -python -m bank N +python -m bank 3 ``` ## Run the example -Run the server first (set the number of epochs to 3) +Run the server first (set the maximum number of epochs to 3 by the "-m" parameter) ```bash python -m src.server -m 3 --num_clients 3 ``` -Then, start 3 clients in different terminal +Then, start 3 clients in different terminals (similarly set the maximum number of epochs to 3) ```bash python -m src.client --model mlp --data bank -m 3 -i 0 -d non-iid @@ -60,4 +60,4 @@ python -m src.client --model mlp --data bank -m 3 -i 1 -d non-iid python -m src.client --model mlp --data bank -m 3 -i 2 -d non-iid ``` -Finally, the server and clients finish the FL training. \ No newline at end of file +Finally, the server and clients finish the FL training. diff --git a/examples/hfl/src/client.py b/examples/hfl/src/client.py index 80ab11f3a..dbff42b4d 100644 --- a/examples/hfl/src/client.py +++ b/examples/hfl/src/client.py @@ -40,6 +40,7 @@ np_dtype = {"float16": np.float16, "float32": np.float32} singa_dtype = {"float16": tensor.float16, "float32": tensor.float32} + class Client: """Client sends and receives protobuf messages. @@ -63,6 +64,7 @@ def __init__( Args: global_rank (int, optional): The rank in training process. Defaults to 0. + Provided by the '-i' parameter (device_id) in the running script. host (str, optional): Host ip address. Defaults to '127.0.0.1'. port (str, optional): Port. Defaults to 1234. """ diff --git a/examples/hfl/src/server.py b/examples/hfl/src/server.py index 7450cc1cf..68780e13c 100644 --- a/examples/hfl/src/server.py +++ b/examples/hfl/src/server.py @@ -80,6 +80,7 @@ def __start_rank_pairing(self) -> None: """Start pair each client to a global rank""" for _ in range(self.num_clients): conn, addr = self.sock.accept() + # rank is the global device_id when initializing the client rank = utils.receive_int(conn) self.conns[rank] = conn self.addrs[rank] = addr