From 4b32f274e6993ae05a02d186c067fcdb0ba8f544 Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Fri, 2 Aug 2024 15:09:24 -0400 Subject: [PATCH] Added id to the jobAPI swarm_script_executor_cifar10 component deploy (#2678) * Added id to the swarm_script_executor_cifar10 component deploy. * codestyle fix. * Changed to use job.as_id(). * codestyle fix. * changed to use job.as_id(shareable_generator) for shareable_generator_id. * removed the un-necessary job.to() calls. --------- Co-authored-by: Chester Chen <512707+chesterxgchen@users.noreply.github.com> Co-authored-by: Sean Yang --- .../pt/swarm_script_executor_cifar10.py | 21 +++++++++++-------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/examples/getting_started/pt/swarm_script_executor_cifar10.py b/examples/getting_started/pt/swarm_script_executor_cifar10.py index 4a04f10f40..adf59091c5 100644 --- a/examples/getting_started/pt/swarm_script_executor_cifar10.py +++ b/examples/getting_started/pt/swarm_script_executor_cifar10.py @@ -47,19 +47,22 @@ executor = ScriptExecutor(task_script_path=train_script) job.to(executor, f"site-{i}", gpu=0, tasks=["train", "validate", "submit_model"]) - client_controller = SwarmClientController() - job.to(client_controller, f"site-{i}", tasks=["swarm_*"]) - - client_controller = CrossSiteEvalClientController() - job.to(client_controller, f"site-{i}", tasks=["cse_*"]) - # In swarm learning, each client acts also as an aggregator aggregator = InTimeAccumulateWeightedAggregator(expected_data_kind=DataKind.WEIGHTS) - job.to(aggregator, f"site-{i}") # In swarm learning, each client uses a model persistor and shareable_generator - job.to(PTFileModelPersistor(model=Net()), f"site-{i}") - job.to(SimpleModelShareableGenerator(), f"site-{i}") + persistor = PTFileModelPersistor(model=Net()) + shareable_generator = SimpleModelShareableGenerator() + + client_controller = SwarmClientController( + aggregator_id=job.as_id(aggregator), + persistor_id=job.as_id(persistor), + shareable_generator_id=job.as_id(shareable_generator), + ) + job.to(client_controller, f"site-{i}", tasks=["swarm_*"]) + + client_controller = CrossSiteEvalClientController() + job.to(client_controller, f"site-{i}", tasks=["cse_*"]) # job.export_job("/tmp/nvflare/jobs/job_config") job.simulator_run("/tmp/nvflare/jobs/workdir")