Skip to content

Commit

Permalink
Format change
Browse files Browse the repository at this point in the history
  • Loading branch information
nvidianz committed Dec 20, 2024
1 parent 0f9d923 commit 82207a7
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 4 deletions.
4 changes: 2 additions & 2 deletions examples/hello-world/hello-pt/fedavg_script_runner_pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@
train_script = "src/hello-pt_cifar10_fl.py"

job = FedAvgJob(
name="hello-pt_cifar10_fedavg", n_clients=n_clients, num_rounds=num_rounds, initial_model=SimpleNetwork(),
name="hello-pt_cifar10_fedavg", n_clients=n_clients, num_rounds=num_rounds, initial_model=SimpleNetwork()
)

# Add clients
for i in range(n_clients):
executor = ScriptRunner(
script=train_script, script_args="", launch_external_process=True # f"--batch_size 32 --data_path /tmp/data/site-{i}"
script=train_script, script_args="" # f"--batch_size 32 --data_path /tmp/data/site-{i}"
)
job.to(executor, f"site-{i + 1}")

Expand Down
3 changes: 1 addition & 2 deletions examples/hello-world/hello-pt/src/hello-pt_cifar10_fl.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,7 @@ def main():
print(f"Epoch: {epoch}/{epochs}, Iteration: {i}, Loss: {running_loss / 3000}")
global_step = input_model.current_round * steps + epoch * len(train_loader) + i
summary_writer.add_scalar(tag="loss_for_each_batch", scalar=running_loss, global_step=global_step)
running_loss = xyz

running_loss = 0.0

print("Finished Training")

Expand Down
3 changes: 3 additions & 0 deletions nvflare/client/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

logger = logging.getLogger(__name__)


class ClientAPIType(Enum):
IN_PROCESS_API = "IN_PROCESS_API"
EX_PROCESS_API = "EX_PROCESS_API"
Expand All @@ -38,6 +39,7 @@ class ClientAPIType(Enum):
client_api: Optional[APISpec] = None
data_bus = DataBus()


def death_watch():
"""
Python's main thread doesn't die if there are running thread pools.
Expand All @@ -54,6 +56,7 @@ def death_watch():
except Exception as ex:
logger.warning(f"Death watch failed with error: {ex}")


def init(rank: Optional[str] = None):
"""Initializes NVFlare Client API environment.
Expand Down

0 comments on commit 82207a7

Please sign in to comment.