diff --git a/docs/programming_guide/controllers/client_controlled_workflows.rst b/docs/programming_guide/controllers/client_controlled_workflows.rst index 823abfe997..ec40dbae5e 100644 --- a/docs/programming_guide/controllers/client_controlled_workflows.rst +++ b/docs/programming_guide/controllers/client_controlled_workflows.rst @@ -70,6 +70,7 @@ All client controlled workflows must have a server side controller that extends result_clients_policy: str = DefaultValuePolicy.ALL, max_status_report_interval: float = Constant.PER_CLIENT_STATUS_REPORT_TIMEOUT, progress_timeout: float = Constant.WORKFLOW_PROGRESS_TIMEOUT, + private_p2p: bool = True, ): Init args for ServerSideController @@ -233,10 +234,11 @@ Cyclic Learning: Server Side Controller starting_client: str = "", max_status_report_interval: float = Constant.PER_CLIENT_STATUS_REPORT_TIMEOUT, progress_timeout: float = Constant.WORKFLOW_PROGRESS_TIMEOUT, - rr_order: str = RROrder.FIXED, + private_p2p: bool = True, + cyclic_order: str = CyclicOrder.FIXED, ): -The only extra init arg is ``rr_order``, which specifies how the round-robin sequence is to be computed for each round: fixed order or random order. +The only extra init arg is ``cyclic_order``, which specifies how the cyclic sequence is to be computed for each round: fixed order or random order. Of all the init args, only the ``num_rounds`` must be explicitly specified. All others can take default values: diff --git a/examples/advanced/custom_authentication/security/server/custom/security_handler.py b/examples/advanced/custom_authentication/security/server/custom/security_handler.py index 4b5956ff67..c10ecb07ec 100644 --- a/examples/advanced/custom_authentication/security/server/custom/security_handler.py +++ b/examples/advanced/custom_authentication/security/server/custom/security_handler.py @@ -20,7 +20,7 @@ class ServerCustomSecurityHandler(FLComponent): def handle_event(self, event_type: str, fl_ctx: FLContext): - if event_type == EventType.CLIENT_REGISTERED: + if event_type == EventType.CLIENT_REGISTER_RECEIVED: self.authenticate(fl_ctx=fl_ctx) def authenticate(self, fl_ctx: FLContext): diff --git a/examples/advanced/experiment-tracking/mlflow/experiment_tracking.ipynb b/examples/advanced/experiment-tracking/mlflow/experiment_tracking.ipynb index cfec9dc778..0a92d40c2d 100644 --- a/examples/advanced/experiment-tracking/mlflow/experiment_tracking.ipynb +++ b/examples/advanced/experiment-tracking/mlflow/experiment_tracking.ipynb @@ -22,7 +22,7 @@ "metadata": {}, "source": [ "\n", - "Example of using [NVIDIA FLARE](https://nvflare.readthedocs.io/en/main/index.html) to train an image classifier using federated averaging ([FedAvg]([FedAvg](https://arxiv.org/abs/1602.05629))) and [PyTorch](https://pytorch.org/) as the deep learning training framework. This example also highlights the streaming capability from the clients to the server with Tensorboard SummaryWriter sender syntax, but with a MLflow receiver\n", + "Example of using [NVIDIA FLARE](https://nvflare.readthedocs.io/en/main/index.html) to train an image classifier using federated averaging ([FedAvg](https://arxiv.org/abs/1602.05629)) and [PyTorch](https://pytorch.org/) as the deep learning training framework. This example also highlights the streaming capability from the clients to the server with Tensorboard SummaryWriter sender syntax, but with a MLflow receiver\n", "\n", "> **_NOTE:_** This example uses the [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) dataset and will load its data within the trainer code.\n" ] diff --git a/examples/advanced/experiment-tracking/mlflow/jobs/hello-pt-mlflow/app/config/config_fed_server.conf b/examples/advanced/experiment-tracking/mlflow/jobs/hello-pt-mlflow/app/config/config_fed_server.conf index 00dbc2df6c..9bc187c8ab 100644 --- a/examples/advanced/experiment-tracking/mlflow/jobs/hello-pt-mlflow/app/config/config_fed_server.conf +++ b/examples/advanced/experiment-tracking/mlflow/jobs/hello-pt-mlflow/app/config/config_fed_server.conf @@ -55,7 +55,7 @@ "mlflow.note.content": "## **Hello PyTorch experiment with MLflow**" }, "run_tags": { - "mlflow.note.content": "## Federated Experiment tracking with MLflow \n### Example of using **[NVIDIA FLARE](https://nvflare.readthedocs.io/en/main/index.html)** to train an image classifier using federated averaging ([FedAvg]([FedAvg](https://arxiv.org/abs/1602.05629))) and [PyTorch](https://pytorch.org/) as the deep learning training framework. This example also highlights the Flare streaming capability from the clients to the server for server delivery to MLflow.\n\n> **_NOTE:_** \n This example uses the *[CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html)* dataset and will load its data within the trainer code.\n" + "mlflow.note.content": "## Federated Experiment tracking with MLflow \n### Example of using **[NVIDIA FLARE](https://nvflare.readthedocs.io/en/main/index.html)** to train an image classifier using federated averaging ([FedAvg](https://arxiv.org/abs/1602.05629)) and [PyTorch](https://pytorch.org/) as the deep learning training framework. This example also highlights the Flare streaming capability from the clients to the server for server delivery to MLflow.\n\n> **_NOTE:_** \n This example uses the *[CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html)* dataset and will load its data within the trainer code.\n" } }, "artifact_location": "artifacts" diff --git a/examples/advanced/experiment-tracking/mlflow/jobs/hello-pt-tb-mlflow/app/config/config_fed_server.conf b/examples/advanced/experiment-tracking/mlflow/jobs/hello-pt-tb-mlflow/app/config/config_fed_server.conf index 1bce0828c9..e9f94a9b82 100644 --- a/examples/advanced/experiment-tracking/mlflow/jobs/hello-pt-tb-mlflow/app/config/config_fed_server.conf +++ b/examples/advanced/experiment-tracking/mlflow/jobs/hello-pt-tb-mlflow/app/config/config_fed_server.conf @@ -56,7 +56,7 @@ "mlflow.note.content": "## **Hello PyTorch experiment with MLflow**" }, "run_tags": { - "mlflow.note.content": "## Federated Experiment tracking with MLflow \n### Example of using **[NVIDIA FLARE](https://nvflare.readthedocs.io/en/main/index.html)** to train an image classifier using federated averaging ([FedAvg]([FedAvg](https://arxiv.org/abs/1602.05629))) and [PyTorch](https://pytorch.org/) as the deep learning training framework. This example also highlights the Flare streaming capability from the clients to the server for server delivery to MLflow.\n\n> **_NOTE:_** \n This example uses the *[CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html)* dataset and will load its data within the trainer code.\n" + "mlflow.note.content": "## Federated Experiment tracking with MLflow \n### Example of using **[NVIDIA FLARE](https://nvflare.readthedocs.io/en/main/index.html)** to train an image classifier using federated averaging ([FedAvg](https://arxiv.org/abs/1602.05629)) and [PyTorch](https://pytorch.org/) as the deep learning training framework. This example also highlights the Flare streaming capability from the clients to the server for server delivery to MLflow.\n\n> **_NOTE:_** \n This example uses the *[CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html)* dataset and will load its data within the trainer code.\n" } }, "artifact_location": "artifacts" diff --git a/examples/advanced/experiment-tracking/wandb/jobs/hello-pt-wandb/app/config/config_fed_server.json b/examples/advanced/experiment-tracking/wandb/jobs/hello-pt-wandb/app/config/config_fed_server.json index 1aebf9d748..dc170f82ae 100644 --- a/examples/advanced/experiment-tracking/wandb/jobs/hello-pt-wandb/app/config/config_fed_server.json +++ b/examples/advanced/experiment-tracking/wandb/jobs/hello-pt-wandb/app/config/config_fed_server.json @@ -45,7 +45,7 @@ "kwargs" : { "project": "hello-pt-experiment", "name": "hello-pt", - "notes": "Federated Experiment tracking with W&B \n Example of using [NVIDIA FLARE](https://nvflare.readthedocs.io/en/main/index.html) to train an image classifier using federated averaging ([FedAvg]([FedAvg](https://arxiv.org/abs/1602.05629))) and [PyTorch](https://pytorch.org/) as the deep learning training framework. This example also highlights the Flare streaming capability from the clients to the server and deliver to MLFLow.\\n\\n> **_NOTE:_** \\n This example uses the *[CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) dataset and will load its data within the trainer code.\n", + "notes": "Federated Experiment tracking with W&B \n Example of using [NVIDIA FLARE](https://nvflare.readthedocs.io/en/main/index.html) to train an image classifier using federated averaging ([FedAvg](https://arxiv.org/abs/1602.05629)) and [PyTorch](https://pytorch.org/) as the deep learning training framework. This example also highlights the Flare streaming capability from the clients to the server and deliver to MLFLow.\\n\\n> **_NOTE:_** \\n This example uses the *[CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) dataset and will load its data within the trainer code.\n", "tags": ["baseline", "paper1"], "job_type": "train-validate", "config": { diff --git a/nvflare/apis/controller_spec.py b/nvflare/apis/controller_spec.py index 2d9f365341..f0019da206 100644 --- a/nvflare/apis/controller_spec.py +++ b/nvflare/apis/controller_spec.py @@ -293,7 +293,6 @@ def stop_controller(self, fl_ctx: FLContext): """ pass - @abstractmethod def process_result_of_unknown_task( self, client: Client, task_name: str, client_task_id: str, result: Shareable, fl_ctx: FLContext ): diff --git a/nvflare/apis/event_type.py b/nvflare/apis/event_type.py index 4c65843e70..7ecac17486 100644 --- a/nvflare/apis/event_type.py +++ b/nvflare/apis/event_type.py @@ -34,6 +34,7 @@ class EventType(object): JOB_COMPLETED = "_job_completed" JOB_ABORTED = "_job_aborted" JOB_CANCELLED = "_job_cancelled" + JOB_DEAD = "_job_dead" BEFORE_PULL_TASK = "_before_pull_task" AFTER_PULL_TASK = "_after_pull_task" @@ -50,6 +51,8 @@ class EventType(object): AFTER_TASK_EXECUTION = "_after_task_execution" BEFORE_SEND_TASK_RESULT = "_before_send_task_result" AFTER_SEND_TASK_RESULT = "_after_send_task_result" + BEFORE_PROCESS_RESULT_OF_UNKNOWN_TASK = "_before_process_result_of_unknown_task" + AFTER_PROCESS_RESULT_OF_UNKNOWN_TASK = "_after_process_result_of_unknown_task" CRITICAL_LOG_AVAILABLE = "_critical_log_available" ERROR_LOG_AVAILABLE = "_error_log_available" @@ -73,8 +76,14 @@ class EventType(object): BEFORE_CLIENT_REGISTER = "_before_client_register" AFTER_CLIENT_REGISTER = "_after_client_register" - CLIENT_REGISTERED = "_client_registered" + CLIENT_REGISTER_RECEIVED = "_client_register_received" + CLIENT_REGISTER_PROCESSED = "_client_register_processed" + CLIENT_QUIT = "_client_quit" SYSTEM_BOOTSTRAP = "_system_bootstrap" + BEFORE_CLIENT_HEARTBEAT = "_before_client_heartbeat" + AFTER_CLIENT_HEARTBEAT = "_after_client_heartbeat" + CLIENT_HEARTBEAT_RECEIVED = "_client_heartbeat_received" + CLIENT_HEARTBEAT_PROCESSED = "_client_heartbeat_processed" AUTHORIZE_COMMAND_CHECK = "_authorize_command_check" BEFORE_BUILD_COMPONENT = "_before_build_component" diff --git a/nvflare/apis/fl_constant.py b/nvflare/apis/fl_constant.py index af15ac6504..58ce869f75 100644 --- a/nvflare/apis/fl_constant.py +++ b/nvflare/apis/fl_constant.py @@ -150,12 +150,14 @@ class FLContextKey(object): COMMUNICATION_ERROR = "Flare_communication_error__" UNAUTHENTICATED = "Flare_unauthenticated__" CLIENT_RESOURCE_SPECS = "__client_resource_specs" + RESOURCE_CHECK_RESULT = "__resource_check_result" JOB_PARTICIPANTS = "__job_participants" JOB_BLOCK_REASON = "__job_block_reason" # why the job should be blocked from scheduling SSID = "__ssid__" CLIENT_TOKEN = "__client_token" AUTHORIZATION_RESULT = "_authorization_result" AUTHORIZATION_REASON = "_authorization_reason" + DEAD_JOB_CLIENT_NAME = "_dead_job_client_name" CLIENT_REGISTER_DATA = "_client_register_data" SECURITY_ITEMS = "_security_items" diff --git a/nvflare/apis/impl/controller.py b/nvflare/apis/impl/controller.py index 92d190805b..6091ac23f1 100644 --- a/nvflare/apis/impl/controller.py +++ b/nvflare/apis/impl/controller.py @@ -42,6 +42,7 @@ def initialize(self, fl_ctx: FLContext): return self._engine = engine + self.start_controller(fl_ctx) def set_communicator(self, communicator: WFCommSpec): if not isinstance(communicator, WFCommSpec): diff --git a/nvflare/apis/impl/wf_comm_server.py b/nvflare/apis/impl/wf_comm_server.py index 9ecbc05feb..06b5d13457 100644 --- a/nvflare/apis/impl/wf_comm_server.py +++ b/nvflare/apis/impl/wf_comm_server.py @@ -18,6 +18,7 @@ from nvflare.apis.client import Client from nvflare.apis.controller_spec import ClientTask, SendOrder, Task, TaskCompletionStatus +from nvflare.apis.event_type import EventType from nvflare.apis.fl_component import FLComponent from nvflare.apis.fl_constant import FLContextKey from nvflare.apis.fl_context import FLContext @@ -133,8 +134,8 @@ def _set_stats(self, fl_ctx: FLContext): raise TypeError( "collector must be an instance of GroupInfoCollector, but got {}".format(type(collector)) ) - collector.set_info( - group_name=self._name, + collector.add_info( + group_name=self.controller._name, info={ "tasks": {t.name: [ct.client.name for ct in t.client_tasks] for t in self._tasks}, }, @@ -149,6 +150,12 @@ def handle_event(self, event_type: str, fl_ctx: FLContext): """ if event_type == InfoCollector.EVENT_TYPE_GET_STATS: self._set_stats(fl_ctx) + elif event_type == EventType.JOB_DEAD: + client_name = fl_ctx.get_prop(FLContextKey.DEAD_JOB_CLIENT_NAME) + with self._dead_clients_lock: + self.log_info(fl_ctx, f"received dead job report from client {client_name}") + if not self._dead_client_reports.get(client_name): + self._dead_client_reports[client_name] = time.time() def process_task_request(self, client: Client, fl_ctx: FLContext) -> Tuple[str, str, Shareable]: """Called by runner when a client asks for a task. @@ -330,22 +337,6 @@ def handle_exception(self, task_id: str, fl_ctx: FLContext) -> None: self.cancel_task(task=task, fl_ctx=fl_ctx) self.log_error(fl_ctx, "task {} is cancelled due to exception".format(task.name)) - def handle_dead_job(self, client_name: str, fl_ctx: FLContext): - """Called by the Engine to handle the case that the job on the client is dead. - - Args: - client_name: name of the client on which the job is dead - fl_ctx: the FLContext - - """ - # record the report and to be used by the task monitor - with self._dead_clients_lock: - self.log_info(fl_ctx, f"received dead job report from client {client_name}") - if not self._dead_client_reports.get(client_name): - self._dead_client_reports[client_name] = time.time() - - self.controller.handle_dead_job(client_name, fl_ctx) - def process_task_check(self, task_id: str, fl_ctx: FLContext): with self._task_lock: # task_id is the uuid associated with the client_task @@ -400,7 +391,14 @@ def _do_process_submission( if client_task is None: # cannot find a standing task for the submission self.log_debug(fl_ctx, "no standing task found for {}:{}".format(task_name, task_id)) + + self.log_debug(fl_ctx, "firing event EventType.BEFORE_PROCESS_RESULT_OF_UNKNOWN_TASK") + self.fire_event(EventType.BEFORE_PROCESS_RESULT_OF_UNKNOWN_TASK, fl_ctx) + self.controller.process_result_of_unknown_task(client, task_name, task_id, result, fl_ctx) + + self.log_debug(fl_ctx, "firing event EventType.AFTER_PROCESS_RESULT_OF_UNKNOWN_TASK") + self.fire_event(EventType.AFTER_PROCESS_RESULT_OF_UNKNOWN_TASK, fl_ctx) return task = client_task.task diff --git a/nvflare/apis/server_engine_spec.py b/nvflare/apis/server_engine_spec.py index 1b1e8c14c4..004e623056 100644 --- a/nvflare/apis/server_engine_spec.py +++ b/nvflare/apis/server_engine_spec.py @@ -154,12 +154,13 @@ def restore_components(self, snapshot: RunSnapshot, fl_ctx: FLContext): pass @abstractmethod - def start_client_job(self, job_id, client_sites): + def start_client_job(self, job_id, client_sites, fl_ctx: FLContext): """To send the start client run commands to the clients Args: client_sites: client sites job_id: job_id + fl_ctx: FLContext Returns: @@ -187,7 +188,7 @@ def check_client_resources( @abstractmethod def cancel_client_resources( - self, resource_check_results: Dict[str, Tuple[bool, str]], resource_reqs: Dict[str, dict] + self, resource_check_results: Dict[str, Tuple[bool, str]], resource_reqs: Dict[str, dict], fl_ctx: FLContext ): """Cancels the request resources for the job. @@ -195,6 +196,7 @@ def cancel_client_resources( resource_check_results: A dict of {client_name: client_check_result} where client_check_result is a tuple of (is_resource_enough, resource reserve token if any) resource_reqs: A dict of {client_name: resource requirements dict} + fl_ctx: FLContext """ pass diff --git a/nvflare/apis/wf_comm_spec.py b/nvflare/apis/wf_comm_spec.py index 8350b16b7a..a32c34948a 100644 --- a/nvflare/apis/wf_comm_spec.py +++ b/nvflare/apis/wf_comm_spec.py @@ -293,16 +293,6 @@ def process_task_check(self, task_id: str, fl_ctx: FLContext): """ raise NotImplementedError - def handle_dead_job(self, client_name: str, fl_ctx: FLContext): - """Called by the Engine to handle the case that the job on the client is dead. - - Args: - client_name: name of the client on which the job is dead - fl_ctx: the FLContext - - """ - raise NotImplementedError - def initialize_run(self, fl_ctx: FLContext): """Called when a new RUN is about to start. diff --git a/nvflare/app_common/ccwf/common.py b/nvflare/app_common/ccwf/common.py index 389123f057..e1883fbe24 100644 --- a/nvflare/app_common/ccwf/common.py +++ b/nvflare/app_common/ccwf/common.py @@ -88,6 +88,8 @@ class Constant: FINAL_RESULT_ACK_TIMEOUT = 10 GET_MODEL_TIMEOUT = 10 + PROP_KEY_TRAIN_CLIENTS = "cwf.train_clients" + class ModelType: diff --git a/nvflare/app_common/ccwf/cse_server_ctl.py b/nvflare/app_common/ccwf/cse_server_ctl.py index f95e5fdc05..e3b87cd183 100644 --- a/nvflare/app_common/ccwf/cse_server_ctl.py +++ b/nvflare/app_common/ccwf/cse_server_ctl.py @@ -204,7 +204,13 @@ def sub_flow(self, abort_signal: Signal, fl_ctx: FLContext): self.current_round += 1 # ask everyone to eval everyone else's local model + train_clients = fl_ctx.get_prop(Constant.PROP_KEY_TRAIN_CLIENTS) for c in self.evaluatees: + if train_clients and c not in train_clients: + # this client does not have local models + self.log_info(fl_ctx, f"ignore client {c} since it does not have local models") + continue + self._ask_to_evaluate( current_round=self.current_round, model_name=ModelName.BEST_MODEL, diff --git a/nvflare/app_common/ccwf/cyclic_client_ctl.py b/nvflare/app_common/ccwf/cyclic_client_ctl.py index 9018eb240d..f4cc1cb885 100644 --- a/nvflare/app_common/ccwf/cyclic_client_ctl.py +++ b/nvflare/app_common/ccwf/cyclic_client_ctl.py @@ -61,8 +61,8 @@ def start_workflow(self, shareable: Shareable, fl_ctx: FLContext, abort_signal: clients = self.get_config_prop(Constant.CLIENTS) # make sure the starting client is the 1st rotate_to_front(self.me, clients) - rr_order = self.get_config_prop(Constant.ORDER) - self.log_info(fl_ctx, f"Starting cyclic workflow on clients {clients} with order {rr_order} ") + cyclic_order = self.get_config_prop(Constant.ORDER) + self.log_info(fl_ctx, f"Starting cyclic workflow on clients {clients} with order {cyclic_order} ") self._set_task_headers( task_data=shareable, num_rounds=self.get_config_prop(AppConstants.NUM_ROUNDS), diff --git a/nvflare/app_common/ccwf/cyclic_server_ctl.py b/nvflare/app_common/ccwf/cyclic_server_ctl.py index 1806cc7848..630b8d2aa0 100644 --- a/nvflare/app_common/ccwf/cyclic_server_ctl.py +++ b/nvflare/app_common/ccwf/cyclic_server_ctl.py @@ -57,7 +57,9 @@ def __init__( ) check_str("cyclic_order", cyclic_order) if cyclic_order not in [CyclicOrder.FIXED, CyclicOrder.RANDOM]: - raise ValueError(f"invalid rr_order {cyclic_order}: must be in {[CyclicOrder.FIXED, CyclicOrder.RANDOM]}") + raise ValueError( + f"invalid cyclic_order {cyclic_order}: must be in {[CyclicOrder.FIXED, CyclicOrder.RANDOM]}" + ) self.cyclic_order = cyclic_order def prepare_config(self): diff --git a/nvflare/app_common/ccwf/swarm_server_ctl.py b/nvflare/app_common/ccwf/swarm_server_ctl.py index ca2ffa3e62..9f64e2ad4e 100644 --- a/nvflare/app_common/ccwf/swarm_server_ctl.py +++ b/nvflare/app_common/ccwf/swarm_server_ctl.py @@ -92,5 +92,9 @@ def start_controller(self, fl_ctx: FLContext): if c not in self.train_clients and c not in self.aggr_clients: raise RuntimeError(f"Config Error: client {c} is neither train client nor aggr client") + # set train_clients as a sticky prop in fl_ctx + # in case CSE (cross site eval) workflow follows, it will know that only training clients have local models + fl_ctx.set_prop(key=Constant.PROP_KEY_TRAIN_CLIENTS, value=self.train_clients, private=True, sticky=True) + def prepare_config(self): return {Constant.AGGR_CLIENTS: self.aggr_clients, Constant.TRAIN_CLIENTS: self.train_clients} diff --git a/nvflare/app_common/job_schedulers/job_scheduler.py b/nvflare/app_common/job_schedulers/job_scheduler.py index 2619b98d68..c7e03d394f 100644 --- a/nvflare/app_common/job_schedulers/job_scheduler.py +++ b/nvflare/app_common/job_schedulers/job_scheduler.py @@ -93,7 +93,7 @@ def _cancel_resources( if not isinstance(engine, ServerEngineSpec): raise RuntimeError(f"engine inside fl_ctx should be of type ServerEngineSpec, but got {type(engine)}.") - engine.cancel_client_resources(resource_check_results, resource_reqs) + engine.cancel_client_resources(resource_check_results, resource_reqs, fl_ctx) self.log_debug(fl_ctx, f"cancel client resources using check results: {resource_check_results}") return False, None @@ -165,6 +165,7 @@ def _try_job(self, job: Job, fl_ctx: FLContext) -> (int, Optional[Dict[str, Disp return SCHEDULE_RESULT_NO_RESOURCE, None, block_reason resource_check_results = self._check_client_resources(job=job, resource_reqs=resource_reqs, fl_ctx=fl_ctx) + fl_ctx.set_prop(FLContextKey.RESOURCE_CHECK_RESULT, resource_check_results, private=True, sticky=False) self.fire_event(EventType.AFTER_CHECK_CLIENT_RESOURCES, fl_ctx) if not resource_check_results: diff --git a/nvflare/app_common/psi/dh_psi/dh_psi_workflow.py b/nvflare/app_common/psi/dh_psi/dh_psi_workflow.py index c3b6b08e56..95212ff73f 100644 --- a/nvflare/app_common/psi/dh_psi/dh_psi_workflow.py +++ b/nvflare/app_common/psi/dh_psi/dh_psi_workflow.py @@ -163,7 +163,7 @@ def pairwise_setup(self, ordered_sites: List[SiteSize]): bop = BroadcastAndWait(self.fl_ctx, self.controller) results = bop.multicasts_and_wait( - task_name=self.task_name, task_inputs=task_inputs, abort_signal=self.abort_signal + task_name=self.task_name, task_inputs=task_inputs, fl_ctx=self.fl_ctx, abort_signal=self.abort_signal ) return {site_name: results[site_name].data[PSIConst.SETUP_MSG] for site_name in results} @@ -181,7 +181,7 @@ def pairwise_requests(self, ordered_sites: List[SiteSize], setup_msgs: Dict[str, bop = BroadcastAndWait(self.fl_ctx, self.controller) results = bop.multicasts_and_wait( - task_name=self.task_name, task_inputs=task_inputs, abort_signal=self.abort_signal + task_name=self.task_name, task_inputs=task_inputs, fl_ctx=self.fl_ctx, abort_signal=self.abort_signal ) return {site_name: results[site_name].data[PSIConst.REQUEST_MSG] for site_name in results} @@ -199,7 +199,7 @@ def pairwise_responses(self, ordered_sites: List[SiteSize], request_msgs: Dict[s bop = BroadcastAndWait(self.fl_ctx, self.controller) results = bop.multicasts_and_wait( - task_name=self.task_name, task_inputs=task_inputs, abort_signal=self.abort_signal + task_name=self.task_name, task_inputs=task_inputs, fl_ctx=self.fl_ctx, abort_signal=self.abort_signal ) return {site_name: results[site_name].data[PSIConst.RESPONSE_MSG] for site_name in results} @@ -217,7 +217,7 @@ def pairwise_intersect(self, ordered_sites: List[SiteSize], response_msg: Dict[s bop = BroadcastAndWait(self.fl_ctx, self.controller) results = bop.multicasts_and_wait( - task_name=self.task_name, task_inputs=task_inputs, abort_signal=self.abort_signal + task_name=self.task_name, task_inputs=task_inputs, fl_ctx=self.fl_ctx, abort_signal=self.abort_signal ) return {site_name: results[site_name].data[PSIConst.ITEMS_SIZE] for site_name in results} @@ -279,7 +279,7 @@ def calculate_intersections(self, response_msg) -> Dict[str, int]: task_inputs[client_name] = inputs bop = BroadcastAndWait(self.fl_ctx, self.controller) results = bop.multicasts_and_wait( - task_name=self.task_name, task_inputs=task_inputs, abort_signal=self.abort_signal + task_name=self.task_name, task_inputs=task_inputs, fl_ctx=self.fl_ctx, abort_signal=self.abort_signal ) intersects = {client_name: results[client_name].data[PSIConst.ITEMS_SIZE] for client_name in results} @@ -292,7 +292,11 @@ def process_requests(self, s: SiteSize, request_msgs: Dict[str, str]) -> Dict[st task_inputs[PSIConst.REQUEST_MSG_SET] = request_msgs bop = BroadcastAndWait(self.fl_ctx, self.controller) results = bop.broadcast_and_wait( - task_name=self.task_name, task_input=task_inputs, targets=[s.name], abort_signal=self.abort_signal + task_name=self.task_name, + task_input=task_inputs, + fl_ctx=self.fl_ctx, + targets=[s.name], + abort_signal=self.abort_signal, ) dxo = results[s.name] @@ -309,7 +313,7 @@ def create_requests(self, site_setup_msgs) -> Dict[str, str]: bop = BroadcastAndWait(self.fl_ctx, self.controller) results = bop.multicasts_and_wait( - task_name=self.task_name, task_inputs=task_inputs, abort_signal=self.abort_signal + task_name=self.task_name, task_inputs=task_inputs, fl_ctx=self.fl_ctx, abort_signal=self.abort_signal ) request_msgs = {client_name: results[client_name].data[PSIConst.REQUEST_MSG] for client_name in results} return request_msgs @@ -335,6 +339,7 @@ def prepare_sites(self, abort_signal): results = bop.broadcast_and_wait( task_name=self.task_name, task_input=inputs, + fl_ctx=self.fl_ctx, targets=targets, min_responses=min_responses, abort_signal=abort_signal, @@ -352,7 +357,11 @@ def prepare_setup_messages(self, s: SiteSize, other_site_sizes: Set[int]) -> Dic inputs[PSIConst.ITEMS_SIZE_SET] = other_site_sizes bop = BroadcastAndWait(self.fl_ctx, self.controller) results = bop.broadcast_and_wait( - task_name=self.task_name, task_input=inputs, targets=[s.name], abort_signal=self.abort_signal + task_name=self.task_name, + task_input=inputs, + fl_ctx=self.fl_ctx, + targets=[s.name], + abort_signal=self.abort_signal, ) dxo = results[s.name] return dxo.data[PSIConst.SETUP_MSG] diff --git a/nvflare/app_common/workflows/broadcast_operator.py b/nvflare/app_common/workflows/broadcast_operator.py index ce548cd95e..e6473426ad 100644 --- a/nvflare/app_common/workflows/broadcast_operator.py +++ b/nvflare/app_common/workflows/broadcast_operator.py @@ -13,6 +13,7 @@ # limitations under the License. import threading +import time from typing import Dict, List, Optional, Union from nvflare.apis.client import Client @@ -41,29 +42,35 @@ def broadcast_and_wait( self, task_name: str, task_input: Shareable, + fl_ctx: FLContext, targets: Union[List[Client], List[str], None] = None, task_props: Optional[Dict] = None, min_responses: int = 1, abort_signal: Signal = None, ) -> Dict[str, DXO]: task = Task(name=task_name, data=task_input, result_received_cb=self.results_cb, props=task_props) - self.controller.broadcast_and_wait(task, self.fl_ctx, targets, min_responses, 0, abort_signal) + self.controller.broadcast_and_wait(task, fl_ctx, targets, min_responses, 0, abort_signal) return self.results def multicasts_and_wait( self, task_name: str, task_inputs: Dict[str, Shareable], + fl_ctx: FLContext, abort_signal: Signal = None, + task_check_period: int = 0.5, ) -> Dict[str, DXO]: tasks: Dict[str, Task] = self.get_tasks(task_name, task_inputs) for client_name in tasks: - self.controller.broadcast(task=tasks[client_name], fl_ctx=self.fl_ctx, targets=[client_name]) + self.controller.send(task=tasks[client_name], fl_ctx=fl_ctx, targets=[client_name]) - for client_name in tasks: - self.log_info(self.fl_ctx, f"wait for client {client_name} task") - self.controller.wait_for_task(tasks[client_name], abort_signal) + while self.controller.get_num_standing_tasks(): + if abort_signal.triggered: + self.log_info(fl_ctx, "Abort signal triggered. Finishing multicasts_and_wait.") + return + self.log_debug(fl_ctx, "Checking standing tasks to see if multicasts_and_wait finished.") + time.sleep(task_check_period) return self.results diff --git a/nvflare/app_common/workflows/cyclic_ctl.py b/nvflare/app_common/workflows/cyclic_ctl.py index 65f8cb9195..22ca0e6b70 100644 --- a/nvflare/app_common/workflows/cyclic_ctl.py +++ b/nvflare/app_common/workflows/cyclic_ctl.py @@ -17,7 +17,8 @@ from nvflare.apis.client import Client from nvflare.apis.controller_spec import ClientTask, Task -from nvflare.apis.fl_constant import ReturnCode +from nvflare.apis.event_type import EventType +from nvflare.apis.fl_constant import FLContextKey, ReturnCode from nvflare.apis.fl_context import FLContext from nvflare.apis.impl.controller import Controller from nvflare.apis.shareable import Shareable @@ -264,11 +265,11 @@ def restore(self, state_data: dict, fl_ctx: FLContext): finally: pass - def handle_dead_job(self, client_name: str, fl_ctx: FLContext): - super().handle_dead_job(client_name, fl_ctx) - - new_client_list = [] - for client in self._participating_clients: - if client_name != client.name: - new_client_list.append(client) - self._participating_clients = new_client_list + def handle_event(self, event_type, fl_ctx): + if event_type == EventType.JOB_DEAD: + client_name = fl_ctx.get_prop(FLContextKey.DEAD_JOB_CLIENT_NAME) + new_client_list = [] + for client in self._participating_clients: + if client_name != client.name: + new_client_list.append(client) + self._participating_clients = new_client_list diff --git a/nvflare/app_common/workflows/model_controller.py b/nvflare/app_common/workflows/model_controller.py index cd6b565241..05e2e7b739 100644 --- a/nvflare/app_common/workflows/model_controller.py +++ b/nvflare/app_common/workflows/model_controller.py @@ -144,7 +144,7 @@ def start_controller(self, fl_ctx: FLContext) -> None: self.event(AppEventType.INITIAL_MODEL_LOADED) self.engine = self.fl_ctx.get_engine() - self.initialize() + FLComponentWrapper.initialize(self) def _build_shareable(self, data: FLModel = None) -> Shareable: if not data: # if no data is given, send self.model diff --git a/nvflare/app_opt/confidential_computing/cc_authorizer.py b/nvflare/app_opt/confidential_computing/cc_authorizer.py new file mode 100644 index 0000000000..aaf610d88c --- /dev/null +++ b/nvflare/app_opt/confidential_computing/cc_authorizer.py @@ -0,0 +1,59 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# import os.path +from abc import ABC, abstractmethod + + +class CCAuthorizer(ABC): + @abstractmethod + def get_namespace(self) -> str: + """This returns the namespace of the CCAuthorizer. + + Returns: namespace string + + """ + pass + + @abstractmethod + def generate(self) -> str: + """To generate and return the active CCAuthorizer token. + + Returns: token string + + """ + pass + + @abstractmethod + def verify(self, token: str) -> bool: + """To return the token verification result. + + Args: + token: bool + + Returns: + + """ + pass + + +class CCTokenGenerateError(Exception): + """Raised when a CC token generation failed""" + + pass + + +class CCTokenVerifyError(Exception): + """Raised when a CC token verification failed""" + + pass diff --git a/nvflare/app_opt/confidential_computing/cc_manager.py b/nvflare/app_opt/confidential_computing/cc_manager.py index 82fa7dba8a..052a61f5b4 100644 --- a/nvflare/app_opt/confidential_computing/cc_manager.py +++ b/nvflare/app_opt/confidential_computing/cc_manager.py @@ -11,95 +11,90 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import threading +import time +from typing import Dict, List from nvflare.apis.event_type import EventType from nvflare.apis.fl_component import FLComponent -from nvflare.apis.fl_constant import AdminCommandNames, FLContextKey +from nvflare.apis.fl_constant import FLContextKey, RunProcessKey from nvflare.apis.fl_context import FLContext from nvflare.apis.fl_exception import UnsafeComponentError - -from .cc_helper import CCHelper +from nvflare.app_opt.confidential_computing.cc_authorizer import CCAuthorizer, CCTokenGenerateError, CCTokenVerifyError +from nvflare.fuel.hci.conn import Connection +from nvflare.private.fed.server.training_cmds import TrainingCommandModule PEER_CTX_CC_TOKEN = "_peer_ctx_cc_token" CC_TOKEN = "_cc_token" +CC_ISSUER = "_cc_issuer" +CC_NAMESPACE = "_cc_namespace" CC_INFO = "_cc_info" CC_TOKEN_VALIDATED = "_cc_token_validated" +CC_VERIFY_ERROR = "_cc_verify_error." + +CC_ISSUER_ID = "issuer_id" +TOKEN_GENERATION_TIME = "token_generation_time" +TOKEN_EXPIRATION = "token_expiration" + +SHUTDOWN_SYSTEM = 1 +SHUTDOWN_JOB = 2 + +CC_VERIFICATION_FAILED = "not meeting CC requirements" class CCManager(FLComponent): - def __init__(self, verifiers: list): + def __init__( + self, + cc_issuers_conf: [Dict[str, str]], + cc_verifier_ids: [str], + verify_frequency=600, + critical_level=SHUTDOWN_JOB, + ): """Manage all confidential computing related tasks. This manager does the following tasks: - obtaining its own GPU CC token + obtaining its own CC token preparing the token to the server keeping clients' tokens in server validating all tokens in the entire NVFlare system + not allowing the system to start if failed to get CC token + shutdown the running jobs if CC tokens expired Args: - verifiers (list): - each element in this list is a dictionary and the keys of dictionary are - "devices", "env", "url", "appraisal_policy_file" and "result_policy_file." - - the values of devices are "gpu" and "cpu" - the values of env are "local" and "test" - currently, valid combination is gpu + local - - url must be an empty string - appraisal_policy_file must point to an existing file - currently supports an empty file only - - result_policy_file must point to an existing file - currently supports the following content only - - .. code-block:: json - - { - "version":"1.0", - "authorization-rules":{ - "x-nv-gpu-available":true, - "x-nv-gpu-attestation-report-available":true, - "x-nv-gpu-info-fetched":true, - "x-nv-gpu-arch-check":true, - "x-nv-gpu-root-cert-available":true, - "x-nv-gpu-cert-chain-verified":true, - "x-nv-gpu-ocsp-cert-chain-verified":true, - "x-nv-gpu-ocsp-signature-verified":true, - "x-nv-gpu-cert-ocsp-nonce-match":true, - "x-nv-gpu-cert-check-complete":true, - "x-nv-gpu-measurement-available":true, - "x-nv-gpu-attestation-report-parsed":true, - "x-nv-gpu-nonce-match":true, - "x-nv-gpu-attestation-report-driver-version-match":true, - "x-nv-gpu-attestation-report-vbios-version-match":true, - "x-nv-gpu-attestation-report-verified":true, - "x-nv-gpu-driver-rim-schema-fetched":true, - "x-nv-gpu-driver-rim-schema-validated":true, - "x-nv-gpu-driver-rim-cert-extracted":true, - "x-nv-gpu-driver-rim-signature-verified":true, - "x-nv-gpu-driver-rim-driver-measurements-available":true, - "x-nv-gpu-driver-vbios-rim-fetched":true, - "x-nv-gpu-vbios-rim-schema-validated":true, - "x-nv-gpu-vbios-rim-cert-extracted":true, - "x-nv-gpu-vbios-rim-signature-verified":true, - "x-nv-gpu-vbios-rim-driver-measurements-available":true, - "x-nv-gpu-vbios-index-conflict":true, - "x-nv-gpu-measurements-match":true - } - } + cc_issuers_conf: configuration of the CC token issuers. each contains the CC token issuer component ID, + and the token expiration time + cc_verifier_ids: CC token verifiers component IDs + verify_frequency: CC tokens verification frequency + critical_level: critical_level """ FLComponent.__init__(self) self.site_name = None - self.helper = None - self.verifiers = verifiers - self.my_token = None + self.cc_issuers_conf = cc_issuers_conf + self.cc_verifier_ids = cc_verifier_ids + + if not isinstance(verify_frequency, int): + raise ValueError(f"verify_frequency must be in, but got {verify_frequency.__class__}") + self.verify_frequency = int(verify_frequency) + + self.critical_level = critical_level + if self.critical_level not in [SHUTDOWN_SYSTEM, SHUTDOWN_JOB]: + raise ValueError(f"critical_level must be in [{SHUTDOWN_SYSTEM}, {SHUTDOWN_JOB}]. But got {critical_level}") + + self.verify_time = None + self.cc_issuers = {} + self.cc_verifiers = {} self.participant_cc_info = {} # used by the Server to keep tokens of all clients + self.token_submitted = False + self.lock = threading.Lock() + def handle_event(self, event_type: str, fl_ctx: FLContext): if event_type == EventType.SYSTEM_BOOTSTRAP: try: - err = self._prepare_for_attestation(fl_ctx) + self._setup_cc_authorizers(fl_ctx) + + err = self._generate_tokens(fl_ctx) except: self.log_exception(fl_ctx, "exception in attestation preparation") err = "exception in attestation preparation" @@ -107,26 +102,27 @@ def handle_event(self, event_type: str, fl_ctx: FLContext): if err: self.log_critical(fl_ctx, err, fire_event=False) raise UnsafeComponentError(err) - elif event_type == EventType.BEFORE_CLIENT_REGISTER: + elif event_type == EventType.BEFORE_CLIENT_REGISTER or event_type == EventType.BEFORE_CLIENT_HEARTBEAT: # On client side - self._prepare_token_for_login(fl_ctx) - elif event_type == EventType.CLIENT_REGISTERED: + self._prepare_cc_info(fl_ctx) + elif event_type == EventType.CLIENT_REGISTER_RECEIVED or event_type == EventType.CLIENT_HEARTBEAT_RECEIVED: # Server side self._add_client_token(fl_ctx) - elif event_type == EventType.AUTHORIZE_COMMAND_CHECK: - command_to_check = fl_ctx.get_prop(key=FLContextKey.COMMAND_NAME) - self.logger.debug(f"Received {command_to_check=}") - if command_to_check == AdminCommandNames.CHECK_RESOURCES: - try: - err = self._client_to_check_participant_token(fl_ctx) - except: - self.log_exception(fl_ctx, "exception in validating participants") - err = "Participants unable to meet client CC requirements" - finally: - if err: - self._not_authorize_job(err, fl_ctx) - elif event_type == EventType.BEFORE_CHECK_CLIENT_RESOURCES: + elif event_type == EventType.CLIENT_QUIT: # Server side + self._remove_client_token(fl_ctx) + elif event_type == EventType.BEFORE_CHECK_RESOURCE_MANAGER: + # Client side: check resources before job scheduled + try: + err = self._client_to_check_participant_token(fl_ctx) + except: + self.log_exception(fl_ctx, "exception in validating participants") + err = "Participants unable to meet client CC requirements" + finally: + if err: + self._block_job(err, fl_ctx) + elif event_type == EventType.BEFORE_CHECK_CLIENT_RESOURCES: + # Server side: job scheduler check client resources try: err = self._server_to_check_client_token(fl_ctx) except: @@ -134,35 +130,126 @@ def handle_event(self, event_type: str, fl_ctx: FLContext): err = "Clients unable to meet server CC requirements" finally: if err: - self._block_job(err, fl_ctx) + if self.critical_level == SHUTDOWN_JOB: + self._block_job(err, fl_ctx) + else: + threading.Thread(target=self._shutdown_system, args=[err, fl_ctx]).start() elif event_type == EventType.AFTER_CHECK_CLIENT_RESOURCES: - # Server side - fl_ctx.remove_prop(PEER_CTX_CC_TOKEN) + client_resource_result = fl_ctx.get_prop(FLContextKey.RESOURCE_CHECK_RESULT) + if client_resource_result: + for site_name, check_result in client_resource_result.items(): + is_resource_enough, reason = check_result + if ( + not is_resource_enough + and reason.startswith(CC_VERIFY_ERROR) + and self.critical_level == SHUTDOWN_SYSTEM + ): + threading.Thread(target=self._shutdown_system, args=[reason, fl_ctx]).start() + break + + def _setup_cc_authorizers(self, fl_ctx): + engine = fl_ctx.get_engine() + for conf in self.cc_issuers_conf: + issuer_id = conf.get(CC_ISSUER_ID) + expiration = conf.get(TOKEN_EXPIRATION) + issuer = engine.get_component(issuer_id) + if not isinstance(issuer, CCAuthorizer): + raise RuntimeError(f"cc_issuer_id {issuer_id} must be a CCAuthorizer, but got {issuer.__class__}") + self.cc_issuers[issuer] = expiration + + for v_id in self.cc_verifier_ids: + verifier = engine.get_component(v_id) + if not isinstance(verifier, CCAuthorizer): + raise RuntimeError(f"cc_authorizer_id {v_id} must be a CCAuthorizer, but got {verifier.__class__}") + namespace = verifier.get_namespace() + if namespace in self.cc_verifiers.keys(): + raise RuntimeError(f"Authorizer with namespace: {namespace} already exist.") + self.cc_verifiers[namespace] = verifier - def _prepare_token_for_login(self, fl_ctx: FLContext): - # client side - if self.my_token is None: - self.my_token = self.helper.get_token() - cc_info = {CC_TOKEN: self.my_token} - fl_ctx.set_prop(key=CC_INFO, value=cc_info, sticky=False, private=False) + def _prepare_cc_info(self, fl_ctx: FLContext): + # client side: if token expired then generate a new one + self._handle_expired_tokens() + + if not self.token_submitted: + site_cc_info = self.participant_cc_info[self.site_name] + cc_info = self._get_participant_tokens(site_cc_info) + fl_ctx.set_prop(key=CC_INFO, value=cc_info, sticky=False, private=False) + self.logger.info("Sent the CC-tokens to server.") + self.token_submitted = True def _add_client_token(self, fl_ctx: FLContext): # server side peer_ctx = fl_ctx.get_peer_context() token_owner = peer_ctx.get_identity_name() peer_cc_info = peer_ctx.get_prop(CC_INFO) - self.participant_cc_info[token_owner] = peer_cc_info - self.participant_cc_info[token_owner][CC_TOKEN_VALIDATED] = False - def _prepare_for_attestation(self, fl_ctx: FLContext) -> str: + if peer_cc_info: + self.participant_cc_info[token_owner] = peer_cc_info + self.logger.info(f"Added CC client: {token_owner} tokens: {peer_cc_info}") + + if not self.verify_time or time.time() - self.verify_time > self.verify_frequency: + self._verify_running_jobs(fl_ctx) + + def _verify_running_jobs(self, fl_ctx): + engine = fl_ctx.get_engine() + run_processes = engine.run_processes + running_jobs = list(run_processes.keys()) + with self.lock: + for job_id in running_jobs: + job_participants = run_processes[job_id].get(RunProcessKey.PARTICIPANTS) + participants = [] + for _, client in job_participants.items(): + participants.append(client.name) + + err, participant_tokens = self._verify_participants(participants) + if err: + if self.critical_level == SHUTDOWN_JOB: + # maybe shutdown the whole system here. leave the user to define the action + engine.job_runner.stop_run(job_id, fl_ctx) + self.logger.info(f"Stop Job: {job_id} with CC verification error: {err} ") + else: + threading.Thread(target=self._shutdown_system, args=[err, fl_ctx]).start() + + self.verify_time = time.time() + + def _remove_client_token(self, fl_ctx: FLContext): + # server side + peer_ctx = fl_ctx.get_peer_context() + token_owner = peer_ctx.get_identity_name() + if token_owner in self.participant_cc_info.keys(): + self.participant_cc_info.pop(token_owner) + self.logger.info(f"Removed CC client: {token_owner}") + + def _generate_tokens(self, fl_ctx: FLContext) -> str: # both server and client sides self.site_name = fl_ctx.get_identity_name() - self.helper = CCHelper(site_name=self.site_name, verifiers=self.verifiers) - ok = self.helper.prepare() - if not ok: - return "failed to attest" - self.my_token = self.helper.get_token() - self.participant_cc_info[self.site_name] = {CC_TOKEN: self.my_token, CC_TOKEN_VALIDATED: True} + workspace_folder = fl_ctx.get_prop(FLContextKey.WORKSPACE_OBJECT).get_site_config_dir() + + self.participant_cc_info[self.site_name] = [] + for issuer, expiration in self.cc_issuers.items(): + try: + my_token = issuer.generate() + namespace = issuer.get_namespace() + + if not isinstance(expiration, int): + raise ValueError(f"token_expiration value must be int, but got {expiration.__class__}") + if not my_token: + return f"{issuer} failed to get CC token" + + self.logger.info(f"site: {self.site_name} namespace: {namespace} got the token: {my_token}") + cc_info = { + CC_TOKEN: my_token, + CC_ISSUER: issuer, + CC_NAMESPACE: namespace, + TOKEN_GENERATION_TIME: time.time(), + TOKEN_EXPIRATION: int(expiration), + CC_TOKEN_VALIDATED: True, + } + self.participant_cc_info[self.site_name].append(cc_info) + self.token_submitted = False + except CCTokenGenerateError: + raise RuntimeError(f"{issuer} failed to generate CC token.") + return "" def _client_to_check_participant_token(self, fl_ctx: FLContext) -> str: @@ -192,47 +279,107 @@ def _server_to_check_client_token(self, fl_ctx: FLContext) -> str: if not isinstance(participants, list): return f"bad value for {FLContextKey.JOB_PARTICIPANTS} in fl_ctx: expect list bot got {type(participants)}" - participant_tokens = {self.site_name: self.my_token} + err, participant_tokens = self._verify_participants(participants) + if err: + return err + + fl_ctx.set_prop(key=PEER_CTX_CC_TOKEN, value=participant_tokens, sticky=False, private=False) + self.logger.info(f"{self.site_name=} set PEER_CTX_CC_TOKEN with {participant_tokens=}") + return "" + + def _verify_participants(self, participants): + # if server token expired, then generates a new one + self._handle_expired_tokens() + + participant_tokens = {} + site_cc_info = self.participant_cc_info[self.site_name] + participant_tokens[self.site_name] = self._get_participant_tokens(site_cc_info) + for p in participants: assert isinstance(p, str) if p == self.site_name: continue - if p not in self.participant_cc_info: - return f"no token available for participant {p}" - participant_tokens[p] = self.participant_cc_info[p][CC_TOKEN] + # if p not in self.participant_cc_info: + # return f"no token available for participant {p}" + if self.participant_cc_info.get(p): + participant_tokens[p] = self._get_participant_tokens(self.participant_cc_info[p]) + else: + participant_tokens[p] = [{CC_TOKEN: "", CC_NAMESPACE: ""}] + return self._validate_participants_tokens(participant_tokens), participant_tokens - err = self._validate_participants_tokens(participant_tokens) - if err: - return err + def _get_participant_tokens(self, site_cc_info): + cc_info = [] + for i in site_cc_info: + namespace = i.get(CC_NAMESPACE) + token = i.get(CC_TOKEN) + cc_info.append({CC_TOKEN: token, CC_NAMESPACE: namespace, CC_TOKEN_VALIDATED: False}) + return cc_info - for p in participant_tokens: - self.participant_cc_info[p][CC_TOKEN_VALIDATED] = True - fl_ctx.set_prop(key=PEER_CTX_CC_TOKEN, value=participant_tokens, sticky=True, private=False) - self.logger.debug(f"{self.site_name=} set PEER_CTX_CC_TOKEN with {participant_tokens=}") - return "" + def _handle_expired_tokens(self): + site_cc_info = self.participant_cc_info[self.site_name] + for i in site_cc_info: + issuer = i.get(CC_ISSUER) + token_generate_time = i.get(TOKEN_GENERATION_TIME) + expiration = i.get(TOKEN_EXPIRATION) + if time.time() - token_generate_time > expiration: + token = issuer.generate() + i[CC_TOKEN] = token + i[TOKEN_GENERATION_TIME] = time.time() + self.logger.info( + f"site: {self.site_name} namespace: {issuer.get_namespace()} got a new CC token: {token}" + ) + + self.token_submitted = False def _validate_participants_tokens(self, participants) -> str: self.logger.debug(f"Validating participant tokens {participants=}") - result = self.helper.validate_participants(participants) - assert isinstance(result, dict) - for p in result: - self.participant_cc_info[p] = {CC_TOKEN: participants[p], CC_TOKEN_VALIDATED: True} - invalid_participant_list = [k for k, v in self.participant_cc_info.items() if v[CC_TOKEN_VALIDATED] is False] + result, invalid_participant_list = self._validate_participants(participants) if invalid_participant_list: invalid_participant_string = ",".join(invalid_participant_list) self.logger.debug(f"{invalid_participant_list=}") - return f"Participant {invalid_participant_string} not meeting CC requirements" + return f"Participant {invalid_participant_string}" + CC_VERIFICATION_FAILED else: return "" - def _not_authorize_job(self, reason: str, fl_ctx: FLContext): - job_id = fl_ctx.get_prop(FLContextKey.CURRENT_JOB_ID, "") - self.log_error(fl_ctx, f"Job {job_id} is blocked: {reason}") - fl_ctx.set_prop(key=FLContextKey.AUTHORIZATION_REASON, value=reason) - fl_ctx.set_prop(key=FLContextKey.AUTHORIZATION_RESULT, value=False) + def _validate_participants(self, participants: Dict[str, List[Dict[str, str]]]) -> (Dict[str, bool], List[str]): + result = {} + invalid_participant_list = [] + if not participants: + return result, invalid_participant_list + for k, cc_info in participants.items(): + for v in cc_info: + token = v.get(CC_TOKEN, "") + namespace = v.get(CC_NAMESPACE, "") + verifier = self.cc_verifiers.get(namespace, None) + try: + if verifier and verifier.verify(token): + result[k + "." + namespace] = True + else: + invalid_participant_list.append(k + " namespace: {" + namespace + "}") + except CCTokenVerifyError: + invalid_participant_list.append(k + " namespace: {" + namespace + "}") + self.logger.info(f"CC - results from validating participants' tokens: {result}") + return result, invalid_participant_list def _block_job(self, reason: str, fl_ctx: FLContext): job_id = fl_ctx.get_prop(FLContextKey.CURRENT_JOB_ID, "") self.log_error(fl_ctx, f"Job {job_id} is blocked: {reason}") - fl_ctx.set_prop(key=FLContextKey.JOB_BLOCK_REASON, value=reason) - fl_ctx.set_prop(key=FLContextKey.AUTHORIZATION_RESULT, value=False) + fl_ctx.set_prop(key=FLContextKey.JOB_BLOCK_REASON, value=CC_VERIFY_ERROR + reason, sticky=False) + fl_ctx.set_prop(key=FLContextKey.AUTHORIZATION_RESULT, value=False, sticky=False) + + def _shutdown_system(self, reason: str, fl_ctx: FLContext): + engine = fl_ctx.get_engine() + run_processes = engine.run_processes + running_jobs = list(run_processes.keys()) + for job_id in running_jobs: + engine.job_runner.stop_run(job_id, fl_ctx) + + conn = Connection({}, engine.server.admin_server) + conn.app_ctx = engine + + cmd = TrainingCommandModule() + args = ["shutdown", "all"] + cmd.validate_command_targets(conn, args[1:]) + cmd.shutdown(conn, args) + + self.logger.error(f"CC system shutdown! due to reason: {reason}") diff --git a/nvflare/app_opt/confidential_computing/gpu_authorizer.py b/nvflare/app_opt/confidential_computing/gpu_authorizer.py new file mode 100644 index 0000000000..bd55e2a463 --- /dev/null +++ b/nvflare/app_opt/confidential_computing/gpu_authorizer.py @@ -0,0 +1,93 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from nvflare.app_opt.confidential_computing.cc_authorizer import CCAuthorizer + +GPU_NAMESPACE = "x-nv-gpu-" + + +class GPUAuthorizer(CCAuthorizer): + """Note: This is just a fake implementation for GPU authorizer. It will be replaced later + with the real implementation. + + """ + + def __init__(self, verifiers: list) -> None: + """ + + Args: + verifiers (list): + each element in this list is a dictionary and the keys of dictionary are + "devices", "env", "url", "appraisal_policy_file" and "result_policy_file." + + the values of devices are "gpu" and "cpu" + the values of env are "local" and "test" + currently, valid combination is gpu + local + + url must be an empty string + appraisal_policy_file must point to an existing file + currently supports an empty file only + + result_policy_file must point to an existing file + currently supports the following content only + + .. code-block:: json + + { + "version":"1.0", + "authorization-rules":{ + "x-nv-gpu-available":true, + "x-nv-gpu-attestation-report-available":true, + "x-nv-gpu-info-fetched":true, + "x-nv-gpu-arch-check":true, + "x-nv-gpu-root-cert-available":true, + "x-nv-gpu-cert-chain-verified":true, + "x-nv-gpu-ocsp-cert-chain-verified":true, + "x-nv-gpu-ocsp-signature-verified":true, + "x-nv-gpu-cert-ocsp-nonce-match":true, + "x-nv-gpu-cert-check-complete":true, + "x-nv-gpu-measurement-available":true, + "x-nv-gpu-attestation-report-parsed":true, + "x-nv-gpu-nonce-match":true, + "x-nv-gpu-attestation-report-driver-version-match":true, + "x-nv-gpu-attestation-report-vbios-version-match":true, + "x-nv-gpu-attestation-report-verified":true, + "x-nv-gpu-driver-rim-schema-fetched":true, + "x-nv-gpu-driver-rim-schema-validated":true, + "x-nv-gpu-driver-rim-cert-extracted":true, + "x-nv-gpu-driver-rim-signature-verified":true, + "x-nv-gpu-driver-rim-driver-measurements-available":true, + "x-nv-gpu-driver-vbios-rim-fetched":true, + "x-nv-gpu-vbios-rim-schema-validated":true, + "x-nv-gpu-vbios-rim-cert-extracted":true, + "x-nv-gpu-vbios-rim-signature-verified":true, + "x-nv-gpu-vbios-rim-driver-measurements-available":true, + "x-nv-gpu-vbios-index-conflict":true, + "x-nv-gpu-measurements-match":true + } + } + + """ + super().__init__() + self.verifiers = verifiers + + def get_namespace(self) -> str: + return GPU_NAMESPACE + + def generate(self) -> str: + raise NotImplementedError + + def verify(self, token: str) -> bool: + raise NotImplementedError diff --git a/nvflare/app_opt/confidential_computing/tdx_authorizer.py b/nvflare/app_opt/confidential_computing/tdx_authorizer.py new file mode 100644 index 0000000000..21bff9035e --- /dev/null +++ b/nvflare/app_opt/confidential_computing/tdx_authorizer.py @@ -0,0 +1,79 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import subprocess + +from nvflare.app_opt.confidential_computing.cc_authorizer import CCAuthorizer + +TDX_NAMESPACE = "tdx_" +TDX_CLI_CONFIG = "config.json" +TOKEN_FILE = "token.txt" +VERIFY_FILE = "verify.txt" +ERROR_FILE = "error.txt" + + +class TDXAuthorizer(CCAuthorizer): + def __init__(self, tdx_cli_command: str, config_dir: str) -> None: + super().__init__() + self.tdx_cli_command = tdx_cli_command + self.config_dir = config_dir + + self.config_file = os.path.join(self.config_dir, TDX_CLI_CONFIG) + + def generate(self) -> str: + token_file = os.path.join(self.config_dir, TOKEN_FILE) + out = open(token_file, "w") + error_file = os.path.join(self.config_dir, ERROR_FILE) + err_out = open(error_file, "w") + + command = ["sudo", self.tdx_cli_command, "-c", self.config_file, "token", "--no-eventlog"] + subprocess.run(command, preexec_fn=os.setsid, stdout=out, stderr=err_out) + + if not os.path.exists(error_file) or not os.path.exists(token_file): + return "" + + try: + with open(error_file, "r") as e_f: + if "Error:" in e_f.read(): + return "" + else: + with open(token_file, "r") as t_f: + token = t_f.readline() + return token + except: + return "" + + def verify(self, token: str) -> bool: + out = open(os.path.join(self.config_dir, VERIFY_FILE), "w") + error_file = os.path.join(self.config_dir, ERROR_FILE) + err_out = open(error_file, "w") + + command = [self.tdx_cli_command, "verify", "--config", self.config_file, "--token", token] + subprocess.run(command, preexec_fn=os.setsid, stdout=out, stderr=err_out) + + if not os.path.exists(error_file): + return False + + try: + with open(error_file, "r") as f: + if "Error:" in f.read(): + return False + except: + return False + + return True + + def get_namespace(self) -> str: + return TDX_NAMESPACE diff --git a/nvflare/app_opt/lightning/api.py b/nvflare/app_opt/lightning/api.py index 567fe1b0a1..8cf7c6f29f 100644 --- a/nvflare/app_opt/lightning/api.py +++ b/nvflare/app_opt/lightning/api.py @@ -65,16 +65,17 @@ def __init__(self): self.__fl_meta__ = {"CUSTOM_VAR": "VALUE_OF_THE_VAR"} """ - fl_callback = FLCallback(rank=trainer.global_rank, load_state_dict_strict=load_state_dict_strict) callbacks = trainer.callbacks - if isinstance(callbacks, list): + if isinstance(callbacks, Callback): + callbacks = [callbacks] + elif not isinstance(callbacks, list): + callbacks = [] + + if not any(isinstance(cb, FLCallback) for cb in callbacks): + fl_callback = FLCallback(rank=trainer.global_rank, load_state_dict_strict=load_state_dict_strict) callbacks.append(fl_callback) - elif isinstance(callbacks, Callback): - callbacks = [callbacks, fl_callback] - else: - callbacks = [fl_callback] - if restore_state: + if restore_state and not any(isinstance(cb, RestoreState) for cb in callbacks): callbacks.append(RestoreState()) trainer.callbacks = callbacks diff --git a/nvflare/client/api.py b/nvflare/client/api.py index 2cb43fdbcb..0f95f50b07 100644 --- a/nvflare/client/api.py +++ b/nvflare/client/api.py @@ -70,6 +70,8 @@ def send(model: FLModel, clear_cache: bool = True) -> None: model (FLModel): Sends a FLModel object. clear_cache: clear cache after send """ + if not isinstance(model, FLModel): + raise TypeError("model needs to be an instance of FLModel") global client_api return client_api.send(model, clear_cache) diff --git a/nvflare/fuel/f3/cellnet/cell.py b/nvflare/fuel/f3/cellnet/cell.py index bbd73a8fc4..07f7a72c68 100644 --- a/nvflare/fuel/f3/cellnet/cell.py +++ b/nvflare/fuel/f3/cellnet/cell.py @@ -28,7 +28,25 @@ from nvflare.fuel.f3.streaming.stream_types import StreamFuture from nvflare.private.defs import CellChannel -CHANNELS_TO_HANDLE = (CellChannel.SERVER_COMMAND, CellChannel.AUX_COMMUNICATION) +CHANNELS_TO_EXCLUDE = ( + CellChannel.CLIENT_MAIN, + CellChannel.SERVER_MAIN, + CellChannel.SERVER_PARENT_LISTENER, + CellChannel.CLIENT_COMMAND, + CellChannel.CLIENT_SUB_WORKER_COMMAND, + CellChannel.MULTI_PROCESS_EXECUTOR, + CellChannel.SIMULATOR_RUNNER, + CellChannel.RETURN_ONLY, +) + + +def _is_stream_channel(channel: str) -> bool: + if channel is None or channel == "": + return False + elif channel in CHANNELS_TO_EXCLUDE: + return False + # if not excluded, all channels supporting streaming capabilities + return True class SimpleWaiter: @@ -104,13 +122,13 @@ def __getattr__(self, func): This method is called when Python cannot find an invoked method "x" of this class. Method "x" is one of the message sending methods (send_request, broadcast_request, etc.) In this method, we decide which method should be used instead, based on the "channel" of the message. - - If the channel is in CHANNELS_TO_HANDLE, use the method "_x" of this class. - - Otherwise, user the method "x" of the core_cell. + - If the channel is stream channel, use the method "_x" of this class. + - Otherwise, user the method "x" of the CoreCell. """ def method(*args, **kwargs): self.logger.debug(f"__getattr__: {args=}, {kwargs=}") - if kwargs.get("channel") in CHANNELS_TO_HANDLE: + if _is_stream_channel(kwargs.get("channel")): self.logger.debug(f"calling cell {func}") return getattr(self, f"_{func}")(*args, **kwargs) if not hasattr(self.core_cell, func): @@ -311,7 +329,7 @@ def _register_request_cb(self, channel: str, topic: str, cb, *args, **kwargs): if not callable(cb): raise ValueError(f"specified request_cb {type(cb)} is not callable") - if channel in CHANNELS_TO_HANDLE: + if _is_stream_channel(channel): self.logger.info(f"Register blob CB for {channel=}, {topic=}") adapter = Adapter(cb, self.core_cell.my_info, self) self.register_blob_cb(channel, topic, adapter.call, *args, **kwargs) diff --git a/nvflare/private/fed/app/client/client_train.py b/nvflare/private/fed/app/client/client_train.py index f933c14349..2618974a81 100644 --- a/nvflare/private/fed/app/client/client_train.py +++ b/nvflare/private/fed/app/client/client_train.py @@ -27,7 +27,7 @@ from nvflare.fuel.utils.argument_utils import parse_vars from nvflare.private.defs import AppFolderConstants from nvflare.private.fed.app.fl_conf import FLClientStarterConfiger, create_privacy_manager -from nvflare.private.fed.app.utils import version_check +from nvflare.private.fed.app.utils import component_security_check, version_check from nvflare.private.fed.client.admin import FedAdminAgent from nvflare.private.fed.client.client_engine import ClientEngine from nvflare.private.fed.client.client_status import ClientStatus @@ -108,8 +108,11 @@ def main(args): time.sleep(1.0) with client_engine.new_context() as fl_ctx: + fl_ctx.set_prop(FLContextKey.WORKSPACE_OBJECT, workspace, private=True) client_engine.fire_event(EventType.SYSTEM_BOOTSTRAP, fl_ctx) + component_security_check(fl_ctx) + client_engine.fire_event(EventType.BEFORE_CLIENT_REGISTER, fl_ctx) federated_client.register(fl_ctx) fl_ctx.set_prop(FLContextKey.CLIENT_TOKEN, federated_client.token) diff --git a/nvflare/private/fed/app/deployer/server_deployer.py b/nvflare/private/fed/app/deployer/server_deployer.py index e800820497..ecb2a1b04f 100644 --- a/nvflare/private/fed/app/deployer/server_deployer.py +++ b/nvflare/private/fed/app/deployer/server_deployer.py @@ -16,8 +16,9 @@ import threading from nvflare.apis.event_type import EventType -from nvflare.apis.fl_constant import SystemComponents +from nvflare.apis.fl_constant import FLContextKey, SystemComponents from nvflare.apis.workspace import Workspace +from nvflare.private.fed.app.utils import component_security_check from nvflare.private.fed.server.fed_server import FederatedServer from nvflare.private.fed.server.job_runner import JobRunner from nvflare.private.fed.server.run_manager import RunManager @@ -119,8 +120,11 @@ def deploy(self, args): run_manager.add_component(SystemComponents.JOB_RUNNER, job_runner) with services.engine.new_context() as fl_ctx: + fl_ctx.set_prop(FLContextKey.WORKSPACE_OBJECT, workspace, private=True) services.engine.fire_event(EventType.SYSTEM_BOOTSTRAP, fl_ctx) + component_security_check(fl_ctx) + threading.Thread(target=self._start_job_runner, args=[job_runner, fl_ctx]).start() services.status = ServerStatus.STARTED diff --git a/nvflare/private/fed/app/simulator/simulator_runner.py b/nvflare/private/fed/app/simulator/simulator_runner.py index b6a84c4516..f85cac0d09 100644 --- a/nvflare/private/fed/app/simulator/simulator_runner.py +++ b/nvflare/private/fed/app/simulator/simulator_runner.py @@ -505,6 +505,7 @@ def __init__(self, args, clients: [], client_config, deploy_args, build_ctx): self.deploy_args = deploy_args self.build_ctx = build_ctx self.kv_list = parse_vars(args.set) + self.logging_config = os.path.join(self.args.workspace, "local", WorkspaceConstants.LOGGING_CONFIG) self.end_run_clients = [] @@ -573,10 +574,13 @@ def run_client_thread(self, num_of_threads, gpu, lock, rank, timeout=60): def do_one_task(self, client, num_of_threads, gpu, lock, timeout=60.0, task_name=RunnerTask.TASK_EXEC): open_port = get_open_ports(1)[0] + client_workspace = os.path.join(self.args.workspace, SimulatorConstants.JOB_NAME, "app_" + client.client_name) command = ( sys.executable + " -m nvflare.private.fed.app.simulator.simulator_worker -o " - + self.args.workspace + + client_workspace + + " --logging_config " + + self.logging_config + " --client " + client.client_name + " --token " diff --git a/nvflare/private/fed/app/simulator/simulator_worker.py b/nvflare/private/fed/app/simulator/simulator_worker.py index de4ba3b702..7fd94f441b 100644 --- a/nvflare/private/fed/app/simulator/simulator_worker.py +++ b/nvflare/private/fed/app/simulator/simulator_worker.py @@ -231,15 +231,11 @@ def main(args): thread = threading.Thread(target=check_parent_alive, args=(parent_pid, stop_event)) thread.start() - log_config_file_path = os.path.join(args.workspace, "startup", WorkspaceConstants.LOGGING_CONFIG) - if not os.path.isfile(log_config_file_path): - log_config_file_path = os.path.join(os.path.dirname(__file__), WorkspaceConstants.LOGGING_CONFIG) - logging.config.fileConfig(fname=log_config_file_path, disable_existing_loggers=False) - workspace = os.path.join(args.workspace, SimulatorConstants.JOB_NAME, "app_" + args.client) - log_file = os.path.join(workspace, WorkspaceConstants.LOG_FILE_NAME) + logging.config.fileConfig(fname=args.logging_config, disable_existing_loggers=False) + log_file = os.path.join(args.workspace, WorkspaceConstants.LOG_FILE_NAME) add_logfile_handler(log_file) - os.chdir(workspace) + os.chdir(args.workspace) fobs_initialize() AuthorizationService.initialize(EmptyAuthorizer()) # AuditService.initialize(audit_file_name=WorkspaceConstants.AUDIT_LOG) @@ -263,6 +259,7 @@ def main(args): def parse_arguments(): parser = argparse.ArgumentParser() parser.add_argument("--workspace", "-o", type=str, help="WORKSPACE folder", required=True) + parser.add_argument("--logging_config", type=str, help="logging config file", required=True) parser.add_argument("--client", type=str, help="Client name", required=True) parser.add_argument("--token", type=str, help="Client token", required=True) parser.add_argument("--port", type=str, help="Listen port", required=True) diff --git a/nvflare/private/fed/app/utils.py b/nvflare/private/fed/app/utils.py index 8552ec4945..94712d4b21 100644 --- a/nvflare/private/fed/app/utils.py +++ b/nvflare/private/fed/app/utils.py @@ -20,6 +20,9 @@ import psutil +from nvflare.apis.fl_constant import FLContextKey +from nvflare.apis.fl_context import FLContext +from nvflare.apis.fl_exception import UnsafeComponentError from nvflare.fuel.hci.security import hash_password from nvflare.private.defs import SSLConstants from nvflare.private.fed.runner import Runner @@ -98,3 +101,12 @@ def version_check(): raise RuntimeError("Python versions 3.11 and above are not yet supported. Please use Python 3.8, 3.9 or 3.10.") if sys.version_info < (3, 8): raise RuntimeError("Python versions 3.7 and below are not supported. Please use Python 3.8, 3.9 or 3.10") + + +def component_security_check(fl_ctx: FLContext): + exceptions = fl_ctx.get_prop(FLContextKey.EXCEPTIONS) + if exceptions: + for _, exception in exceptions.items(): + if isinstance(exception, UnsafeComponentError): + print(f"Unsafe component configured, could not start {fl_ctx.get_identity_name()}!!") + raise RuntimeError(exception) diff --git a/nvflare/private/fed/client/communicator.py b/nvflare/private/fed/client/communicator.py index d3475f0253..e95e2f5269 100644 --- a/nvflare/private/fed/client/communicator.py +++ b/nvflare/private/fed/client/communicator.py @@ -17,6 +17,7 @@ import time from typing import List, Optional +from nvflare.apis.event_type import EventType from nvflare.apis.filter import Filter from nvflare.apis.fl_constant import FLContextKey from nvflare.apis.fl_constant import ReturnCode as ShareableRC @@ -292,6 +293,9 @@ def quit_remote(self, servers, task_name, token, ssid, fl_ctx: FLContext): server's reply to the last message """ + shared_fl_ctx = gen_new_peer_ctx(fl_ctx) + shareable = Shareable() + shareable.set_header(ServerCommandKey.PEER_FL_CONTEXT, shared_fl_ctx) client_name = fl_ctx.get_identity_name() quit_message = new_cell_message( { @@ -299,7 +303,8 @@ def quit_remote(self, servers, task_name, token, ssid, fl_ctx: FLContext): CellMessageHeaderKeys.CLIENT_NAME: client_name, CellMessageHeaderKeys.SSID: ssid, CellMessageHeaderKeys.PROJECT_NAME: task_name, - } + }, + shareable, ) try: result = self.cell.send_request( @@ -328,6 +333,11 @@ def send_heartbeat(self, servers, task_name, token, ssid, client_name, engine: C heartbeats_log_interval = 10 while not self.heartbeat_done: try: + engine.fire_event(EventType.BEFORE_CLIENT_HEARTBEAT, fl_ctx) + shareable = Shareable() + shared_fl_ctx = gen_new_peer_ctx(fl_ctx) + shareable.set_header(ServerCommandKey.PEER_FL_CONTEXT, shared_fl_ctx) + job_ids = engine.get_all_job_ids() heartbeat_message = new_cell_message( { @@ -336,7 +346,8 @@ def send_heartbeat(self, servers, task_name, token, ssid, client_name, engine: C CellMessageHeaderKeys.CLIENT_NAME: client_name, CellMessageHeaderKeys.PROJECT_NAME: task_name, CellMessageHeaderKeys.JOB_IDS: job_ids, - } + }, + shareable, ) try: @@ -367,6 +378,7 @@ def send_heartbeat(self, servers, task_name, token, ssid, client_name, engine: C except Exception as ex: raise FLCommunicationError("error:client_quit", ex) + engine.fire_event(EventType.AFTER_CLIENT_HEARTBEAT, fl_ctx) for i in range(wait_times): time.sleep(2) if self.heartbeat_done: diff --git a/nvflare/private/fed/client/scheduler_cmds.py b/nvflare/private/fed/client/scheduler_cmds.py index 58828121bb..7ea2e8d55b 100644 --- a/nvflare/private/fed/client/scheduler_cmds.py +++ b/nvflare/private/fed/client/scheduler_cmds.py @@ -16,7 +16,7 @@ from typing import List from nvflare.apis.event_type import EventType -from nvflare.apis.fl_constant import FLContextKey, ReturnCode, SystemComponents +from nvflare.apis.fl_constant import FLContextKey, ReturnCode, ServerCommandKey, SystemComponents from nvflare.apis.resource_manager_spec import ResourceConsumerSpec, ResourceManagerSpec from nvflare.apis.shareable import Shareable from nvflare.private.admin_defs import Message @@ -68,6 +68,8 @@ def process(self, req: Message, app_ctx) -> Message: fl_ctx.set_prop(key=FLContextKey.CLIENT_RESOURCE_SPECS, value=resource_spec, private=True, sticky=False) fl_ctx.set_prop(FLContextKey.CURRENT_JOB_ID, job_id, private=True, sticky=False) + shared_fl_ctx = req.get_header(ServerCommandKey.PEER_FL_CONTEXT) + fl_ctx.set_peer_context(shared_fl_ctx) engine.fire_event(EventType.BEFORE_CHECK_RESOURCE_MANAGER, fl_ctx) block_reason = fl_ctx.get_prop(FLContextKey.JOB_BLOCK_REASON) @@ -78,7 +80,7 @@ def process(self, req: Message, app_ctx) -> Message: is_resource_enough, token = resource_manager.check_resources( resource_requirement=resource_spec, fl_ctx=fl_ctx ) - except Exception: + except Exception as e: result.set_return_code(ReturnCode.EXECUTION_EXCEPTION) result.set_header(ShareableHeader.IS_RESOURCE_ENOUGH, is_resource_enough) diff --git a/nvflare/private/fed/server/admin.py b/nvflare/private/fed/server/admin.py index eec732013a..71fc765939 100644 --- a/nvflare/private/fed/server/admin.py +++ b/nvflare/private/fed/server/admin.py @@ -11,13 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import copy import threading import time from typing import List, Optional from nvflare.apis.event_type import EventType -from nvflare.apis.shareable import ReservedHeaderKey +from nvflare.apis.fl_constant import ServerCommandKey +from nvflare.apis.fl_context import FLContext +from nvflare.apis.utils.fl_context_utils import gen_new_peer_ctx from nvflare.fuel.f3.cellnet.cell import Cell from nvflare.fuel.f3.cellnet.net_agent import NetAgent from nvflare.fuel.f3.cellnet.net_manager import NetManager @@ -229,11 +230,12 @@ def send_request_to_client(self, req: Message, client_token: str, timeout_secs=2 if not isinstance(req, Message): raise TypeError("request must be Message but got {}".format(type(req))) reqs = {client_token: req} - replies = self.send_requests(reqs, timeout_secs=timeout_secs) - if replies is None or len(replies) <= 0: - return None - else: - return replies[0] + with self.sai.new_context() as fl_ctx: + replies = self.send_requests(reqs, fl_ctx, timeout_secs=timeout_secs) + if replies is None or len(replies) <= 0: + return None + else: + return replies[0] def send_requests_and_get_reply_dict(self, requests: dict, timeout_secs=2.0) -> dict: """Send requests to clients @@ -250,12 +252,13 @@ def send_requests_and_get_reply_dict(self, requests: dict, timeout_secs=2.0) -> for token, _ in requests.items(): result[token] = None - replies = self.send_requests(requests, timeout_secs=timeout_secs) - for r in replies: - result[r.client_token] = r.reply + with self.sai.new_context() as fl_ctx: + replies = self.send_requests(requests, fl_ctx, timeout_secs=timeout_secs) + for r in replies: + result[r.client_token] = r.reply return result - def send_requests(self, requests: dict, timeout_secs=2.0, optional=False) -> [ClientReply]: + def send_requests(self, requests: dict, fl_ctx: FLContext, timeout_secs=2.0, optional=False) -> [ClientReply]: """Send requests to clients. NOTE:: @@ -266,6 +269,7 @@ def send_requests(self, requests: dict, timeout_secs=2.0, optional=False) -> [Cl Args: requests: A dict of requests: {client token: request or list of requests} + fl_ctx: FLContext timeout_secs: how long to wait for reply before timeout optional: whether the requests are optional @@ -274,9 +278,10 @@ def send_requests(self, requests: dict, timeout_secs=2.0, optional=False) -> [Cl """ for _, request in requests.items(): - with self.sai.new_context() as fl_ctx: - self.sai.fire_event(EventType.BEFORE_SEND_ADMIN_COMMAND, fl_ctx) - request.set_header(ReservedHeaderKey.PEER_PROPS, copy.deepcopy(fl_ctx.get_all_public_props())) + # with self.sai.new_context() as fl_ctx: + self.sai.fire_event(EventType.BEFORE_SEND_ADMIN_COMMAND, fl_ctx) + shared_fl_ctx = gen_new_peer_ctx(fl_ctx) + request.set_header(ServerCommandKey.PEER_FL_CONTEXT, shared_fl_ctx) return send_requests( cell=self.cell, diff --git a/nvflare/private/fed/server/cmd_utils.py b/nvflare/private/fed/server/cmd_utils.py index d2d7bd6bfd..4cc8cc72ce 100644 --- a/nvflare/private/fed/server/cmd_utils.py +++ b/nvflare/private/fed/server/cmd_utils.py @@ -148,7 +148,8 @@ def send_request_to_clients(self, conn, message): cmd_timeout = conn.get_prop(ConnProps.CMD_TIMEOUT) if not cmd_timeout: cmd_timeout = admin_server.timeout - replies = admin_server.send_requests(requests, timeout_secs=cmd_timeout) + with admin_server.sai.new_context() as fl_ctx: + replies = admin_server.send_requests(requests, fl_ctx, timeout_secs=cmd_timeout) return replies diff --git a/nvflare/private/fed/server/fed_server.py b/nvflare/private/fed/server/fed_server.py index d07dc99246..ae50e4c3a6 100644 --- a/nvflare/private/fed/server/fed_server.py +++ b/nvflare/private/fed/server/fed_server.py @@ -493,7 +493,7 @@ def register_client(self, request: Message) -> Message: shared_fl_ctx = data.get_header(ServerCommandKey.PEER_FL_CONTEXT) fl_ctx.set_peer_context(shared_fl_ctx) - self.engine.fire_event(EventType.CLIENT_REGISTERED, fl_ctx=fl_ctx) + self.engine.fire_event(EventType.CLIENT_REGISTER_RECEIVED, fl_ctx=fl_ctx) exceptions = fl_ctx.get_prop(FLContextKey.EXCEPTIONS) if exceptions: @@ -513,6 +513,7 @@ def register_client(self, request: Message) -> Message: } else: headers = {} + self.engine.fire_event(EventType.CLIENT_REGISTER_PROCESSED, fl_ctx=fl_ctx) return self._generate_reply(headers=headers, payload=None, fl_ctx=fl_ctx) except NotAuthenticated as e: self.logger.error(f"Failed to authenticate the register_client: {secure_format_exception(e)}") @@ -539,6 +540,11 @@ def quit_client(self, request: Message) -> Message: token = client.get_token() self.logout_client(token) + data = request.payload + shared_fl_ctx = data.get_header(ServerCommandKey.PEER_FL_CONTEXT) + fl_ctx.set_peer_context(shared_fl_ctx) + self.engine.fire_event(EventType.CLIENT_QUIT, fl_ctx=fl_ctx) + headers = {CellMessageHeaderKeys.MESSAGE: "Removed client"} return self._generate_reply(headers=headers, payload=None, fl_ctx=fl_ctx) @@ -572,6 +578,11 @@ def client_heartbeat(self, request: Message) -> Message: if error is not None: return make_cellnet_reply(rc=F3ReturnCode.COMM_ERROR, error=error) + data = request.payload + shared_fl_ctx = data.get_header(ServerCommandKey.PEER_FL_CONTEXT) + fl_ctx.set_peer_context(shared_fl_ctx) + self.engine.fire_event(EventType.CLIENT_HEARTBEAT_RECEIVED, fl_ctx=fl_ctx) + token = request.get_header(CellMessageHeaderKeys.TOKEN) client_name = request.get_header(CellMessageHeaderKeys.CLIENT_NAME) @@ -593,6 +604,7 @@ def client_heartbeat(self, request: Message) -> Message: f"These jobs: {display_runs} are not running on the server. " f"Ask client: {client_name} to abort these runs." ) + self.engine.fire_event(EventType.CLIENT_HEARTBEAT_PROCESSED, fl_ctx=fl_ctx) return reply def _sync_client_jobs(self, request, client_token): diff --git a/nvflare/private/fed/server/job_runner.py b/nvflare/private/fed/server/job_runner.py index ae2cb7d568..15a9bea833 100644 --- a/nvflare/private/fed/server/job_runner.py +++ b/nvflare/private/fed/server/job_runner.py @@ -49,7 +49,8 @@ def _send_to_clients(admin_server, client_sites: List[str], engine, message, tim if timeout is None: timeout = admin_server.timeout - replies = admin_server.send_requests(requests, timeout_secs=timeout, optional=optional) + with admin_server.sai.new_context() as fl_ctx: + replies = admin_server.send_requests(requests, fl_ctx, timeout_secs=timeout, optional=optional) return replies @@ -248,7 +249,7 @@ def _start_run(self, job_id: str, job: Job, client_sites: Dict[str, DispatchInfo if err: raise RuntimeError(f"Could not start the server App for job: {job_id}.") - replies = engine.start_client_job(job_id, client_sites) + replies = engine.start_client_job(job_id, client_sites, fl_ctx) client_sites_names = list(client_sites.keys()) check_client_replies(replies=replies, client_sites=client_sites_names, command=f"start job ({job_id})") display_sites = ",".join(client_sites_names) diff --git a/nvflare/private/fed/server/server_engine.py b/nvflare/private/fed/server/server_engine.py index d6b636b575..955861ec66 100644 --- a/nvflare/private/fed/server/server_engine.py +++ b/nvflare/private/fed/server/server_engine.py @@ -721,8 +721,8 @@ def reset_errors(self, job_id) -> str: return f"reset the server error stats for job: {job_id}" - def _send_admin_requests(self, requests, timeout_secs=10) -> List[ClientReply]: - return self.server.admin_server.send_requests(requests, timeout_secs=timeout_secs) + def _send_admin_requests(self, requests, fl_ctx: FLContext, timeout_secs=10) -> List[ClientReply]: + return self.server.admin_server.send_requests(requests, fl_ctx, timeout_secs=timeout_secs) def check_client_resources(self, job: Job, resource_reqs, fl_ctx: FLContext) -> Dict[str, Tuple[bool, str]]: requests = {} @@ -737,7 +737,7 @@ def check_client_resources(self, job: Job, resource_reqs, fl_ctx: FLContext) -> requests.update({client.token: request}) replies = [] if requests: - replies = self._send_admin_requests(requests, 15) + replies = self._send_admin_requests(requests, fl_ctx, 15) result = {} for r in replies: site_name = r.client_name @@ -760,12 +760,14 @@ def check_client_resources(self, job: Job, resource_reqs, fl_ctx: FLContext) -> def _make_message_for_check_resource(self, job, resource_requirements, fl_ctx): request = Message(topic=TrainingTopic.CHECK_RESOURCE, body=resource_requirements) request.set_header(RequestHeader.JOB_ID, job.job_id) + request.set_header(RequestHeader.REQUIRE_AUTHZ, "true") + request.set_header(RequestHeader.ADMIN_COMMAND, AdminCommandNames.CHECK_RESOURCES) set_message_security_data(request, job, fl_ctx) return request def cancel_client_resources( - self, resource_check_results: Dict[str, Tuple[bool, str]], resource_reqs: Dict[str, dict] + self, resource_check_results: Dict[str, Tuple[bool, str]], resource_reqs: Dict[str, dict], fl_ctx: FLContext ): requests = {} for site_name, result in resource_check_results.items(): @@ -778,9 +780,9 @@ def cancel_client_resources( if client: requests.update({client.token: request}) if requests: - _ = self._send_admin_requests(requests) + _ = self._send_admin_requests(requests, fl_ctx) - def start_client_job(self, job_id, client_sites): + def start_client_job(self, job_id, client_sites, fl_ctx: FLContext): requests = {} for site, dispatch_info in client_sites.items(): resource_requirement = dispatch_info.resource_requirements @@ -793,7 +795,7 @@ def start_client_job(self, job_id, client_sites): requests.update({client.token: request}) replies = [] if requests: - replies = self._send_admin_requests(requests, timeout_secs=20) + replies = self._send_admin_requests(requests, fl_ctx, timeout_secs=20) return replies def stop_all_jobs(self): diff --git a/nvflare/private/fed/server/server_json_config.py b/nvflare/private/fed/server/server_json_config.py index c735e82d44..f58cb8a2f2 100644 --- a/nvflare/private/fed/server/server_json_config.py +++ b/nvflare/private/fed/server/server_json_config.py @@ -42,7 +42,6 @@ def __init__(self, id, controller: Controller): """ self.id = id self.controller = controller - self.controller.set_communicator(WFCommServer()) class ServerJsonConfigurator(FedJsonConfigurator): @@ -128,13 +127,13 @@ def process_config_element(self, config_ctx: ConfigContext, node: Node): return if re.search(r"^workflows\.#[0-9]+$", path): - workflow = self.authorize_and_build_component(element, config_ctx, node) - if not isinstance(workflow, Controller): - raise ConfigError('"workflow" must be a Controller object, but got {}'.format(type(workflow))) + controller = self.authorize_and_build_component(element, config_ctx, node) + if not isinstance(controller, Controller): + raise ConfigError('"controller" must be a Controller object, but got {}'.format(type(controller))) cid = element.get("id", None) if not cid: - cid = type(workflow).__name__ + cid = type(controller).__name__ if not isinstance(cid, str): raise ConfigError('"id" must be str but got {}'.format(type(cid))) @@ -145,8 +144,12 @@ def process_config_element(self, config_ctx: ConfigContext, node: Node): if cid in self.components: raise ConfigError('duplicate component id "{}"'.format(cid)) - self.workflows.append(WorkFlow(cid, workflow)) - self.components[cid] = workflow + communicator = WFCommServer() + self.handlers.append(communicator) + controller.set_communicator(communicator) + + self.workflows.append(WorkFlow(cid, controller)) + self.components[cid] = controller return def _get_all_workflows_ids(self): diff --git a/nvflare/private/fed/server/server_runner.py b/nvflare/private/fed/server/server_runner.py index 6b8886be47..b675220978 100644 --- a/nvflare/private/fed/server/server_runner.py +++ b/nvflare/private/fed/server/server_runner.py @@ -126,9 +126,8 @@ def _execute_run(self): fl_ctx.set_prop(FLContextKey.WORKFLOW, wf.id, sticky=True) - wf.controller.initialize(fl_ctx) wf.controller.communicator.initialize_run(fl_ctx) - wf.controller.start_controller(fl_ctx) + wf.controller.initialize(fl_ctx) self.log_info(fl_ctx, "Workflow {} ({}) started".format(wf.id, type(wf.controller))) self.log_debug(fl_ctx, "firing event EventType.START_WORKFLOW") @@ -381,7 +380,10 @@ def handle_dead_job(self, client_name: str, fl_ctx: FLContext): if self.current_wf is None: return - self.current_wf.controller.communicator.handle_dead_job(client_name=client_name, fl_ctx=fl_ctx) + fl_ctx.set_prop(FLContextKey.DEAD_JOB_CLIENT_NAME, client_name) + self.log_debug(fl_ctx, "firing event EventType.JOB_DEAD") + self.fire_event(EventType.JOB_DEAD, fl_ctx) + except Exception as e: self.log_exception( fl_ctx, f"Error processing dead job by workflow {self.current_wf.id}: {secure_format_exception(e)}" diff --git a/tests/unit_test/app_common/job_schedulers/job_scheduler_test.py b/tests/unit_test/app_common/job_schedulers/job_scheduler_test.py index df6c63a912..1f7f9ab48e 100644 --- a/tests/unit_test/app_common/job_schedulers/job_scheduler_test.py +++ b/tests/unit_test/app_common/job_schedulers/job_scheduler_test.py @@ -120,7 +120,7 @@ def persist_components(self, fl_ctx: FLContext, completed: bool): def restore_components(self, snapshot, fl_ctx: FLContext): pass - def start_client_job(self, job_id, client_sites): + def start_client_job(self, job_id, client_sites, fl_ctx: FLContext): pass def check_client_resources( @@ -136,15 +136,15 @@ def get_client_name_from_token(self, token): return self.clients.get(token) def cancel_client_resources( - self, resource_check_results: Dict[str, Tuple[bool, str]], resource_reqs: Dict[str, dict] + self, resource_check_results: Dict[str, Tuple[bool, str]], resource_reqs: Dict[str, dict], fl_ctx: FLContext ): - with self.new_context() as fl_ctx: - for site_name, result in resource_check_results.items(): - check_result, token = result - if check_result and token: - self.clients[site_name].resource_manager.cancel_resources( - resource_requirement=resource_reqs[site_name], token=token, fl_ctx=fl_ctx - ) + # with self.new_context() as fl_ctx: + for site_name, result in resource_check_results.items(): + check_result, token = result + if check_result and token: + self.clients[site_name].resource_manager.cancel_resources( + resource_requirement=resource_reqs[site_name], token=token, fl_ctx=fl_ctx + ) def update_job_run_status(self): pass diff --git a/tests/unit_test/app_opt/confidential_computing/cc_manager_test.py b/tests/unit_test/app_opt/confidential_computing/cc_manager_test.py new file mode 100644 index 0000000000..3db086a277 --- /dev/null +++ b/tests/unit_test/app_opt/confidential_computing/cc_manager_test.py @@ -0,0 +1,133 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest.mock import Mock + +from nvflare.apis.fl_constant import ReservedKey +from nvflare.apis.fl_context import FLContext +from nvflare.apis.server_engine_spec import ServerEngineSpec +from nvflare.app_opt.confidential_computing.cc_manager import ( + CC_INFO, + CC_NAMESPACE, + CC_TOKEN, + CC_TOKEN_VALIDATED, + CC_VERIFICATION_FAILED, + CCManager, +) +from nvflare.app_opt.confidential_computing.tdx_authorizer import TDX_NAMESPACE, TDXAuthorizer + +VALID_TOKEN = "valid_token" +INVALID_TOKEN = "invalid_token" + + +class TestCCManager: + def setup_method(self, method): + issues_conf = [{"issuer_id": "tdx_authorizer", "token_expiration": 250}] + + verify_ids = (["tdx_authorizer"],) + self.cc_manager = CCManager(issues_conf, verify_ids) + + def test_authorizer_setup(self): + + fl_ctx, tdx_authorizer = self._setup_authorizers() + + assert self.cc_manager.cc_issuers == {tdx_authorizer: 250} + assert self.cc_manager.cc_verifiers == {TDX_NAMESPACE: tdx_authorizer} + + def _setup_authorizers(self): + fl_ctx = Mock(spec=FLContext) + fl_ctx.get_identity_name.return_value = "server" + engine = Mock(spec=ServerEngineSpec) + fl_ctx.get_engine.return_value = engine + + tdx_authorizer = Mock(spec=TDXAuthorizer) + tdx_authorizer.get_namespace.return_value = TDX_NAMESPACE + tdx_authorizer.verify = self._verify_token + engine.get_component.return_value = tdx_authorizer + self.cc_manager._setup_cc_authorizers(fl_ctx) + + tdx_authorizer.generate.return_value = VALID_TOKEN + self.cc_manager._generate_tokens(fl_ctx) + + return fl_ctx, tdx_authorizer + + def _verify_token(self, token): + if token == VALID_TOKEN: + return True + else: + return False + + def test_add_client_token(self): + + cc_info1, cc_info2 = self._add_failed_tokens() + + assert self.cc_manager.participant_cc_info["client1"] == cc_info1 + assert self.cc_manager.participant_cc_info["client2"] == cc_info2 + + def _add_failed_tokens(self): + self.cc_manager._verify_running_jobs = Mock() + client_name = "client1" + valid_token = VALID_TOKEN + cc_info1, fl_ctx = self._add_client_token(client_name, valid_token) + self.cc_manager._add_client_token(fl_ctx) + + client_name = "client2" + valid_token = INVALID_TOKEN + cc_info2, fl_ctx = self._add_client_token(client_name, valid_token) + self.cc_manager._add_client_token(fl_ctx) + return cc_info1, cc_info2 + + def test_verification_success(self): + + self._setup_authorizers() + + self.cc_manager._verify_running_jobs = Mock() + + self.cc_manager._verify_running_jobs = Mock() + client_name = "client1" + valid_token = VALID_TOKEN + cc_info1, fl_ctx = self._add_client_token(client_name, valid_token) + self.cc_manager._add_client_token(fl_ctx) + + client_name = "client2" + valid_token = VALID_TOKEN + cc_info2, fl_ctx = self._add_client_token(client_name, valid_token) + self.cc_manager._add_client_token(fl_ctx) + + self.cc_manager._handle_expired_tokens = Mock() + + err, participant_tokens = self.cc_manager._verify_participants(["client1", "client2"]) + + assert not err + + def test_verification_failed(self): + + self._setup_authorizers() + + self.cc_manager._verify_running_jobs = Mock() + self._add_failed_tokens() + self.cc_manager._handle_expired_tokens = Mock() + + err, participant_tokens = self.cc_manager._verify_participants(["client1", "client2"]) + + assert "client2" in err + assert CC_VERIFICATION_FAILED in err + + def _add_client_token(self, client_name, valid_token): + peer_ctx = FLContext() + cc_info = [{CC_TOKEN: valid_token, CC_NAMESPACE: TDX_NAMESPACE, CC_TOKEN_VALIDATED: False}] + peer_ctx.set_prop(CC_INFO, cc_info) + peer_ctx.set_prop(ReservedKey.IDENTITY_NAME, client_name) + fl_ctx = Mock(spec=FLContext) + fl_ctx.get_peer_context.return_value = peer_ctx + return cc_info, fl_ctx diff --git a/tests/unit_test/private/fed/server/fed_server_test.py b/tests/unit_test/private/fed/server/fed_server_test.py index abf8a21974..235cfac0c9 100644 --- a/tests/unit_test/private/fed/server/fed_server_test.py +++ b/tests/unit_test/private/fed/server/fed_server_test.py @@ -16,6 +16,7 @@ import pytest +from nvflare.apis.shareable import Shareable from nvflare.private.defs import CellMessageHeaderKeys, new_cell_message from nvflare.private.fed.server.fed_server import FederatedServer from nvflare.private.fed.server.server_state import ColdState, HotState @@ -46,7 +47,8 @@ def test_heart_beat_abort_jobs(self, server_state, expected): CellMessageHeaderKeys.CLIENT_NAME: "client_name", CellMessageHeaderKeys.PROJECT_NAME: "task_name", CellMessageHeaderKeys.JOB_IDS: ["extra_job"], - } + }, + Shareable(), ) result = server.client_heartbeat(request)