diff --git a/examples/getting_started/pt/swarm_script_executor_cifar10.py b/examples/getting_started/pt/swarm_script_executor_cifar10.py index 4a04f10f40..adf59091c5 100644 --- a/examples/getting_started/pt/swarm_script_executor_cifar10.py +++ b/examples/getting_started/pt/swarm_script_executor_cifar10.py @@ -47,19 +47,22 @@ executor = ScriptExecutor(task_script_path=train_script) job.to(executor, f"site-{i}", gpu=0, tasks=["train", "validate", "submit_model"]) - client_controller = SwarmClientController() - job.to(client_controller, f"site-{i}", tasks=["swarm_*"]) - - client_controller = CrossSiteEvalClientController() - job.to(client_controller, f"site-{i}", tasks=["cse_*"]) - # In swarm learning, each client acts also as an aggregator aggregator = InTimeAccumulateWeightedAggregator(expected_data_kind=DataKind.WEIGHTS) - job.to(aggregator, f"site-{i}") # In swarm learning, each client uses a model persistor and shareable_generator - job.to(PTFileModelPersistor(model=Net()), f"site-{i}") - job.to(SimpleModelShareableGenerator(), f"site-{i}") + persistor = PTFileModelPersistor(model=Net()) + shareable_generator = SimpleModelShareableGenerator() + + client_controller = SwarmClientController( + aggregator_id=job.as_id(aggregator), + persistor_id=job.as_id(persistor), + shareable_generator_id=job.as_id(shareable_generator), + ) + job.to(client_controller, f"site-{i}", tasks=["swarm_*"]) + + client_controller = CrossSiteEvalClientController() + job.to(client_controller, f"site-{i}", tasks=["cse_*"]) # job.export_job("/tmp/nvflare/jobs/job_config") job.simulator_run("/tmp/nvflare/jobs/workdir")