Skip to content

Commit

Permalink
Merge branch 'NVIDIA:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
nvidianz authored Jul 12, 2024
2 parents 3fdb95b + 8ec804f commit 654bc17
Show file tree
Hide file tree
Showing 46 changed files with 1,708 additions and 242 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"format_version": 2,
"min_clients": 8,
"num_clients": 8,
"num_rounds": 50,
"TRAIN_SPLIT_ROOT": "/tmp/cifar10_splits",
"alpha": 1.0,
Expand All @@ -15,7 +15,7 @@
"path": "pt.utils.cifar10_data_splitter.Cifar10DataSplitter",
"args": {
"split_dir": "{TRAIN_SPLIT_ROOT}",
"num_sites": "{min_clients}",
"num_sites": "{num_clients}",
"alpha": "{alpha}"
}
},
Expand Down Expand Up @@ -52,7 +52,7 @@
"id": "scaffold_ctl",
"name": "Scaffold",
"args": {
"min_clients": "{min_clients}",
"num_clients": "{num_clients}",
"num_rounds": "{num_rounds}",

"persistor_id": "persistor"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
"id": "fedavg_newton_raphson",
"path": "newton_raphson_workflow.FedAvgNewtonRaphson",
"args": {
"min_clients": 4,
"num_clients": 4,
"num_rounds": 5,
"damping_factor": 0.8,
"persistor_id": "newton_raphson_persistor"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import os
import sys

sys.path.insert(0, os.path.join(os.getcwd(), "..", "..", "..", "advanced", "cifar10"))
sys.path.insert(0, os.path.join(os.getcwd(), "..", "..", "advanced", "cifar10"))

from pt.learners.cifar10_model_learner import CIFAR10ModelLearner
from pt.networks.cifar10_nets import ModerateCNN
Expand Down
28 changes: 28 additions & 0 deletions examples/hello-world/hello-fedavg/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Hello FedAvg

In this example we highlight the flexibility of the ModelController API, and show how to write a Federated Averaging workflow with early stopping, model selection, and saving and loading. Follow along in the [hello-fedavg.ipynb](hello-fedavg.ipynb) notebook for more details.

### 1. Setup

```
pip install nvflare~=2.5.0rc torch torchvision tensorboard
```

### 2. PTFedAvgEarlyStopping using ModelController API

The ModelController API enables the option to easily customize a workflow with Python code.

- FedAvg: We subclass the BaseFedAvg class to leverage the predefined aggregation functions.
- Early Stopping: We add a `stop_condition` argument (eg. `"accuracy >= 80"`) and end the workflow early if the corresponding global model metric meets the condition.
- Model Selection: As and alternative to using a `IntimeModelSelector` componenet for model selection, we instead compare the metrics of the models in the workflow to select the best model each round.
- Saving/Loading: Rather than configuring a persistor such as `PTFileModelPersistor` component, we choose to utilize PyTorch's save and load functions and save the metadata of the FLModel separately.

### 3. Run the script

Use the Job API to define and run the example with the simulator:

```
python3 pt_fedavg_early_stopping_script.py
```

View the results in the job workspace: `/tmp/nvflare/jobs/workdir`.
Loading

0 comments on commit 654bc17

Please sign in to comment.