Skip to content

Commit

Permalink
refactor(framework) Update torch template (#4295)
Browse files Browse the repository at this point in the history
  • Loading branch information
jafermarq authored Oct 8, 2024
1 parent 270f823 commit ef3646c
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 15 deletions.
19 changes: 8 additions & 11 deletions src/py/flwr/cli/new/templates/app/code/client.pytorch.py.tpl
Original file line number Diff line number Diff line change
@@ -1,17 +1,10 @@
"""$project_name: A Flower / $framework_str app."""

import torch
from flwr.client import NumPyClient, ClientApp
from flwr.common import Context

from $import_name.task import (
Net,
load_data,
get_weights,
set_weights,
train,
test,
)
from flwr.client import ClientApp, NumPyClient
from flwr.common import Context
from $import_name.task import Net, get_weights, load_data, set_weights, test, train


# Define Flower Client and client_fn
Expand All @@ -32,7 +25,11 @@ class FlowerClient(NumPyClient):
self.local_epochs,
self.device,
)
return get_weights(self.net), len(self.trainloader.dataset), {"train_loss": train_loss}
return (
get_weights(self.net),
len(self.trainloader.dataset),
{"train_loss": train_loss},
)

def evaluate(self, parameters, config):
set_weights(self.net, parameters)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from flwr.common import Context, ndarrays_to_parameters
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
from flwr.server.strategy import FedAvg

from $import_name.task import Net, get_weights


Expand All @@ -27,5 +26,6 @@ def server_fn(context: Context):

return ServerAppComponents(strategy=strategy, config=config)


# Create ServerApp
app = ServerApp(server_fn=server_fn)
6 changes: 3 additions & 3 deletions src/py/flwr/cli/new/templates/app/code/task.pytorch.py.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, Normalize, ToTensor
from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import IidPartitioner
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, Normalize, ToTensor


class Net(nn.Module):
Expand Down Expand Up @@ -67,7 +67,7 @@ def train(net, trainloader, epochs, device):
"""Train the model on the training set."""
net.to(device) # move model to GPU if available
criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(net.parameters(), lr=0.1, momentum=0.9)
optimizer = torch.optim.Adam(net.parameters(), lr=0.01)
net.train()
running_loss = 0.0
for _ in range(epochs):
Expand Down

0 comments on commit ef3646c

Please sign in to comment.