diff --git a/doc/source/tutorial-series-get-started-with-flower-pytorch.ipynb b/doc/source/tutorial-series-get-started-with-flower-pytorch.ipynb index ce4c2bb63606..bbd916b32375 100644 --- a/doc/source/tutorial-series-get-started-with-flower-pytorch.ipynb +++ b/doc/source/tutorial-series-get-started-with-flower-pytorch.ipynb @@ -484,7 +484,7 @@ " min_available_clients=10, # Wait until all 10 clients are available\n", ")\n", "\n", - "# Specify the resources each of your clients need. By default, each \n", + "# Specify the resources each of your clients need. By default, each\n", "# client will be allocated 1x CPU and 0x CPUs\n", "client_resources = {\"num_cpus\": 1, \"num_gpus\": 0.0}\n", "if DEVICE.type == \"cuda\":\n", diff --git a/examples/pytorch-from-centralized-to-federated/cifar.py b/examples/pytorch-from-centralized-to-federated/cifar.py index a374909c33b2..e8f3ec3fd724 100644 --- a/examples/pytorch-from-centralized-to-federated/cifar.py +++ b/examples/pytorch-from-centralized-to-federated/cifar.py @@ -73,10 +73,10 @@ def apply_transforms(batch): def train( - net: Net, - trainloader: torch.utils.data.DataLoader, - epochs: int, - device: torch.device, # pylint: disable=no-member + net: Net, + trainloader: torch.utils.data.DataLoader, + epochs: int, + device: torch.device, # pylint: disable=no-member ) -> None: """Train the network.""" # Define loss and optimizer @@ -110,9 +110,9 @@ def train( def test( - net: Net, - testloader: torch.utils.data.DataLoader, - device: torch.device, # pylint: disable=no-member + net: Net, + testloader: torch.utils.data.DataLoader, + device: torch.device, # pylint: disable=no-member ) -> Tuple[float, float]: """Validate the network on the entire test set.""" # Define loss and metrics diff --git a/examples/pytorch-from-centralized-to-federated/client.py b/examples/pytorch-from-centralized-to-federated/client.py index df4da7c11cff..61c7e7f762b3 100644 --- a/examples/pytorch-from-centralized-to-federated/client.py +++ b/examples/pytorch-from-centralized-to-federated/client.py @@ -24,10 +24,10 @@ class CifarClient(fl.client.NumPyClient): """Flower client implementing CIFAR-10 image classification using PyTorch.""" def __init__( - self, - model: cifar.Net, - trainloader: DataLoader, - testloader: DataLoader, + self, + model: cifar.Net, + trainloader: DataLoader, + testloader: DataLoader, ) -> None: self.model = model self.trainloader = trainloader @@ -61,7 +61,7 @@ def set_parameters(self, parameters: List[np.ndarray]) -> None: self.model.load_state_dict(state_dict, strict=True) def fit( - self, parameters: List[np.ndarray], config: Dict[str, str] + self, parameters: List[np.ndarray], config: Dict[str, str] ) -> Tuple[List[np.ndarray], int, Dict]: # Set model parameters, train model, return updated model parameters self.set_parameters(parameters) @@ -69,7 +69,7 @@ def fit( return self.get_parameters(config={}), len(self.trainloader.dataset), {} def evaluate( - self, parameters: List[np.ndarray], config: Dict[str, str] + self, parameters: List[np.ndarray], config: Dict[str, str] ) -> Tuple[float, int, Dict]: # Set model parameters, evaluate model on local test dataset, return result self.set_parameters(parameters) diff --git a/examples/quickstart-pytorch-lightning/client.py b/examples/quickstart-pytorch-lightning/client.py index 8e07494b6492..1dabd5732b9b 100644 --- a/examples/quickstart-pytorch-lightning/client.py +++ b/examples/quickstart-pytorch-lightning/client.py @@ -10,6 +10,7 @@ disable_progress_bar() + class FlowerClient(fl.client.NumPyClient): def __init__(self, model, train_loader, val_loader, test_loader): self.model = model @@ -55,7 +56,6 @@ def _set_parameters(model, parameters): def main() -> None: - parser = argparse.ArgumentParser(description="Flower") parser.add_argument( "--node-id", diff --git a/examples/quickstart-pytorch-lightning/mnist.py b/examples/quickstart-pytorch-lightning/mnist.py index d32a0afe2d1e..95342f4fb9b3 100644 --- a/examples/quickstart-pytorch-lightning/mnist.py +++ b/examples/quickstart-pytorch-lightning/mnist.py @@ -86,16 +86,20 @@ def load_data(partition): # 60 % for the federated train and 20 % for the federated validation (both in fit) partition_train_valid = partition_full["train"].train_test_split(train_size=0.75) trainloader = DataLoader( - partition_train_valid["train"], batch_size=32, - shuffle=True, collate_fn=collate_fn, num_workers=1 + partition_train_valid["train"], + batch_size=32, + shuffle=True, + collate_fn=collate_fn, + num_workers=1, ) valloader = DataLoader( - partition_train_valid["test"], batch_size=32, - collate_fn=collate_fn, num_workers=1 + partition_train_valid["test"], + batch_size=32, + collate_fn=collate_fn, + num_workers=1, ) testloader = DataLoader( - partition_full["test"], batch_size=32, - collate_fn=collate_fn, num_workers=1 + partition_full["test"], batch_size=32, collate_fn=collate_fn, num_workers=1 ) return trainloader, valloader, testloader diff --git a/examples/quickstart-sklearn-tabular/client.py b/examples/quickstart-sklearn-tabular/client.py index 88f654d4398e..5dc0e88b3c75 100644 --- a/examples/quickstart-sklearn-tabular/client.py +++ b/examples/quickstart-sklearn-tabular/client.py @@ -68,4 +68,6 @@ def evaluate(self, parameters, config): # type: ignore return loss, len(X_test), {"test_accuracy": accuracy} # Start Flower client - fl.client.start_client(server_address="0.0.0.0:8080", client=IrisClient().to_client()) + fl.client.start_client( + server_address="0.0.0.0:8080", client=IrisClient().to_client() + )