From 40f257ee57c4a5c35e2ca9aace67026ba4a54e57 Mon Sep 17 00:00:00 2001 From: Charles Beauville Date: Fri, 5 Apr 2024 15:56:36 +0200 Subject: [PATCH 1/2] Initial draft --- examples/quickstart-pytorch/README.md | 8 ++- examples/quickstart-pytorch/client.py | 70 ++++++++++++++++----------- 2 files changed, 49 insertions(+), 29 deletions(-) diff --git a/examples/quickstart-pytorch/README.md b/examples/quickstart-pytorch/README.md index 978191cc0ecd..03ac96190622 100644 --- a/examples/quickstart-pytorch/README.md +++ b/examples/quickstart-pytorch/README.md @@ -90,7 +90,13 @@ flower-superlink --insecure Start 2 Flower `SuperNodes` in 2 separate terminal windows, using: ```bash -flower-client-app client:app --insecure +flower-client-app client:partition_0 --insecure +``` + +And: + +```bash +flower-client-app client:partition_1 --insecure ``` ### 3. Run the Flower App diff --git a/examples/quickstart-pytorch/client.py b/examples/quickstart-pytorch/client.py index be4be88b8f8d..c25fb8a15216 100644 --- a/examples/quickstart-pytorch/client.py +++ b/examples/quickstart-pytorch/client.py @@ -94,59 +94,73 @@ def apply_transforms(batch): # 2. Federation of the pipeline with Flower # ############################################################################# -# Get partition id -parser = argparse.ArgumentParser(description="Flower") -parser.add_argument( - "--partition-id", - choices=[0, 1, 2], - default=0, - type=int, - help="Partition of the dataset divided into 3 iid partitions created artificially.", -) -partition_id = parser.parse_known_args()[0].partition_id - -# Load model and data (simple CNN, CIFAR-10) -net = Net().to(DEVICE) -trainloader, testloader = load_data(partition_id=partition_id) - # Define Flower client class FlowerClient(NumPyClient): + def __init__(self, net, data): + super().__init__() + self.net = net + self.trainloader, self.testloader = data + def get_parameters(self, config): - return [val.cpu().numpy() for _, val in net.state_dict().items()] + return [val.cpu().numpy() for _, val in self.net.state_dict().items()] def set_parameters(self, parameters): - params_dict = zip(net.state_dict().keys(), parameters) + params_dict = zip(self.net.state_dict().keys(), parameters) state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) - net.load_state_dict(state_dict, strict=True) + self.net.load_state_dict(state_dict, strict=True) def fit(self, parameters, config): self.set_parameters(parameters) - train(net, trainloader, epochs=1) - return self.get_parameters(config={}), len(trainloader.dataset), {} + train(self.net, self.trainloader, epochs=1) + return self.get_parameters(config={}), len(self.trainloader.dataset), {} def evaluate(self, parameters, config): self.set_parameters(parameters) - loss, accuracy = test(net, testloader) - return loss, len(testloader.dataset), {"accuracy": accuracy} + loss, accuracy = test(self.net, self.testloader) + return loss, len(self.testloader.dataset), {"accuracy": accuracy} + +def get_client_fn(partition_id): + net = Net().to(DEVICE) + data = load_data(partition_id=partition_id) -def client_fn(cid: str): - """Create and return an instance of Flower `Client`.""" - return FlowerClient().to_client() + def client_fn(cid: str): + """Create and return an instance of Flower `Client`.""" + return FlowerClient(net, data).to_client() + + return client_fn # Flower ClientApp -app = ClientApp( - client_fn=client_fn, +partition_0 = ClientApp( + client_fn=get_client_fn(0), ) +partition_1 = ClientApp( + client_fn=get_client_fn(0), +) # Legacy mode if __name__ == "__main__": from flwr.client import start_client + # Get partition id + parser = argparse.ArgumentParser(description="Flower") + parser.add_argument( + "--partition-id", + choices=[0, 1, 2], + default=0, + type=int, + help="Partition of the dataset divided into 3 iid partitions created artificially.", + ) + partition_id = parser.parse_known_args()[0].partition_id + + # Load model and data (simple CNN, CIFAR-10) + net = Net().to(DEVICE) + data = load_data(partition_id=partition_id) + start_client( server_address="127.0.0.1:8080", - client=FlowerClient().to_client(), + client=FlowerClient(net, data).to_client(), ) From 745739c16a82e85a60fd097344ec95f411a59617 Mon Sep 17 00:00:00 2001 From: Charles Beauville Date: Fri, 5 Apr 2024 16:00:28 +0200 Subject: [PATCH 2/2] Use correct partition --- examples/quickstart-pytorch/client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/quickstart-pytorch/client.py b/examples/quickstart-pytorch/client.py index c25fb8a15216..9cf7da469b41 100644 --- a/examples/quickstart-pytorch/client.py +++ b/examples/quickstart-pytorch/client.py @@ -138,7 +138,7 @@ def client_fn(cid: str): ) partition_1 = ClientApp( - client_fn=get_client_fn(0), + client_fn=get_client_fn(1), ) # Legacy mode