diff --git a/src/py/flwr/cli/new/templates/app/code/client.pytorch.py.tpl b/src/py/flwr/cli/new/templates/app/code/client.pytorch.py.tpl index 4f2b26ceddea..7137a7791683 100644 --- a/src/py/flwr/cli/new/templates/app/code/client.pytorch.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/client.pytorch.py.tpl @@ -15,7 +15,7 @@ from $project_name.task import ( # Define Flower Client and client_fn class FlowerClient(NumPyClient): - def __init__(self, net, trainloader, valloader) -> None: + def __init__(self, net, trainloader, valloader): self.net = net self.trainloader = trainloader self.valloader = valloader @@ -31,7 +31,7 @@ class FlowerClient(NumPyClient): return loss, len(self.valloader.dataset), {"accuracy": accuracy} -def client_fn(cid: str): +def client_fn(cid): # Load model and data net = Net().to(DEVICE) trainloader, valloader = load_data(int(cid), 2) diff --git a/src/py/flwr/cli/new/templates/app/code/task.pytorch.py.tpl b/src/py/flwr/cli/new/templates/app/code/task.pytorch.py.tpl index 82e57388fa3e..85460564b6ef 100644 --- a/src/py/flwr/cli/new/templates/app/code/task.pytorch.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/task.pytorch.py.tpl @@ -16,7 +16,7 @@ DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") class Net(nn.Module): """Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')""" - def __init__(self) -> None: + def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.pool = nn.MaxPool2d(2, 2) @@ -25,7 +25,7 @@ class Net(nn.Module): self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = x.view(-1, 16 * 5 * 5) @@ -34,7 +34,7 @@ class Net(nn.Module): return self.fc3(x) -def load_data(partition_id: int, num_partitions: int): +def load_data(partition_id, num_partitions): """Load partition CIFAR10 data.""" fds = FederatedDataset(dataset="cifar10", partitioners={"train": num_partitions}) partition = fds.load_partition(partition_id)