Skip to content

Commit

Permalink
Merge branch 'main' into add-dp-secagg-demo-example
Browse files Browse the repository at this point in the history
  • Loading branch information
danieljanes authored Mar 15, 2024
2 parents 3ffe3f5 + d391fd6 commit 4e2da5e
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
4 changes: 2 additions & 2 deletions src/py/flwr/cli/new/templates/app/code/client.pytorch.py.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ from $project_name.task import (

# Define Flower Client and client_fn
class FlowerClient(NumPyClient):
def __init__(self, net, trainloader, valloader) -> None:
def __init__(self, net, trainloader, valloader):
self.net = net
self.trainloader = trainloader
self.valloader = valloader
Expand All @@ -31,7 +31,7 @@ class FlowerClient(NumPyClient):
return loss, len(self.valloader.dataset), {"accuracy": accuracy}


def client_fn(cid: str):
def client_fn(cid):
# Load model and data
net = Net().to(DEVICE)
trainloader, valloader = load_data(int(cid), 2)
Expand Down
6 changes: 3 additions & 3 deletions src/py/flwr/cli/new/templates/app/code/task.pytorch.py.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class Net(nn.Module):
"""Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')"""

def __init__(self) -> None:
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
Expand All @@ -25,7 +25,7 @@ class Net(nn.Module):
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)

def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
Expand All @@ -34,7 +34,7 @@ class Net(nn.Module):
return self.fc3(x)


def load_data(partition_id: int, num_partitions: int):
def load_data(partition_id, num_partitions):
"""Load partition CIFAR10 data."""
fds = FederatedDataset(dataset="cifar10", partitioners={"train": num_partitions})
partition = fds.load_partition(partition_id)
Expand Down

0 comments on commit 4e2da5e

Please sign in to comment.