From 75faaec96ec583f641b85cec23c904bb6ea0f4e7 Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Tue, 16 Apr 2024 13:48:37 -0400 Subject: [PATCH 01/14] Added more logging for the job status changing. (#2480) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Added more logging for the job status changing. * Fixed a logging call error. --------- Co-authored-by: Chester Chen <512707+chesterxgchen@users.noreply.github.com> Co-authored-by: Yuan-Ting Hsieh (謝沅廷) --- nvflare/private/fed/server/job_runner.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/nvflare/private/fed/server/job_runner.py b/nvflare/private/fed/server/job_runner.py index 15a9bea833..d6bf91e888 100644 --- a/nvflare/private/fed/server/job_runner.py +++ b/nvflare/private/fed/server/job_runner.py @@ -444,6 +444,7 @@ def run(self, fl_ctx: FLContext): }, fl_ctx, ) + self.log_info(fl_ctx, f"Updated the schedule history of Job: {job_id}") if failed_clients: deployable_clients = {k: v for k, v in client_sites.items() if k not in failed_clients} @@ -465,6 +466,7 @@ def run(self, fl_ctx: FLContext): with self.lock: self.running_jobs[job_id] = ready_job job_manager.set_status(ready_job.job_id, RunStatus.RUNNING, fl_ctx) + self.log_info(fl_ctx, f"Job: {job_id} started to run, status changed to RUNNING.") except Exception as e: if job_id: if job_id in self.running_jobs: From 2b4cf2acdccb3bd27d64924c5f39c8dc50625e59 Mon Sep 17 00:00:00 2001 From: Sean Yang Date: Tue, 16 Apr 2024 13:38:37 -0700 Subject: [PATCH 02/14] Fix update client status (#2508) * check workflow id before updating client status * change order of checks --- nvflare/app_common/ccwf/server_ctl.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/nvflare/app_common/ccwf/server_ctl.py b/nvflare/app_common/ccwf/server_ctl.py index 80933969b5..82f6499183 100644 --- a/nvflare/app_common/ccwf/server_ctl.py +++ b/nvflare/app_common/ccwf/server_ctl.py @@ -440,9 +440,6 @@ def _update_client_status(self, fl_ctx: FLContext): peer_ctx = fl_ctx.get_peer_context() assert isinstance(peer_ctx, FLContext) client_name = peer_ctx.get_identity_name() - if client_name not in self.client_statuses: - self.log_error(fl_ctx, f"received result from unknown client {client_name}!") - return # see whether status is available reports = peer_ctx.get_prop(Constant.STATUS_REPORTS) @@ -454,6 +451,10 @@ def _update_client_status(self, fl_ctx: FLContext): if not my_report: return + if client_name not in self.client_statuses: + self.log_error(fl_ctx, f"received result from unknown client {client_name}!") + return + report = status_report_from_dict(my_report) cs = self.client_statuses[client_name] assert isinstance(cs, ClientStatus) From f2fd48acdf4a050ce518e1dc8f26b55096e0892e Mon Sep 17 00:00:00 2001 From: Isaac Yang Date: Wed, 17 Apr 2024 11:26:29 -0700 Subject: [PATCH 03/14] Add user guide on how to deploy to EKS (#2510) * Add user guide on how to deploy to EKS * Address comments --- docs/real_world_fl.rst | 1 + docs/real_world_fl/kubernetes.rst | 371 ++++++++++++++++++++++++++++++ docs/real_world_fl/overview.rst | 6 + 3 files changed, 378 insertions(+) create mode 100644 docs/real_world_fl/kubernetes.rst diff --git a/docs/real_world_fl.rst b/docs/real_world_fl.rst index d9867adb19..c2b0b78e62 100644 --- a/docs/real_world_fl.rst +++ b/docs/real_world_fl.rst @@ -30,5 +30,6 @@ to see the capabilities of the system and how it can be operated. real_world_fl/job real_world_fl/workspace real_world_fl/cloud_deployment + real_world_fl/kubernetes real_world_fl/notes_on_large_models user_guide/security/identity_security diff --git a/docs/real_world_fl/kubernetes.rst b/docs/real_world_fl/kubernetes.rst new file mode 100644 index 0000000000..8c52bddc8f --- /dev/null +++ b/docs/real_world_fl/kubernetes.rst @@ -0,0 +1,371 @@ +.. _eks_deployment: + +############################################ +Amazon Elastic Kubernetes Service Deployment +############################################ +In this document, we will describe how to run the entire NVIDIA FLARE inside one Amazon Elastic Kubernetes Service (EKS). For information +how to run NVIDIA FLARE inside microk8s (local kubernetes cluster), please refer to :ref:`_helm_chart`. That document describes how to +provision one NVIDIA FLARE system, configure your microk8s cluster, deploy the servers, the overseer and the clients to that cluster, and +control and submit jobs to that NVIDIA FLARE from admin console. + + +Start the EKS +============= +We assume that you have one AWS account which allows you to start one EKS. We also assume you have eksctl, aws and kubectl installed in your local machine. +Note that the versions of those CLI may affect the operations. We suggest keep them updated. + +The first thing is to start the EKS with eksctl. The following is a sample yaml file, ``cluster.yaml``, to create EKS with one command. + +.. code-block:: yaml + + apiVersion: eksctl.io/v1alpha5 + kind: ClusterConfig + + metadata: + name: nvflare-cluster + region: us-west-2 + tags: + project: nvflare + + nodeGroups: + - name: worker-node + instanceType: t3.large + desiredCapacity: 2 + +.. code-block:: shell + + eksctl create cluster -f cluster.yaml + +After this, you will have one cluster with two `t3.large` EC2 nodes. + + +Provision +========= + +With NVIDIA FLARE installed in your local machine, you can create one set of startup kits easily with ``nvflare provision``. If there is a project.yml file +in your current working directory, ``nvflare provision`` will create a workspace directory. If that project.yml file does not exist, ``nvflare provision`` will +create a sample project.yml for you. For simplicity, we suggest you remove/rename any existing project.yml and workspace directory. Then provision the +set of startup kits from scratch. When selecting the sampel project.yml during provisioning time, select non-HA one as most clusters support HA easily. + +After provisioning, you will have a workspace/example_project/prod_00 folder, which includes server, site-1, site-2 and admin@nvidia.com folders. If you +would like to use other names instead of ``site-1``, ``site-2``, etc, you can remove the workspace folder and modify the project.yml file. After that, +you can run ``nvflare provision`` command to get the new set of startup kits. + +Persistent Volume +================= + +EKS provides several ways to create persistent volumes. Before you can use create the volume, +you will need to create one OIDC provider, add one service account and attach a pollicy to two roles, the node instance group and that service account. + +.. code-block:: shell + + eksctl utils associate-iam-oidc-provider --region=us-west-2 --cluster=nvflare-cluster --approve + +.. code-block:: shell + + eksctl create iamserviceaccount \ + --region us-west-2 \ + --name ebs-csi-controller-sa \ + --namespace kube-system \ + --cluster nvflare-cluster \ + --attach-policy-arn arn:aws:iam::aws:policy/service-role/AmazonEBSCSIDriverPolicy \ + --approve \ + --role-only \ + --role-name AmazonEKS_EBS_CSI_DriverRole + + +.. code-block:: shell + + eksctl create addon --name aws-ebs-csi-driver \ + --cluster nvflare-cluster \ + --service-account-role-arn arn:aws:iam::$(aws sts get-caller-identity --query Account --output text):role/AmazonEKS_EBS_CSI_DriverRole \ + --force + +The following is the policy json file that you have to attach to the roles. + +.. code-block:: json + + { + "Version": "2012-10-17", + "Statement": [ + { + "Sid": "Poicly4EKS", + "Effect": "Allow", + "Action": [ + "ec2:DetachVolume", + "ec2:AttachVolume", + "ec2:DeleteVolume", + "ec2:DescribeInstances", + "ec2:DescribeTags", + "ec2:DeleteTags", + "ec2:CreateTags", + "ec2:DescribeVolumes", + "ec2:CreateVolume" + ], + "Resource": [ + "*" + ] + } + ] + } + +The following yaml file will utilize EKS gp2 StorageClass to allocate 5GiByte space. You +can run ``kubectl apply -f volume.yaml`` to make the volume available. + +.. code-block:: yaml + + apiVersion: v1 + kind: PersistentVolumeClaim + metadata: + name: nvflare-pv-claim + labels: + app: nvflare + spec: + accessModes: + - ReadWriteOnce + resources: + requests: + storage: 5Gi + storageClassName: gp2 + +After that, your EKS persistent volme should be waiting for the first claim. + + +Start Helper Pod +================ + +Now you will need to copy your startup kits to your EKS cluster. Those startup kits will copied into the volume you just created. +In order to access the volume, we deploy a helper pod which mounts that persistent volume and use kubectl cp to copy files from your +local machine to the cluster. + +The following is the helper pod yaml file. + +.. code-block:: yaml + + apiVersion: apps/v1 + kind: Deployment + metadata: + labels: + run: bb8 + name: bb8 + spec: + replicas: 1 + selector: + matchLabels: + run: bb8 + template: + metadata: + labels: + run: bb8 + spec: + containers: + - args: + - sleep + - "50000" + image: busybox + name: bb8 + volumeMounts: + - name: nvfl + mountPath: /workspace/nvfl/ + volumes: + - name: nvfl + persistentVolumeClaim: + claimName: nvflare-pv-claim + + +All pods can be deployed with ``kubectl apply -f`` so we just need the following command. + +.. code-block:: shell + + kubectl apply -f bb8.yaml + +Your helper pod should be up and running very soon. Now copy the startup kits to the cluster with + +.. code-block:: shell + + kubectl cp workspace/example_project/prod_00/server :/workspace/nvfl/ + +And the same for site-1, site-2, admin@nvidia.com. + +This will make the entire startup kits available at the nvflare-pv-claim of the cluster so that NVIDIA FLARE system +can mount that nvflare-pv-claim and access the startup kits. + +After copying those folders to nvflare-pv-claim, you can shutdown the helper pod. The nvflare-pv-claim and its contents will stay and is +available to server/client/admin pods. + +Start Server Pod +================ + +The NVIDIA FLARE server consists of two portions for Kubernetes clusters. As you might know, +the server needs computation to handle model updates, aggregations and other operations. It also needs to provide a service for clients and admins +to connect. Therefore, the followings are two separate yaml files that work together to create the NVIDIA FLARE server in EKS. + +.. code-block:: yaml + + apiVersion: apps/v1 + kind: Deployment + metadata: + labels: + run: nvflare + name: nvflare + spec: + replicas: 1 + selector: + matchLabels: + run: nvflare + template: + metadata: + labels: + run: nvflare + spec: + containers: + - args: + - -u + - -m + - nvflare.private.fed.app.server.server_train + - -m + - /workspace/nvfl/server + - -s + - fed_server.json + - --set + - secure_train=true + - config_folder=config + - org=nvidia + command: + - /usr/local/bin/python3 + image: nvflare/nvflare:2.4.0 + imagePullPolicy: Always + name: nvflare + volumeMounts: + - name: nvfl + mountPath: /workspace/nvfl/ + volumes: + - name: nvfl + persistentVolumeClaim: + claimName: nvflare-pv-claim + + +.. code-block:: yaml + + apiVersion: v1 + kind: Service + metadata: + labels: + run: server + name: server + spec: + ports: + - port: 8002 + protocol: TCP + targetPort: 8002 + name: flport + - port: 8003 + protocol: TCP + targetPort: 8003 + name: adminport + selector: + run: nvflare + + +Note that the pod will use nvflare/nvflare:2.4.0 container image from dockerhub.com. This image only includes the necessary dependencies to start +NVIDIA FLARE system. If you require additional dependencies, such as Torch or MONAI, you will need to build and publish your own image and update +the yaml file accordingly. + +Start Client Pods +================= + +For the client pods, we only need one yaml file for eacch client. The following is the deployment yaml file for site-1. + +.. code-block:: yaml + + apiVersion: apps/v1 + kind: Deployment + metadata: + labels: + run: site1 + name: site1 + spec: + replicas: 1 + selector: + matchLabels: + run: site1 + template: + metadata: + labels: + run: site1 + spec: + containers: + - args: + - -u + - -m + - nvflare.private.fed.app.client.client_train + - -m + - /workspace/nvfl/site-1 + - -s + - fed_client.json + - --set + - secure_train=true + - uid=site-1 + - config_folder=config + - org=nvidia + command: + - /usr/local/bin/python3 + image: nvflare/nvflare:2.4.0 + imagePullPolicy: Always + name: site1 + volumeMounts: + - name: nvfl + mountPath: /workspace/nvfl/ + volumes: + - name: nvfl + persistentVolumeClaim: + claimName: nvflare-pv-claim + +Once the client is up and running, you can check the server log with ``kubectl logs`` and the log should show the clients registered. + +Start and Connect to Admin Pods +=============================== + +We can also run the admin console inside the EKS cluster to submit jobs to the NVIDIA FLARE running in the EKS cluster. Start the admin pod +with the following yaml file. + +.. code-block:: yaml + + apiVersion: apps/v1 + kind: Deployment + metadata: + labels: + run: admin + name: admin + spec: + replicas: 1 + selector: + matchLabels: + run: admin + template: + metadata: + labels: + run: admin + spec: + containers: + - args: + - "50000" + command: + - /usr/bin/sleep + image: nvflare/nvflare:2.4.0 + imagePullPolicy: Always + name: admin + volumeMounts: + - name: nvfl + mountPath: /workspace/nvfl/ + volumes: + - name: nvfl + persistentVolumeClaim: + claimName: nvflare-pv-claim + +Once the admin pod is running, you can enter the pod with ``kubectl exec`` , cd to ``/workspace/nvfl/admin@nvidia.com/startup`` and run ``fl_admin.sh``. + + +Note that you need to copy the job from your local machine to the EKS cluster so that the ``transfer`` directory of admin@nvidia.com contains the jobs +you would like to run in that EKS cluster. + diff --git a/docs/real_world_fl/overview.rst b/docs/real_world_fl/overview.rst index 8f0572c1ec..23d494d0b1 100644 --- a/docs/real_world_fl/overview.rst +++ b/docs/real_world_fl/overview.rst @@ -159,6 +159,12 @@ See how to deploy to Azure and AWS clouds can be found in :ref:`cloud_deployment Deploy to Google Cloud will be made available in a future release. +Kubernetes Deployment +===================== +As mentioned above, you can run NVIDIA FLARE in the public cloud. If you prefer to deploy NVIDIA FLARE in Amazon Elastic Kubernetes Service (EKS), +you can find the deployment guide in :ref:`eks_deployment`. + + Starting Federated Learning Servers ============================================= The FL Server will coordinate the federated learning training and be the main hub all clients and admin From f948b6ecc783526516b5a13825ff047169de32b8 Mon Sep 17 00:00:00 2001 From: Yan Cheng <58191769+yanchengnv@users.noreply.github.com> Date: Thu, 18 Apr 2024 13:10:42 -0400 Subject: [PATCH 04/14] Improve dead client handling (#2506) * dev * test dead client cmd * added more info for dead client tracing * remove unused imports * fix unit test * fix test case * address PR comments --------- Co-authored-by: Sean Yang --- nvflare/apis/controller_spec.py | 11 + nvflare/apis/event_type.py | 3 +- nvflare/apis/fl_constant.py | 10 +- nvflare/apis/impl/controller.py | 13 + nvflare/apis/impl/wf_comm_server.py | 313 +++++++++++------- nvflare/apis/wf_comm_spec.py | 31 ++ nvflare/app_common/workflows/cyclic_ctl.py | 24 +- nvflare/private/fed/client/client_runner.py | 2 +- nvflare/private/fed/server/fed_server.py | 35 +- nvflare/private/fed/server/server_commands.py | 2 + nvflare/private/fed/server/server_engine.py | 17 +- nvflare/private/fed/server/server_runner.py | 19 +- nvflare/private/fed/server/sys_cmd.py | 18 + 13 files changed, 326 insertions(+), 172 deletions(-) diff --git a/nvflare/apis/controller_spec.py b/nvflare/apis/controller_spec.py index f0019da206..2f18f95623 100644 --- a/nvflare/apis/controller_spec.py +++ b/nvflare/apis/controller_spec.py @@ -542,3 +542,14 @@ def cancel_all_tasks(self, completion_status=TaskCompletionStatus.CANCELLED, fl_ fl_ctx: the FL context """ pass + + def get_client_disconnect_time(self, client_name): + """Get the time that the client is deemed disconnected. + + Args: + client_name: the name of the client + + Returns: time at which the client was deemed disconnected; or None if the client is not disconnected. + + """ + return None diff --git a/nvflare/apis/event_type.py b/nvflare/apis/event_type.py index e2448da113..5bdbe50bc8 100644 --- a/nvflare/apis/event_type.py +++ b/nvflare/apis/event_type.py @@ -34,7 +34,8 @@ class EventType(object): JOB_COMPLETED = "_job_completed" JOB_ABORTED = "_job_aborted" JOB_CANCELLED = "_job_cancelled" - JOB_DEAD = "_job_dead" + CLIENT_DISCONNECTED = "_client_disconnected" + CLIENT_RECONNECTED = "_client_reconnected" BEFORE_PULL_TASK = "_before_pull_task" AFTER_PULL_TASK = "_after_pull_task" diff --git a/nvflare/apis/fl_constant.py b/nvflare/apis/fl_constant.py index 4752eb4a29..71ae9c0c45 100644 --- a/nvflare/apis/fl_constant.py +++ b/nvflare/apis/fl_constant.py @@ -159,7 +159,8 @@ class FLContextKey(object): CLIENT_TOKEN = "__client_token" AUTHORIZATION_RESULT = "_authorization_result" AUTHORIZATION_REASON = "_authorization_reason" - DEAD_JOB_CLIENT_NAME = "_dead_job_client_name" + DISCONNECTED_CLIENT_NAME = "_disconnected_client_name" + RECONNECTED_CLIENT_NAME = "_reconnected_client_name" CLIENT_REGISTER_DATA = "_client_register_data" SECURITY_ITEMS = "_security_items" @@ -263,6 +264,7 @@ class ServerCommandKey(object): CLIENTS = "clients" COLLECTOR = "collector" TURN_TO_COLD = "__turn_to_cold__" + REASON = "reason" class FedEventHeader(object): @@ -464,6 +466,12 @@ class ConfigVarName: # client and server: query interval for reliable message RM_QUERY_INTERVAL = "rm_query_interval" + # server: wait this long since client death report before treating the client as dead/disconnected + DEAD_CLIENT_GRACE_PERIOD = "dead_client_grace_period" + + # server: wait this long since job schedule time before starting to check dead/disconnected clients + DEAD_CLIENT_CHECK_LEAD_TIME = "dead_client_check_lead_time" + class SystemVarName: """ diff --git a/nvflare/apis/impl/controller.py b/nvflare/apis/impl/controller.py index 6091ac23f1..924512f77b 100644 --- a/nvflare/apis/impl/controller.py +++ b/nvflare/apis/impl/controller.py @@ -145,3 +145,16 @@ def cancel_task( def cancel_all_tasks(self, completion_status=TaskCompletionStatus.CANCELLED, fl_ctx: Optional[FLContext] = None): self.communicator.cancel_all_tasks(completion_status, fl_ctx) + + def get_client_disconnect_time(self, client_name): + """Get the time when the client is deemed disconnected. + + Args: + client_name: the name of the client + + Returns: time at which the client was deemed disconnected; or None if the client is not disconnected. + + """ + if not self.communicator: + return None + return self.communicator.get_client_disconnect_time(client_name) diff --git a/nvflare/apis/impl/wf_comm_server.py b/nvflare/apis/impl/wf_comm_server.py index 06b5d13457..679efe643d 100644 --- a/nvflare/apis/impl/wf_comm_server.py +++ b/nvflare/apis/impl/wf_comm_server.py @@ -20,7 +20,7 @@ from nvflare.apis.controller_spec import ClientTask, SendOrder, Task, TaskCompletionStatus from nvflare.apis.event_type import EventType from nvflare.apis.fl_component import FLComponent -from nvflare.apis.fl_constant import FLContextKey +from nvflare.apis.fl_constant import ConfigVarName, FLContextKey, SystemConfigs from nvflare.apis.fl_context import FLContext from nvflare.apis.job_def import job_from_meta from nvflare.apis.shareable import ReservedHeaderKey, Shareable, make_copy @@ -40,12 +40,6 @@ _TASK_KEY_MANAGER = "___mgr" _TASK_KEY_DONE = "___done" -# wait this long since client death report before treating the client as dead -_CONFIG_VAR_DEAD_CLIENT_GRACE_PERIOD = "dead_client_grace_period" - -# wait this long since job schedule time before starting to check dead clients -_CONFIG_VAR_DEAD_CLIENT_CHECK_LEAD_TIME = "dead_client_check_lead_time" - def _check_positive_int(name, value): if not isinstance(value, int): @@ -79,6 +73,12 @@ def _get_client_task(target, task: Task): return None +class _DeadClientStatus: + def __init__(self): + self.report_time = time.time() + self.disconnect_time = None + + class WFCommServer(FLComponent, WFCommSpec): def __init__(self, task_check_period=0.2): """Manage life cycles of tasks and their destinations. @@ -93,9 +93,10 @@ def __init__(self, task_check_period=0.2): self._client_task_map = {} # client_task_id => client_task self._all_done = False self._task_lock = Lock() - self._task_monitor = threading.Thread(target=self._monitor_tasks, args=()) + self._task_monitor = threading.Thread(target=self._monitor_tasks, args=(), daemon=True) self._task_check_period = task_check_period - self._dead_client_reports = {} # clients that reported the job is dead on it: name => report time + self._dead_client_grace = 60.0 + self._dead_clients = {} # clients reported dead: name => _DeadClientStatus self._dead_clients_lock = Lock() # need lock since dead_clients can be modified from different threads # make sure check_tasks, process_task_request, process_submission does not interfere with each other self._controller_lock = Lock() @@ -112,13 +113,16 @@ def initialize_run(self, fl_ctx: FLContext): """ engine = fl_ctx.get_engine() if not engine: - self.system_panic(f"Engine not found. {self.__class__.__name__} exiting.", fl_ctx) + self.system_panic(f"Engine not found. {self.name} exiting.", fl_ctx) return self._engine = engine + self._dead_client_grace = ConfigService.get_float_var( + name=ConfigVarName.DEAD_CLIENT_GRACE_PERIOD, conf=SystemConfigs.APPLICATION_CONF, default=60.0 + ) self._task_monitor.start() - def _try_again(self) -> Tuple[str, str, Shareable]: + def _try_again(self) -> Tuple[str, str, Optional[Shareable]]: # TODO: how to tell client no shareable available now? return "", "", None @@ -135,7 +139,7 @@ def _set_stats(self, fl_ctx: FLContext): "collector must be an instance of GroupInfoCollector, but got {}".format(type(collector)) ) collector.add_info( - group_name=self.controller._name, + group_name=self.name, info={ "tasks": {t.name: [ct.client.name for ct in t.client_tasks] for t in self._tasks}, }, @@ -150,12 +154,15 @@ def handle_event(self, event_type: str, fl_ctx: FLContext): """ if event_type == InfoCollector.EVENT_TYPE_GET_STATS: self._set_stats(fl_ctx) - elif event_type == EventType.JOB_DEAD: - client_name = fl_ctx.get_prop(FLContextKey.DEAD_JOB_CLIENT_NAME) - with self._dead_clients_lock: - self.log_info(fl_ctx, f"received dead job report from client {client_name}") - if not self._dead_client_reports.get(client_name): - self._dead_client_reports[client_name] = time.time() + + def process_dead_client_report(self, client_name: str, fl_ctx: FLContext): + with self._dead_clients_lock: + self.log_warning(fl_ctx, f"received dead job report for client {client_name}") + if not self._dead_clients.get(client_name): + self.log_warning(fl_ctx, f"client {client_name} is placed on dead client watch list") + self._dead_clients[client_name] = _DeadClientStatus() + else: + self.log_warning(fl_ctx, f"discarded dead client report {client_name=}: already on watch list") def process_task_request(self, client: Client, fl_ctx: FLContext) -> Tuple[str, str, Shareable]: """Called by runner when a client asks for a task. @@ -183,9 +190,6 @@ def _do_process_task_request(self, client: Client, fl_ctx: FLContext) -> Tuple[s if not isinstance(client, Client): raise TypeError("client must be an instance of Client, but got {}".format(type(client))) - with self._dead_clients_lock: - self._dead_client_reports.pop(client.name, None) - if not isinstance(fl_ctx, FLContext): raise TypeError("fl_ctx must be an instance of FLContext, but got {}".format(type(fl_ctx))) @@ -371,13 +375,6 @@ def _do_process_submission( if not isinstance(client, Client): raise TypeError("client must be an instance of Client, but got {}".format(type(client))) - # reset the dead job report! - # note that due to potential race conditions, a client may fail to include the job id in its - # heartbeat (since the job hasn't started at the time of heartbeat report), but then includes - # the job ID later. - with self._dead_clients_lock: - self._dead_client_reports.pop(client.name, None) - if not isinstance(fl_ctx, FLContext): raise TypeError("fl_ctx must be an instance of FLContext, but got {}".format(type(fl_ctx))) if not isinstance(result, Shareable): @@ -490,18 +487,23 @@ def broadcast( ): """Schedule a broadcast task. This is a non-blocking call. - The task is scheduled into a task list. Clients can request tasks and controller will dispatch the task to eligible clients. + The task is scheduled into a task list. + Clients can request tasks and controller will dispatch the task to eligible clients. Args: task (Task): the task to be scheduled fl_ctx (FLContext): FLContext associated with this task - targets (Union[List[Client], List[str], None], optional): the list of eligible clients or client names or None (all clients). Defaults to None. - min_responses (int, optional): the condition to mark this task as completed because enough clients respond with submission. Defaults to 1. - wait_time_after_min_received (int, optional): a grace period for late clients to contribute their submission. 0 means no grace period. + targets (Union[List[Client], List[str], None], optional): the list of eligible clients or client names + or None (all clients). Defaults to None. + min_responses (int, optional): the condition to mark this task as completed because enough clients + respond with submission. Defaults to 1. + wait_time_after_min_received (int, optional): a grace period for late clients to contribute their + submission. 0 means no grace period. Submission of late clients in the grace period are still collected as valid submission. Defaults to 0. Raises: - ValueError: min_responses is greater than the length of targets since this condition will make the task, if allowed to be scheduled, never exit. + ValueError: min_responses is greater than the length of targets since this condition will make the task, + if allowed to be scheduled, never exit. """ _check_inputs(task=task, fl_ctx=fl_ctx, targets=targets) _check_positive_int("min_responses", min_responses) @@ -527,16 +529,21 @@ def broadcast_and_wait( ): """Schedule a broadcast task. This is a blocking call. - The task is scheduled into a task list. Clients can request tasks and controller will dispatch the task to eligible clients. + The task is scheduled into a task list. Clients can request tasks and controller will dispatch the task + to eligible clients. Args: task (Task): the task to be scheduled fl_ctx (FLContext): FLContext associated with this task - targets (Union[List[Client], List[str], None], optional): the list of eligible clients or client names or None (all clients). Defaults to None. - min_responses (int, optional): the condition to mark this task as completed because enough clients respond with submission. Defaults to 1. - wait_time_after_min_received (int, optional): a grace period for late clients to contribute their submission. 0 means no grace period. + targets (Union[List[Client], List[str], None], optional): the list of eligible clients or client names + or None (all clients). Defaults to None. + min_responses (int, optional): the condition to mark this task as completed because enough clients + respond with submission. Defaults to 1. + wait_time_after_min_received (int, optional): a grace period for late clients to contribute their + submission. 0 means no grace period. Submission of late clients in the grace period are still collected as valid submission. Defaults to 0. - abort_signal (Optional[Signal], optional): as this is a blocking call, this abort_signal informs this method to return. Defaults to None. + abort_signal (Optional[Signal], optional): as this is a blocking call, this abort_signal informs + this method to return. Defaults to None. """ self.broadcast( task=task, @@ -550,13 +557,16 @@ def broadcast_and_wait( def broadcast_forever(self, task: Task, fl_ctx: FLContext, targets: Union[List[Client], List[str], None] = None): """Schedule a broadcast task. This is a non-blocking call. - The task is scheduled into a task list. Clients can request tasks and controller will dispatch the task to eligible clients. + The task is scheduled into a task list. Clients can request tasks and controller will dispatch + the task to eligible clients. + This broadcast will not end. Args: task (Task): the task to be scheduled fl_ctx (FLContext): FLContext associated with this task - targets (Union[List[Client], List[str], None], optional): the list of eligible clients or client names or None (all clients). Defaults to None. + targets (Union[List[Client], List[str], None], optional): the list of eligible clients or client names + or None (all clients). Defaults to None. """ _check_inputs(task=task, fl_ctx=fl_ctx, targets=targets) manager = BcastForeverTaskManager() @@ -572,14 +582,18 @@ def send( ): """Schedule a single task to targets. This is a non-blocking call. - The task is scheduled into a task list. Clients can request tasks and controller will dispatch the task to eligible clients based on the send_order. + The task is scheduled into a task list. Clients can request tasks and controller will dispatch the task + to eligible clients based on the send_order. Args: task (Task): the task to be scheduled fl_ctx (FLContext): FLContext associated with this task - targets (Union[List[Client], List[str], None], optional): the list of eligible clients or client names or None (all clients). Defaults to None. - send_order (SendOrder, optional): the order for clients to become eligible. SEQUENTIAL means the order in targets is enforced. ANY means - clients in targets and haven't received task are eligible for task. Defaults to SendOrder.SEQUENTIAL. + targets (Union[List[Client], List[str], None], optional): the list of eligible clients or client names + or None (all clients). Defaults to None. + send_order (SendOrder, optional): the order for clients to become eligible. + SEQUENTIAL means the order in targets is enforced. + ANY means clients in targets and haven't received task are eligible for task. + Defaults to SendOrder.SEQUENTIAL. task_assignment_timeout (int, optional): how long to wait for one client to pick the task. Defaults to 0. Raises: @@ -625,16 +639,21 @@ def send_and_wait( ): """Schedule a single task to targets. This is a blocking call. - The task is scheduled into a task list. Clients can request tasks and controller will dispatch the task to eligible clients based on the send_order. + The task is scheduled into a task list. Clients can request tasks and controller will dispatch the task + to eligible clients based on the send_order. Args: task (Task): the task to be scheduled fl_ctx (FLContext): FLContext associated with this task - targets (Union[List[Client], List[str], None], optional): the list of eligible clients or client names or None (all clients). Defaults to None. - send_order (SendOrder, optional): the order for clients to become eligible. SEQUENTIAL means the order in targets is enforced. ANY means - clients in targets and haven't received task are eligible for task. Defaults to SendOrder.SEQUENTIAL. + targets (Union[List[Client], List[str], None], optional): the list of eligible clients or client names + or None (all clients). Defaults to None. + send_order (SendOrder, optional): the order for clients to become eligible. + SEQUENTIAL means the order in targets is enforced. + ANY means clients in targets and haven't received task are eligible for task. + Defaults to SendOrder.SEQUENTIAL. task_assignment_timeout (int, optional): how long to wait for one client to pick the task. Defaults to 0. - abort_signal (Optional[Signal], optional): as this is a blocking call, this abort_signal informs this method to return. Defaults to None. + abort_signal (Optional[Signal], optional): as this is a blocking call, this abort_signal informs this + method to return. Defaults to None. """ self.send( @@ -663,11 +682,13 @@ def cancel_task( note:: - We only mark the task as completed and leave it to the task monitor to clean up. This is to avoid potential deadlock of task_lock. + We only mark the task as completed and leave it to the task monitor to clean up. + This is to avoid potential deadlock of task_lock. Args: task (Task): the task to be cancelled - completion_status (str, optional): the completion status for this cancellation. Defaults to TaskCompletionStatus.CANCELLED. + completion_status (str, optional): the completion status for this cancellation. + Defaults to TaskCompletionStatus.CANCELLED. fl_ctx (Optional[FLContext], optional): FLContext associated with this cancellation. Defaults to None. """ task.completion_status = completion_status @@ -676,7 +697,8 @@ def cancel_all_tasks(self, completion_status=TaskCompletionStatus.CANCELLED, fl_ """Cancel all standing tasks in this controller. Args: - completion_status (str, optional): the completion status for this cancellation. Defaults to TaskCompletionStatus.CANCELLED. + completion_status (str, optional): the completion status for this cancellation. + Defaults to TaskCompletionStatus.CANCELLED. fl_ctx (Optional[FLContext], optional): FLContext associated with this cancellation. Defaults to None. """ with self._task_lock: @@ -695,11 +717,6 @@ def finalize_run(self, fl_ctx: FLContext): """ self.cancel_all_tasks() # unconditionally cancel all tasks self._all_done = True - try: - if self._task_monitor.is_alive(): - self._task_monitor.join() - except RuntimeError: - self.log_debug(fl_ctx, "unable to join monitor thread (not started?)") def relay( self, @@ -713,18 +730,23 @@ def relay( ): """Schedule a single task to targets in one-after-another style. This is a non-blocking call. - The task is scheduled into a task list. Clients can request tasks and controller will dispatch the task to eligible clients based on the send_order. + The task is scheduled into a task list. Clients can request tasks and controller will dispatch the task + to eligible clients based on the send_order. Args: task (Task): the task to be scheduled fl_ctx (FLContext): FLContext associated with this task - targets (Union[List[Client], List[str], None], optional): the list of eligible clients or client names or None (all clients). Defaults to None. + targets (Union[List[Client], List[str], None], optional): the list of eligible clients or client names + or None (all clients). Defaults to None. send_order (SendOrder, optional): the order for clients to become eligible. SEQUENTIAL means the order in targets is enforced. - ANY means any clients that are inside the targets and haven't received the task are eligible. Defaults to SendOrder.SEQUENTIAL. + ANY means any clients that are inside the targets and haven't received the task are eligible. + Defaults to SendOrder.SEQUENTIAL. task_assignment_timeout (int, optional): how long to wait for one client to pick the task. Defaults to 0. - task_result_timeout (int, optional): how long to wait for current working client to reply its result. Defaults to 0. - dynamic_targets (bool, optional): allow clients not in targets to join at the end of targets list. Defaults to True. + task_result_timeout (int, optional): how long to wait for current working client to reply its result. + Defaults to 0. + dynamic_targets (bool, optional): allow clients not in targets to join at the end of targets list. + Defaults to True. Raises: ValueError: when task_assignment_timeout is greater than task's timeout @@ -792,18 +814,25 @@ def relay_and_wait( ): """Schedule a single task to targets in one-after-another style. This is a blocking call. - The task is scheduled into a task list. Clients can request tasks and controller will dispatch the task to eligible clients based on the send_order. + The task is scheduled into a task list. Clients can request tasks and controller will dispatch the task + to eligible clients based on the send_order. Args: task (Task): the task to be scheduled fl_ctx (FLContext): FLContext associated with this task - targets (Union[List[Client], List[str], None], optional): the list of eligible clients or client names or None (all clients). Defaults to None. - send_order (SendOrder, optional): the order for clients to become eligible. SEQUENTIAL means the order in targets is enforced. ANY means - clients in targets and haven't received task are eligible for task. Defaults to SendOrder.SEQUENTIAL. + targets (Union[List[Client], List[str], None], optional): the list of eligible clients or client names + or None (all clients). Defaults to None. + send_order (SendOrder, optional): the order for clients to become eligible. + SEQUENTIAL means the order in targets is enforced. + ANY means clients in targets and haven't received task are eligible for task. + Defaults to SendOrder.SEQUENTIAL. task_assignment_timeout (int, optional): how long to wait for one client to pick the task. Defaults to 0. - task_result_timeout (int, optional): how long to wait for current working client to reply its result. Defaults to 0. - dynamic_targets (bool, optional): allow clients not in targets to join at the end of targets list. Defaults to True. - abort_signal (Optional[Signal], optional): as this is a blocking call, this abort_signal informs this method to return. Defaults to None. + task_result_timeout (int, optional): how long to wait for current working client to reply its result. + Defaults to 0. + dynamic_targets (bool, optional): allow clients not in targets to join at the end of targets list. + Defaults to True. + abort_signal (Optional[Signal], optional): as this is a blocking call, this abort_signal informs + this method to return. Defaults to None. """ self.relay( task=task, @@ -816,8 +845,33 @@ def relay_and_wait( ) self.wait_for_task(task, abort_signal) + def _check_dead_clients(self): + if not self._dead_clients: + return + + now = time.time() + with self._dead_clients_lock: + for client_name, status in self._dead_clients.items(): + if status.disconnect_time: + # already disconnected + continue + + if now - status.report_time < self._dead_client_grace: + # this report is still fresh - consider the client to be still alive + continue + + # consider client disconnected + status.disconnect_time = now + self.logger.error(f"Client {client_name} is deemed disconnected!") + with self._engine.new_context() as fl_ctx: + fl_ctx.set_prop(FLContextKey.DISCONNECTED_CLIENT_NAME, client_name) + self.fire_event(EventType.CLIENT_DISCONNECTED, fl_ctx) + def _monitor_tasks(self): while not self._all_done: + # determine clients are still active or not + self._check_dead_clients() + should_abort_job = self._job_policy_violated() if not should_abort_job: self.check_tasks() @@ -907,29 +961,30 @@ def _get_task_dead_clients(self, task: Task): See whether the task is only waiting for response from a dead client """ now = time.time() - lead_time = ConfigService.get_float_var(name=_CONFIG_VAR_DEAD_CLIENT_CHECK_LEAD_TIME, default=30.0) + lead_time = ConfigService.get_float_var( + name=ConfigVarName.DEAD_CLIENT_CHECK_LEAD_TIME, conf=SystemConfigs.APPLICATION_CONF, default=30.0 + ) if now - task.schedule_time < lead_time: # due to potential race conditions, we'll wait for at least 1 minute after the task # is started before checking dead clients. return None dead_clients = [] - with self._dead_clients_lock: - for target in task.targets: - ct = _get_client_task(target, task) - if ct is not None and ct.result_received_time: - # response has been received from this client - continue - - # either we have not sent the task to this client or we have not received response - # is the client already dead? - if self._client_still_alive(target): - # this client is still alive - # we let the task continue its course since we still have live clients - return None - else: - # this client is dead - remember it - dead_clients.append(target) + for target in task.targets: + ct = _get_client_task(target, task) + if ct is not None and ct.result_received_time: + # response has been received from this client + continue + + # either we have not sent the task to this client or we have not received response + # is the client already dead? + if self.get_client_disconnect_time(target): + # this client is dead - remember it + dead_clients.append(target) + else: + # this client is still alive + # we let the task continue its course since we still have live clients + return None return dead_clients @@ -964,47 +1019,57 @@ def _job_policy_violated(self): with self._engine.new_context() as fl_ctx: clients = self._engine.get_clients() - with self._dead_clients_lock: - alive_clients = [] - dead_clients = [] - - for client in clients: - if self._client_still_alive(client.name): - alive_clients.append(client.name) - else: - dead_clients.append(client.name) - - if not dead_clients: - return False - - if not alive_clients: - self.log_error(fl_ctx, f"All clients are dead: {dead_clients}") - return True + alive_clients = [] + dead_clients = [] - job_meta = fl_ctx.get_prop(FLContextKey.JOB_META) - job = job_from_meta(job_meta) - if len(alive_clients) < job.min_sites: - self.log_error(fl_ctx, f"Alive clients {len(alive_clients)} < required min {job.min_sites}") + for client in clients: + if self.get_client_disconnect_time(client.name): + dead_clients.append(client.name) + else: + alive_clients.append(client.name) + + if not dead_clients: + return False + + if not alive_clients: + self.log_error(fl_ctx, f"All clients are dead: {dead_clients}") + return True + + job_meta = fl_ctx.get_prop(FLContextKey.JOB_META) + job = job_from_meta(job_meta) + if len(alive_clients) < job.min_sites: + self.log_error(fl_ctx, f"Alive clients {len(alive_clients)} < required min {job.min_sites}") + return True + + # check required clients: + if dead_clients and job.required_sites: + dead_required_clients = [c for c in dead_clients if c in job.required_sites] + if dead_required_clients: + self.log_error(fl_ctx, f"Required client(s) dead: {dead_required_clients}") return True - - # check required clients: - if dead_clients and job.required_sites: - dead_required_clients = [c for c in dead_clients if c in job.required_sites] - if dead_required_clients: - self.log_error(fl_ctx, f"Required client(s) dead: {dead_required_clients}") - return True return False - def _client_still_alive(self, client_name): - now = time.time() - report_time = self._dead_client_reports.get(client_name, None) - grace_period = ConfigService.get_float_var(name=_CONFIG_VAR_DEAD_CLIENT_GRACE_PERIOD, default=30.0) + def client_is_active(self, client_name: str, reason: str, fl_ctx: FLContext): + with self._dead_clients_lock: + self.log_debug(fl_ctx, f"client {client_name} is active: {reason}") + if client_name in self._dead_clients: + self.log_info(fl_ctx, f"Client {client_name} is removed from watch list: {reason}") + status = self._dead_clients.pop(client_name) + if status.disconnect_time: + self.log_info(fl_ctx, f"Client {client_name} is reconnected") + fl_ctx.set_prop(FLContextKey.RECONNECTED_CLIENT_NAME, client_name) + self.fire_event(EventType.CLIENT_RECONNECTED, fl_ctx) + + def get_client_disconnect_time(self, client_name: str): + """Get the time that the client was deemed disconnected - if not report_time: - # this client is still alive - return True - elif now - report_time < grace_period: - # this report is still fresh - consider the client to be still alive - return True + Args: + client_name: name of the client - return False + Returns: time at which the client was deemed disconnected; or None if the client is not disconnected + + """ + status = self._dead_clients.get(client_name) + if status: + return status.disconnect_time + return None diff --git a/nvflare/apis/wf_comm_spec.py b/nvflare/apis/wf_comm_spec.py index a32c34948a..0ac504f9a1 100644 --- a/nvflare/apis/wf_comm_spec.py +++ b/nvflare/apis/wf_comm_spec.py @@ -284,6 +284,37 @@ def process_submission(self, client: Client, task_name: str, task_id: str, resul """ raise NotImplementedError + def get_client_disconnect_time(self, client_name): + """Get the time that the client is deemed disconnected. + + Args: + client_name: the name of the client + + Returns: time at which the client was deemed disconnected; or None if the client is not disconnected. + + """ + raise NotImplementedError + + def process_dead_client_report(self, client_name: str, fl_ctx: FLContext): + """Called by the Engine to process dead client report. + + Args: + client_name: name of the client that dead report is received + fl_ctx: the FLContext + + """ + raise NotImplementedError + + def client_is_active(self, client_name: str, reason: str, fl_ctx: FLContext): + """Called by the Engine to notify us that the client is active . + + Args: + client_name: name of the client that is active + reason: why client is considered active + fl_ctx: the FLContext + """ + raise NotImplementedError + def process_task_check(self, task_id: str, fl_ctx: FLContext): """Called by the Engine to check whether a specified task still exists. Args: diff --git a/nvflare/app_common/workflows/cyclic_ctl.py b/nvflare/app_common/workflows/cyclic_ctl.py index 442aaa89be..034a1a7d11 100644 --- a/nvflare/app_common/workflows/cyclic_ctl.py +++ b/nvflare/app_common/workflows/cyclic_ctl.py @@ -18,8 +18,7 @@ from nvflare.apis.client import Client from nvflare.apis.controller_spec import ClientTask, Task -from nvflare.apis.event_type import EventType -from nvflare.apis.fl_constant import FLContextKey, ReturnCode +from nvflare.apis.fl_constant import ReturnCode from nvflare.apis.fl_context import FLContext from nvflare.apis.impl.controller import Controller from nvflare.apis.shareable import Shareable @@ -141,20 +140,24 @@ def start_controller(self, fl_ctx: FLContext): self._last_client = None def _get_relay_orders(self, fl_ctx: FLContext) -> Union[List[Client], None]: - if len(self._participating_clients) <= 1: - self.system_panic(f"Not enough client sites ({len(self._participating_clients)}).", fl_ctx) + active_clients_map = {} + for t in self._participating_clients: + if not self.get_client_disconnect_time(t.name): + active_clients_map[t.name] = t + + if len(active_clients_map) <= 1: + self.system_panic(f"Not enough active client sites ({len(active_clients_map)}).", fl_ctx) return None if isinstance(self._order, list): targets = [] - active_clients_map = {t.name: t for t in self._participating_clients} for c_name in self._order: if c_name not in active_clients_map: self.system_panic(f"Required client site ({c_name}) is not in active clients.", fl_ctx) return None targets.append(active_clients_map[c_name]) else: - targets = list(self._participating_clients) + targets = list(active_clients_map.values()) if self._order == RelayOrder.RANDOM or self._order == RelayOrder.RANDOM_WITHOUT_SAME_IN_A_ROW: random.shuffle(targets) if self._order == RelayOrder.RANDOM_WITHOUT_SAME_IN_A_ROW and self._last_client == targets[0]: @@ -310,12 +313,3 @@ def restore(self, state_data: dict, fl_ctx: FLContext): self._start_round = self._current_round finally: pass - - def handle_event(self, event_type, fl_ctx): - if event_type == EventType.JOB_DEAD: - client_name = fl_ctx.get_prop(FLContextKey.DEAD_JOB_CLIENT_NAME) - new_client_list = [] - for client in self._participating_clients: - if client_name != client.name: - new_client_list.append(client) - self._participating_clients = new_client_list diff --git a/nvflare/private/fed/client/client_runner.py b/nvflare/private/fed/client/client_runner.py index b57f2f68fd..365aefef3f 100644 --- a/nvflare/private/fed/client/client_runner.py +++ b/nvflare/private/fed/client/client_runner.py @@ -145,7 +145,7 @@ def __init__( self.task_check_timeout = self.get_positive_float_var(ConfigVarName.TASK_CHECK_TIMEOUT, 5.0) self.task_check_interval = self.get_positive_float_var(ConfigVarName.TASK_CHECK_INTERVAL, 5.0) - self.job_heartbeat_interval = self.get_positive_float_var(ConfigVarName.JOB_HEARTBEAT_INTERVAL, 30.0) + self.job_heartbeat_interval = self.get_positive_float_var(ConfigVarName.JOB_HEARTBEAT_INTERVAL, 10.0) self.get_task_timeout = self.get_positive_float_var(ConfigVarName.GET_TASK_TIMEOUT, None) self.submit_task_result_timeout = self.get_positive_float_var(ConfigVarName.SUBMIT_TASK_RESULT_TIMEOUT, None) self._register_aux_message_handlers(engine) diff --git a/nvflare/private/fed/server/fed_server.py b/nvflare/private/fed/server/fed_server.py index ae50e4c3a6..2f6334b470 100644 --- a/nvflare/private/fed/server/fed_server.py +++ b/nvflare/private/fed/server/fed_server.py @@ -36,7 +36,6 @@ ) from nvflare.apis.fl_context import FLContext from nvflare.apis.fl_exception import NotAuthenticated -from nvflare.apis.shareable import Shareable from nvflare.apis.workspace import Workspace from nvflare.fuel.common.exit_codes import ProcessExitCode from nvflare.fuel.f3.cellnet.cell import Cell @@ -140,13 +139,8 @@ def close(self): self.lock.release() except RuntimeError: self.logger.info("canceling sync locks") - try: - # if self.cell: - # self.cell.stop() - pass - finally: - self.logger.info("server off") - return 0 + self.logger.info("server off") + return 0 def deploy(self, args, grpc_args=None, secure_train=False): """Start a grpc server and listening the designated port.""" @@ -624,26 +618,17 @@ def _sync_client_jobs(self, request, client_token): # this is a dict: token => nvflare.apis.client.Client client = participating_clients.get(client_token, None) if client: - self._notify_dead_job(client, job_id) + self._notify_dead_job(client, job_id, "missing job on client") return jobs_need_abort - def _notify_dead_job(self, client, job_id: str): + def _notify_dead_job(self, client, job_id: str, reason: str): try: - with self.engine.lock: - shareable = Shareable() - shareable.set_header(ServerCommandKey.FL_CLIENT, client.name) - fqcn = FQCN.join([FQCN.ROOT_SERVER, job_id]) - request = new_cell_message({}, shareable) - self.cell.fire_and_forget( - targets=fqcn, - channel=CellChannel.SERVER_COMMAND, - topic=ServerCommandNames.HANDLE_DEAD_JOB, - message=request, - optional=True, - ) - except Exception: - self.logger.info("Could not connect to server runner process") + self.engine.notify_dead_job(job_id, client.name, reason) + except Exception as ex: + self.logger.info( + f"Failed to notify_dead_job to runner process of job {job_id}: {secure_format_exception(ex)}" + ) def notify_dead_client(self, client): """Called to do further processing of the dead client @@ -662,7 +647,7 @@ def notify_dead_client(self, client): assert isinstance(process_info, dict) participating_clients = process_info.get(RunProcessKey.PARTICIPANTS, None) if participating_clients and client.token in participating_clients: - self._notify_dead_job(client, job_id) + self._notify_dead_job(client, job_id, "client dead") def start_run(self, job_id, run_root, conf, args, snapshot): # Create the FL Engine diff --git a/nvflare/private/fed/server/server_commands.py b/nvflare/private/fed/server/server_commands.py index 37ca856ef8..a4456bc5ce 100644 --- a/nvflare/private/fed/server/server_commands.py +++ b/nvflare/private/fed/server/server_commands.py @@ -263,6 +263,8 @@ def process(self, data: Shareable, fl_ctx: FLContext): """ client_name = data.get_header(ServerCommandKey.FL_CLIENT) + reason = data.get_header(ServerCommandKey.REASON) + self.logger.warning(f"received dead job notification: {reason=}") server_runner = fl_ctx.get_prop(FLContextKey.RUNNER) if server_runner: server_runner.handle_dead_job(client_name, fl_ctx) diff --git a/nvflare/private/fed/server/server_engine.py b/nvflare/private/fed/server/server_engine.py index cc0c3a0484..60f28bb758 100644 --- a/nvflare/private/fed/server/server_engine.py +++ b/nvflare/private/fed/server/server_engine.py @@ -577,13 +577,26 @@ def update_job_run_status(self): data = {"execution_error": execution_error} job_id = fl_ctx.get_job_id() request = new_cell_message({CellMessageHeaderKeys.JOB_ID: job_id}, data) - return_data = self.server.cell.fire_and_forget( + self.server.cell.fire_and_forget( targets=FQCN.ROOT_SERVER, channel=CellChannel.SERVER_PARENT_LISTENER, topic=ServerCommandNames.UPDATE_RUN_STATUS, message=request, ) + def notify_dead_job(self, job_id: str, client_name: str, reason: str): + shareable = Shareable() + shareable.set_header(ServerCommandKey.FL_CLIENT, client_name) + shareable.set_header(ServerCommandKey.REASON, reason) + self.send_command_to_child_runner_process( + job_id=job_id, + command_name=ServerCommandNames.HANDLE_DEAD_JOB, + command_data=shareable, + timeout=0.0, + optional=True, + ) + self.logger.warning(f"notified SJ of dead-job: {job_id=}; {client_name=}; {reason=}") + def send_command_to_child_runner_process( self, job_id: str, command_name: str, command_data, timeout=5.0, optional=False ): @@ -595,7 +608,7 @@ def send_command_to_child_runner_process( targets=fqcn, channel=CellChannel.SERVER_COMMAND, topic=command_name, - request=request, + message=request, optional=optional, ) return None diff --git a/nvflare/private/fed/server/server_runner.py b/nvflare/private/fed/server/server_runner.py index 9fbceb8880..4831cacfe8 100644 --- a/nvflare/private/fed/server/server_runner.py +++ b/nvflare/private/fed/server/server_runner.py @@ -116,6 +116,7 @@ def _register_aux_message_handler(self, engine): def _handle_sync_runner(self, topic: str, request: Shareable, fl_ctx: FLContext) -> Shareable: # simply ack + self._report_client_active("syncRunner", fl_ctx) return make_reply(ReturnCode.OK) def _execute_run(self): @@ -278,6 +279,8 @@ def process_task_request(self, client: Client, fl_ctx: FLContext) -> (str, str, self.log_debug(fl_ctx, "invalid task request: no peer context - asked client to try again later") return self._task_try_again() + self._report_client_active("getTask", fl_ctx) + peer_job_id = peer_ctx.get_job_id() if not peer_job_id or peer_job_id != self.job_id: # the client is in a different RUN @@ -383,9 +386,8 @@ def handle_dead_job(self, client_name: str, fl_ctx: FLContext): if self.current_wf is None: return - fl_ctx.set_prop(FLContextKey.DEAD_JOB_CLIENT_NAME, client_name) - self.log_debug(fl_ctx, "firing event EventType.JOB_DEAD") - self.fire_event(EventType.JOB_DEAD, fl_ctx) + if self.current_wf.controller: + self.current_wf.controller.communicator.process_dead_client_report(client_name, fl_ctx) except Exception as e: self.log_exception( @@ -408,6 +410,7 @@ def process_submission(self, client: Client, task_name: str, task_id: str, resul fl_ctx: FLContext """ self.log_info(fl_ctx, f"got result from client {client.name} for task: name={task_name}, id={task_id}") + self._report_client_active("submitTaskResult", fl_ctx) if not isinstance(result, Shareable): self.log_error(fl_ctx, "invalid result submission: must be Shareable but got {}".format(type(result))) @@ -503,11 +506,21 @@ def process_submission(self, client: Client, task_name: str, task_id: str, resul "Error processing client result by {}: {}".format(self.current_wf.id, secure_format_exception(e)), ) + def _report_client_active(self, reason: str, fl_ctx: FLContext): + with self.wf_lock: + if self.current_wf and self.current_wf.controller: + peer_ctx = fl_ctx.get_peer_context() + assert isinstance(peer_ctx, FLContext) + client_name = peer_ctx.get_identity_name() + self.current_wf.controller.communicator.client_is_active(client_name, reason, fl_ctx) + def _handle_job_heartbeat(self, topic: str, request: Shareable, fl_ctx: FLContext) -> Shareable: self.log_debug(fl_ctx, "received client job_heartbeat") + self._report_client_active("jobHeartbeat", fl_ctx) return make_reply(ReturnCode.OK) def _handle_task_check(self, topic: str, request: Shareable, fl_ctx: FLContext) -> Shareable: + self._report_client_active("taskCheck", fl_ctx) task_id = request.get_header(ReservedHeaderKey.TASK_ID) if not task_id: self.log_error(fl_ctx, f"missing {ReservedHeaderKey.TASK_ID} in task_check request") diff --git a/nvflare/private/fed/server/sys_cmd.py b/nvflare/private/fed/server/sys_cmd.py index c684e08073..6fb89266ae 100644 --- a/nvflare/private/fed/server/sys_cmd.py +++ b/nvflare/private/fed/server/sys_cmd.py @@ -76,6 +76,14 @@ def get_spec(self): authz_func=self.authorize_client_operation, visible=True, ), + CommandSpec( + name="dead", + description="send dead client msg to SJ", + usage="dead ", + handler_func=self.dead_client, + authz_func=self.must_be_project_admin, + visible=False, + ), ], ) @@ -175,3 +183,13 @@ def report_env(self, conn: Connection, args: List[str]): table = conn.append_table(["Sites", "Env"], name=MetaKey.CLIENTS) for k, v in site_resources.items(): table.add_row([str(k), str(v)], meta=v) + + def dead_client(self, conn: Connection, args: List[str]): + if len(args) != 3: + conn.append_error(f"Usage: {args[0]} client_name job_id") + return + client_name = args[1] + job_id = args[2] + engine = conn.app_ctx + engine.notify_dead_job(job_id, client_name, f"AdminCommand: {args[0]}") + conn.append_string(f"called notify_dead_job for client {client_name=} {job_id=}") From d29f5deec6576a9bf36fd2997637175a29c115cb Mon Sep 17 00:00:00 2001 From: Sean Yang Date: Thu, 18 Apr 2024 14:16:12 -0700 Subject: [PATCH 05/14] Enhance WFController (#2505) * set flmodel variables in basefedavg * make round info optional, fix inproc api bug --- .../in_process_client_api_executor.py | 18 +++++++++++++++--- .../app_common/executors/launcher_executor.py | 7 +++---- nvflare/app_common/workflows/base_fedavg.py | 1 + nvflare/app_common/workflows/fedavg.py | 2 +- .../app_common/workflows/model_controller.py | 11 ++++++----- nvflare/app_common/workflows/wf_controller.py | 6 +++--- nvflare/client/in_process/api.py | 15 ++++----------- 7 files changed, 33 insertions(+), 27 deletions(-) diff --git a/nvflare/app_common/executors/in_process_client_api_executor.py b/nvflare/app_common/executors/in_process_client_api_executor.py index 894cfdbf2c..0d70a61a1f 100644 --- a/nvflare/app_common/executors/in_process_client_api_executor.py +++ b/nvflare/app_common/executors/in_process_client_api_executor.py @@ -23,6 +23,7 @@ from nvflare.apis.signal import Signal from nvflare.apis.utils.analytix_utils import create_analytic_dxo from nvflare.app_common.abstract.params_converter import ParamsConverter +from nvflare.app_common.app_constant import AppConstants from nvflare.app_common.executors.exec_task_fn_wrapper import ExecTaskFuncWrapper from nvflare.app_common.tracking.tracker_types import ANALYTIC_EVENT_TYPE from nvflare.app_common.widgets.streaming import send_analytic_dxo @@ -107,6 +108,11 @@ def handle_event(self, event_type: str, fl_ctx: FLContext): self._task_fn_thread = threading.Thread(target=self._task_fn_wrapper.run) self._task_fn_thread.start() + meta = self._prepare_task_meta(fl_ctx, None) + self.client_api = InProcessClientAPI(task_metadata=meta, result_check_interval=0.5) + self.client_api.init() + self._data_bus.put_data(CLIENT_API_KEY, self.client_api) + elif event_type == EventType.END_RUN: self._event_manager.fire_event(TOPIC_STOP, "END_RUN received") if self._task_fn_thread: @@ -118,9 +124,7 @@ def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort fl_ctx.set_prop("abort_signal", abort_signal) meta = self._prepare_task_meta(fl_ctx, task_name) - client_api = InProcessClientAPI(task_metadata=meta, result_check_interval=0.5) - client_api.init() - self._data_bus.put_data(CLIENT_API_KEY, client_api) + self.client_api.set_meta(meta) shareable.set_header(FLMetaKey.JOB_ID, fl_ctx.get_job_id()) shareable.set_header(FLMetaKey.SITE_NAME, fl_ctx.get_identity_name()) @@ -142,6 +146,14 @@ def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort if self.local_result: result = self.local_result self.local_result = None + + if not isinstance(result, Shareable): + self.log_error(fl_ctx, f"bad task result from peer: expect Shareable but got {type(result)}") + return make_reply(ReturnCode.EXECUTION_EXCEPTION) + + current_round = shareable.get_header(AppConstants.CURRENT_ROUND) + if current_round is not None: + result.set_header(AppConstants.CURRENT_ROUND, current_round) if self._to_nvflare_converter is not None: result = self._to_nvflare_converter.process(task_name, result, fl_ctx) return result diff --git a/nvflare/app_common/executors/launcher_executor.py b/nvflare/app_common/executors/launcher_executor.py index d62f0671fe..d4dfe964df 100644 --- a/nvflare/app_common/executors/launcher_executor.py +++ b/nvflare/app_common/executors/launcher_executor.py @@ -181,12 +181,11 @@ def check_input_shareable(self, task_name: str, shareable: Shareable, fl_ctx: FL total_rounds = shareable.get_header(AppConstants.NUM_ROUNDS, None) if task_name == self._train_task_name: if current_round is None: - self.log_error(fl_ctx, "missing current round") - return False + self.log_warning(fl_ctx, f"no current round for task {task_name}") if total_rounds is None: - self.log_error(fl_ctx, "missing total number of rounds") - return False + self.log_warning(fl_ctx, f"no total number of rounds for task {task_name}") + return True def check_output_shareable(self, task_name: str, shareable: Shareable, fl_ctx: FLContext) -> bool: diff --git a/nvflare/app_common/workflows/base_fedavg.py b/nvflare/app_common/workflows/base_fedavg.py index f43eee24d3..7587fa40d5 100644 --- a/nvflare/app_common/workflows/base_fedavg.py +++ b/nvflare/app_common/workflows/base_fedavg.py @@ -71,6 +71,7 @@ def __init__( self.num_rounds = num_rounds self.start_round = start_round self.persist_every_n_rounds = persist_every_n_rounds + self.current_round = None def sample_clients(self, num_clients): diff --git a/nvflare/app_common/workflows/fedavg.py b/nvflare/app_common/workflows/fedavg.py index 61edca34fc..d03ae8999b 100644 --- a/nvflare/app_common/workflows/fedavg.py +++ b/nvflare/app_common/workflows/fedavg.py @@ -50,10 +50,10 @@ def run(self) -> None: for self.current_round in range(self.start_round, self.start_round + self.num_rounds): self.info(f"Round {self.current_round} started.") + model.current_round = self.current_round clients = self.sample_clients(self.min_clients) - model.current_round = self.current_round results = self.send_model_and_wait(targets=clients, data=model) aggregate_results = self.aggregate( diff --git a/nvflare/app_common/workflows/model_controller.py b/nvflare/app_common/workflows/model_controller.py index f891ef8ee0..d5db8215e3 100644 --- a/nvflare/app_common/workflows/model_controller.py +++ b/nvflare/app_common/workflows/model_controller.py @@ -37,7 +37,7 @@ class ModelController(Controller, FLComponentWrapper, ABC): def __init__( self, - persistor_id="", + persistor_id="persistor", ignore_result_error: bool = False, allow_empty_global_weights: bool = False, task_check_period: float = 0.5, @@ -79,11 +79,11 @@ def start_controller(self, fl_ctx: FLContext) -> None: if self._persistor_id: self._persistor = self._engine.get_component(self._persistor_id) if not isinstance(self._persistor, LearnablePersistor): - self.panic( + self.warning( f"Model Persistor {self._persistor_id} must be a LearnablePersistor type object, " f"but got {type(self._persistor)}" ) - return + self._persistor = None self.engine = self.fl_ctx.get_engine() FLComponentWrapper.initialize(self) @@ -96,7 +96,7 @@ def _build_shareable(self, data: FLModel = None) -> Shareable: return data_shareable - def send_model( + def broadcast_model( self, task_name: str = AppConstants.TASK_TRAIN, data: FLModel = None, @@ -216,7 +216,8 @@ def _process_result(self, client_task: ClientTask, fl_ctx: FLContext) -> None: result_model = FLModelUtils.from_shareable(result) result_model.meta["client_name"] = client_name - self.fl_ctx.set_prop(AppConstants.CURRENT_ROUND, result_model.current_round, private=True, sticky=True) + if result_model.current_round is not None: + self.fl_ctx.set_prop(AppConstants.CURRENT_ROUND, result_model.current_round, private=True, sticky=True) self.event(AppEventType.BEFORE_CONTRIBUTION_ACCEPT) self._accept_train_result(client_name=client_name, result=result, fl_ctx=fl_ctx) diff --git a/nvflare/app_common/workflows/wf_controller.py b/nvflare/app_common/workflows/wf_controller.py index 668bd6e348..021de51c10 100644 --- a/nvflare/app_common/workflows/wf_controller.py +++ b/nvflare/app_common/workflows/wf_controller.py @@ -23,7 +23,7 @@ class WFController(ModelController, ABC): def __init__( self, *args, - persistor_id: str = "", + persistor_id: str = "persistor", **kwargs, ): """Workflow Controller API for FLModel-based ModelController. @@ -58,7 +58,7 @@ def send_model_and_wait( Returns: List[FLModel] """ - return super().send_model( + return super().broadcast_model( task_name=task_name, data=data, targets=targets, @@ -88,7 +88,7 @@ def send_model( Returns: None """ - super().send_model( + super().broadcast_model( task_name=task_name, data=data, targets=targets, diff --git a/nvflare/client/in_process/api.py b/nvflare/client/in_process/api.py index 9ee7084b15..2bb0029bdc 100644 --- a/nvflare/client/in_process/api.py +++ b/nvflare/client/in_process/api.py @@ -50,12 +50,9 @@ def __init__(self, task_metadata: dict, result_check_interval: float = 2.0): self.meta = task_metadata self.result_check_interval = result_check_interval - self.start_round = None self.fl_model = None self.sys_info = {} self.client_config: Optional[ClientConfig] = None - self.current_round = None - self.total_rounds = None self.logger = logging.getLogger(self.__class__.__name__) self.event_manager = EventManager(self.data_bus) self.abort_reason = "" @@ -96,6 +93,9 @@ def prepare_client_config(self, config): client_config.config = self.meta self.client_config = client_config + def set_meta(self, meta: dict): + self.meta = meta + def receive(self, timeout: Optional[float] = None) -> Optional[FLModel]: if self.fl_model: return self.fl_model @@ -149,14 +149,7 @@ def is_running(self) -> bool: else: self.receive() - if self.fl_model: - self.current_round = self.fl_model.current_round - self.total_rounds = self.fl_model.total_rounds - self.start_round = self.fl_model.meta.get(FLMetaKey.START_ROUND, 0) - else: - return False - - return self.current_round < self.start_round + self.total_rounds + return self.fl_model is not None def is_train(self) -> bool: if self.rank != "0": From c33cc4c0726e492d902817aac7db53b5f49448f2 Mon Sep 17 00:00:00 2001 From: Sean Yang Date: Thu, 18 Apr 2024 16:47:04 -0700 Subject: [PATCH 06/14] temporarily disable preflight tests (#2521) --- tests/integration_test/run_integration_tests.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/integration_test/run_integration_tests.sh b/tests/integration_test/run_integration_tests.sh index 6f8e68a981..49021f6104 100755 --- a/tests/integration_test/run_integration_tests.sh +++ b/tests/integration_test/run_integration_tests.sh @@ -79,8 +79,8 @@ if [[ $m == "tensorflow" ]]; then run_tensorflow elif [[ $m == "overseer" ]]; then run_overseer_test -elif [[ $m == "preflight" ]]; then - run_preflight_check_test +# elif [[ $m == "preflight" ]]; then +# run_preflight_check_test else run_system_test fi From ac9ff92abffdc4c8d7bf6ef6e70e06e4fb4cec30 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yuan-Ting=20Hsieh=20=28=E8=AC=9D=E6=B2=85=E5=BB=B7=29?= Date: Thu, 18 Apr 2024 16:49:39 -0700 Subject: [PATCH 07/14] Upgrade dependencies (#2516) --- setup.cfg | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/setup.cfg b/setup.cfg index 7b1b6e834f..c07c44e483 100644 --- a/setup.cfg +++ b/setup.cfg @@ -17,15 +17,15 @@ zip_safe = True python_requires = >= 3.8 install_requires = cryptography>=36.0.0 - Flask==2.2.5 - Werkzeug==2.2.2 - Flask-JWT-Extended==4.4.3 - Flask-SQLAlchemy==2.5.1 - SQLAlchemy==1.4.31 - grpcio==1.51.1 + Flask==3.0.2 + Werkzeug==3.0.1 + Flask-JWT-Extended==4.6.0 + Flask-SQLAlchemy==3.1.1 + SQLAlchemy==2.0.16 + grpcio==1.62.1 gunicorn>=20.1.0 numpy - protobuf==3.20.3 + protobuf==4.24.4 psutil>=5.9.1 PyYAML>=6.0 requests>=2.28.0 @@ -45,6 +45,7 @@ PT = torchvision SKLEARN = scikit-learn + pandas>=1.5.1 TRACKING = mlflow wandb From 3db9365ebc5afe1c292af4be6cbd18734bc28d1c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yuan-Ting=20Hsieh=20=28=E8=AC=9D=E6=B2=85=E5=BB=B7=29?= Date: Thu, 18 Apr 2024 16:49:54 -0700 Subject: [PATCH 08/14] Use full path for PSI components (#2437) (#2517) --- .../advanced/psi/user_email_match/README.md | 59 ++++++++----------- .../app/config/config_fed_client.conf | 12 ++-- .../app/config/config_fed_server.conf | 4 +- job_templates/psi_csv/config_fed_client.conf | 8 +-- job_templates/psi_csv/config_fed_server.conf | 2 +- 5 files changed, 39 insertions(+), 46 deletions(-) diff --git a/examples/advanced/psi/user_email_match/README.md b/examples/advanced/psi/user_email_match/README.md index 6952fc25f2..2b01108e8e 100644 --- a/examples/advanced/psi/user_email_match/README.md +++ b/examples/advanced/psi/user_email_match/README.md @@ -13,42 +13,36 @@ These items could be user_ids or feature names depending on your use case. ``` { - "format_version": 2, - "executors": [ + format_version = 2 + executors = [ { - "tasks": [ - "PSI" - ], - "executor": { - "id": "Executor", - "name": "PSIExecutor", - "args": { - "psi_algo_id": "dh_psi" - } + tasks = ["PSI"] + executor { + id = "Executor" + path = "nvflare.app_common.psi.psi_executor.PSIExecutor" + args.psi_algo_id = "dh_psi" } } - ], - "components": [ + ] + + components = [ { - "id": "dh_psi", - "name": "DhPSITaskHandler", - "args": { - "local_psi_id": "local_psi" - } + id = "dh_psi" + path = "nvflare.app_opt.psi.dh_psi.dh_psi_task_handler.DhPSITaskHandler" + args.local_psi_id = "local_psi" }, { - "id": "local_psi", - "path": "local_psi.LocalPSI", - "args": { - "psi_writer_id": "psi_writer" + id = "local_psi" + path = "local_psi.LocalPSI" + args { + psi_writer_id = "psi_writer" + data_root_dir = "/tmp/nvflare/psi/data" } }, { - "id": "psi_writer", - "name": "FilePSIWriter", - "args": { - "output_path": "psi/intersection.txt" - } + id = "psi_writer", + path = "nvflare.app_common.psi.file_psi_writer.FilePSIWriter" + args.output_path = "psi/intersection.txt" } ] } @@ -67,17 +61,16 @@ a file writer Just specify the built-in PSI controller. ``` { - "format_version": 2, - "workflows": [ + format_version = 2, + workflows = [ { - "id": "controller", - "name": "DhPSIController", - "args": { + id = "DhPSIController" + path = "nvflare.app_common.psi.dh_psi.dh_psi_controller.DhPSIController" + args{ } } ] } - ``` **Code** the code is really trivial just needs to implement one method in PSI interface diff --git a/examples/advanced/psi/user_email_match/jobs/user_email_match/app/config/config_fed_client.conf b/examples/advanced/psi/user_email_match/jobs/user_email_match/app/config/config_fed_client.conf index 32b4f9c50b..d233433576 100644 --- a/examples/advanced/psi/user_email_match/jobs/user_email_match/app/config/config_fed_client.conf +++ b/examples/advanced/psi/user_email_match/jobs/user_email_match/app/config/config_fed_client.conf @@ -5,7 +5,7 @@ tasks = ["PSI"] executor { id = "Executor" - name = "PSIExecutor" + path = "nvflare.app_common.psi.psi_executor.PSIExecutor" args.psi_algo_id = "dh_psi" } } @@ -14,20 +14,20 @@ components = [ { id = "dh_psi" - name = "DhPSITaskHandler" + path = "nvflare.app_opt.psi.dh_psi.dh_psi_task_handler.DhPSITaskHandler" args.local_psi_id = "local_psi" }, { id = "local_psi" - path="local_psi.LocalPSI" + path = "local_psi.LocalPSI" args { - psi_writer_id="psi_writer", - data_root_dir="/tmp/nvflare/psi/data" + psi_writer_id = "psi_writer" + data_root_dir = "/tmp/nvflare/psi/data" } }, { id = "psi_writer", - name = "FilePSIWriter", + path = "nvflare.app_common.psi.file_psi_writer.FilePSIWriter" args.output_path = "psi/intersection.txt" } ] diff --git a/examples/advanced/psi/user_email_match/jobs/user_email_match/app/config/config_fed_server.conf b/examples/advanced/psi/user_email_match/jobs/user_email_match/app/config/config_fed_server.conf index 23129f8817..c17696aa53 100644 --- a/examples/advanced/psi/user_email_match/jobs/user_email_match/app/config/config_fed_server.conf +++ b/examples/advanced/psi/user_email_match/jobs/user_email_match/app/config/config_fed_server.conf @@ -2,8 +2,8 @@ format_version = 2, workflows = [ { - id="DhPSIController" - name="DhPSIController" + id = "DhPSIController" + path = "nvflare.app_common.psi.dh_psi.dh_psi_controller.DhPSIController" args{ } } diff --git a/job_templates/psi_csv/config_fed_client.conf b/job_templates/psi_csv/config_fed_client.conf index aec8b2b7db..14d8c91edb 100644 --- a/job_templates/psi_csv/config_fed_client.conf +++ b/job_templates/psi_csv/config_fed_client.conf @@ -7,7 +7,7 @@ executors = [ executor { # built in PSIExecutor id = "psi_executor" - name = "PSIExecutor" + path = "nvflare.app_common.psi.psi_executor.PSIExecutor" args { psi_algo_id = "dh_psi" } @@ -17,13 +17,13 @@ executors = [ components = [ { id = "dh_psi" - name = "DhPSITaskHandler" + path = "nvflare.app_opt.psi.dh_psi.dh_psi_task_handler.DhPSITaskHandler" args { local_psi_id = "local_psi" } } { - # custome component to load the items for the PSI algorithm + # custom component to load the items for the PSI algorithm id = "local_psi" path = "local_psi.LocalPSI" args { @@ -37,7 +37,7 @@ components = [ { # saves the calculated intersection to a file in the workspace id = "psi_writer" - name = "FilePSIWriter" + path = "nvflare.app_common.psi.file_psi_writer.FilePSIWriter" args { output_path = "psi/intersection.txt" } diff --git a/job_templates/psi_csv/config_fed_server.conf b/job_templates/psi_csv/config_fed_server.conf index 6c4d91d431..fd54c8c98a 100644 --- a/job_templates/psi_csv/config_fed_server.conf +++ b/job_templates/psi_csv/config_fed_server.conf @@ -2,7 +2,7 @@ format_version = 2 workflows = [ { id = "controller" - name = "DhPSIController" + path = "nvflare.app_common.psi.dh_psi.dh_psi_controller.DhPSIController" args { } } From 8c91a284d12dbe5880ef4ba0dd12cf4d154bf9cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yuan-Ting=20Hsieh=20=28=E8=AC=9D=E6=B2=85=E5=BB=B7=29?= Date: Thu, 18 Apr 2024 16:50:09 -0700 Subject: [PATCH 09/14] Multiple bug fixes from 2.4 (#2518) * [2.4] Support client custom code in simulator (#2447) * Support client custom code in simulator * Fix client custom code * Remove cancel_futures args (#2457) * Fix sub_worker_process shutdown (#2458) * Set GRPC_ENABLE_FORK_SUPPORT to False (#2474) --- nvflare/apis/utils/reliable_message.py | 2 +- nvflare/fuel/f3/drivers/aio_grpc_driver.py | 5 +++++ nvflare/fuel/f3/drivers/grpc_driver.py | 3 +++ nvflare/private/fed/app/client/sub_worker_process.py | 2 -- nvflare/private/fed/app/simulator/simulator_worker.py | 2 ++ 5 files changed, 11 insertions(+), 3 deletions(-) diff --git a/nvflare/apis/utils/reliable_message.py b/nvflare/apis/utils/reliable_message.py index b222336d53..802d2aff2e 100644 --- a/nvflare/apis/utils/reliable_message.py +++ b/nvflare/apis/utils/reliable_message.py @@ -357,7 +357,7 @@ def shutdown(cls): """ if not cls._shutdown_asked: cls._shutdown_asked = True - cls._executor.shutdown(cancel_futures=True, wait=False) + cls._executor.shutdown(wait=False) cls._logger.info("ReliableMessage is shutdown") @classmethod diff --git a/nvflare/fuel/f3/drivers/aio_grpc_driver.py b/nvflare/fuel/f3/drivers/aio_grpc_driver.py index 2837b2e0b2..9c91f3271f 100644 --- a/nvflare/fuel/f3/drivers/aio_grpc_driver.py +++ b/nvflare/fuel/f3/drivers/aio_grpc_driver.py @@ -11,7 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import asyncio +import os import random import threading import time @@ -278,6 +280,9 @@ class AioGrpcDriver(BaseDriver): def __init__(self): super().__init__() + # GRPC with fork issue: https://github.com/grpc/grpc/issues/28557 + os.environ["GRPC_ENABLE_FORK_SUPPORT"] = "False" + self.server = None self.options = GRPC_DEFAULT_OPTIONS self.logger = get_logger(self) diff --git a/nvflare/fuel/f3/drivers/grpc_driver.py b/nvflare/fuel/f3/drivers/grpc_driver.py index 2e677584c7..2ec4009fbf 100644 --- a/nvflare/fuel/f3/drivers/grpc_driver.py +++ b/nvflare/fuel/f3/drivers/grpc_driver.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import threading from concurrent import futures from typing import Any, Dict, List, Union @@ -202,6 +203,8 @@ def shutdown(self): class GrpcDriver(BaseDriver): def __init__(self): BaseDriver.__init__(self) + # GRPC with fork issue: https://github.com/grpc/grpc/issues/28557 + os.environ["GRPC_ENABLE_FORK_SUPPORT"] = "False" self.server = None self.closing = False self.max_workers = 100 diff --git a/nvflare/private/fed/app/client/sub_worker_process.py b/nvflare/private/fed/app/client/sub_worker_process.py index 9ad491f23b..ec911961bf 100644 --- a/nvflare/private/fed/app/client/sub_worker_process.py +++ b/nvflare/private/fed/app/client/sub_worker_process.py @@ -294,8 +294,6 @@ def _handle_event(self, data): def _close(self, data): self.done = True - self.cell.stop() - # mpm.stop() def run(self): self.logger.info("SubWorkerExecutor process started.") diff --git a/nvflare/private/fed/app/simulator/simulator_worker.py b/nvflare/private/fed/app/simulator/simulator_worker.py index 1fd9ba9614..85cde612c7 100644 --- a/nvflare/private/fed/app/simulator/simulator_worker.py +++ b/nvflare/private/fed/app/simulator/simulator_worker.py @@ -235,6 +235,8 @@ def main(args): log_file = os.path.join(args.workspace, WorkspaceConstants.LOG_FILE_NAME) add_logfile_handler(log_file) + app_custom_folder = os.path.join(args.workspace, "custom") + sys.path.append(app_custom_folder) os.chdir(args.workspace) fobs_initialize() AuthorizationService.initialize(EmptyAuthorizer()) From bc7d96d3715886c0d60af6bfc687d9c713ee048c Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Thu, 18 Apr 2024 20:21:46 -0400 Subject: [PATCH 10/14] Pythonic job creation (#2483) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * WIP: constructed the FedJob. * WIP: server_app josn export. * generate the job app config. * fully functional pythonic job creation. * Added simulator_run for pythonic API. * reformat. * Added filters support for pythonic job creation. * handled the direct import case in fed_job. * refactor. * Added the resource_spec set function for FedJob. * refactored. * Moved the ClientApp and ServerApp into fed_app.py. * Refactored: removed the _FilterDef class. * refactored. * Rename job config classes (#3) * rename config related classes * add client api example * fix metric streaming * add to() routine * Enable obj in the constructor as paramenter. * Added support for the launcher script. * refactored. * reformat. * Update the comment. * re-arrange the package location. * Added add_ext_script() for BaseAppConfig. * codestyle fix. * Removed the client-api-pt example. * removed no used import. * fixed the in_time_accumulate_weighted_aggregator_test.py * Added Enum parameter support. * Added docstring. * Added ability to handle parameters from base class. * Move the parameter data format conversion to the START_RUN event for InProcessClientAPIExecutor. * Added params_exchange_format for PTInProcessClientAPIExecutor. * codestyle fix. * Fixed a custom code folder structure issue. * work for sub-folder custom files. * backed to handle parameters from base classes. * Support folder structure job config. * Added support for flat folder from '.XXX' import. * codestyle fix. * refactored and add docstring. * Address some of the PR reviews. --------- Co-authored-by: Holger Roth <6304754+holgerroth@users.noreply.github.com> Co-authored-by: Yuan-Ting Hsieh (謝沅廷) Co-authored-by: Chester Chen <512707+chesterxgchen@users.noreply.github.com> --- .../hello-pt/add_shareable_parameter.py | 24 ++ .../job_config/hello-pt/cifar10trainer.py | 200 +++++++++++ .../job_config/hello-pt/cifar10validator.py | 112 ++++++ .../job_config/hello-pt/hello_pt_job.py | 120 +++++++ .../hello-pt/print_shareable_parameter.py | 33 ++ .../job_config/hello-pt/pt_constants.py | 21 ++ .../job_config/hello-pt/pt_model_locator.py | 66 ++++ .../job_config/hello-pt/simple_network.py | 38 +++ .../intime_accumulate_model_aggregator.py | 18 +- .../in_process_client_api_executor.py | 12 +- .../pt/in_process_client_api_executor.py | 3 +- nvflare/job_config/__init__.py | 13 + nvflare/job_config/base_app_config.py | 72 ++++ nvflare/job_config/fed_app_config.py | 88 +++++ nvflare/job_config/fed_job_config.py | 321 ++++++++++++++++++ ...ime_accumulate_weighted_aggregator_test.py | 13 +- 16 files changed, 1143 insertions(+), 11 deletions(-) create mode 100644 examples/advanced/job_config/hello-pt/add_shareable_parameter.py create mode 100644 examples/advanced/job_config/hello-pt/cifar10trainer.py create mode 100644 examples/advanced/job_config/hello-pt/cifar10validator.py create mode 100644 examples/advanced/job_config/hello-pt/hello_pt_job.py create mode 100644 examples/advanced/job_config/hello-pt/print_shareable_parameter.py create mode 100644 examples/advanced/job_config/hello-pt/pt_constants.py create mode 100644 examples/advanced/job_config/hello-pt/pt_model_locator.py create mode 100644 examples/advanced/job_config/hello-pt/simple_network.py create mode 100644 nvflare/job_config/__init__.py create mode 100644 nvflare/job_config/base_app_config.py create mode 100644 nvflare/job_config/fed_app_config.py create mode 100644 nvflare/job_config/fed_job_config.py diff --git a/examples/advanced/job_config/hello-pt/add_shareable_parameter.py b/examples/advanced/job_config/hello-pt/add_shareable_parameter.py new file mode 100644 index 0000000000..f7c819c1c7 --- /dev/null +++ b/examples/advanced/job_config/hello-pt/add_shareable_parameter.py @@ -0,0 +1,24 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nvflare.apis.filter import Filter +from nvflare.apis.fl_context import FLContext +from nvflare.apis.shareable import Shareable + + +class AddShareable(Filter): + def process(self, shareable: Shareable, fl_ctx: FLContext) -> Shareable: + print(f"{fl_ctx.get_identity_name()} ---- AddShareable Filter ----") + + return shareable diff --git a/examples/advanced/job_config/hello-pt/cifar10trainer.py b/examples/advanced/job_config/hello-pt/cifar10trainer.py new file mode 100644 index 0000000000..181e84ad13 --- /dev/null +++ b/examples/advanced/job_config/hello-pt/cifar10trainer.py @@ -0,0 +1,200 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os.path + +import torch +from pt_constants import PTConstants +from simple_network import SimpleNetwork +from torch import nn +from torch.optim import SGD +from torch.utils.data.dataloader import DataLoader +from torchvision.datasets import CIFAR10 +from torchvision.transforms import Compose, Normalize, ToTensor + +from nvflare.apis.dxo import DXO, DataKind, MetaKey, from_shareable +from nvflare.apis.executor import Executor +from nvflare.apis.fl_constant import ReservedKey, ReturnCode +from nvflare.apis.fl_context import FLContext +from nvflare.apis.shareable import Shareable, make_reply +from nvflare.apis.signal import Signal +from nvflare.app_common.abstract.model import make_model_learnable, model_learnable_to_dxo +from nvflare.app_common.app_constant import AppConstants +from nvflare.app_opt.pt.model_persistence_format_manager import PTModelPersistenceFormatManager + + +class Cifar10Trainer(Executor): + def __init__( + self, + data_path="~/data", + lr=0.01, + epochs=5, + train_task_name=AppConstants.TASK_TRAIN, + submit_model_task_name=AppConstants.TASK_SUBMIT_MODEL, + exclude_vars=None, + pre_train_task_name=AppConstants.TASK_GET_WEIGHTS, + ): + """Cifar10 Trainer handles train and submit_model tasks. During train_task, it trains a + simple network on CIFAR10 dataset. For submit_model task, it sends the locally trained model + (if present) to the server. + + Args: + lr (float, optional): Learning rate. Defaults to 0.01 + epochs (int, optional): Epochs. Defaults to 5 + train_task_name (str, optional): Task name for train task. Defaults to "train". + submit_model_task_name (str, optional): Task name for submit model. Defaults to "submit_model". + exclude_vars (list): List of variables to exclude during model loading. + pre_train_task_name: Task name for pre train task, i.e., sending initial model weights. + """ + super().__init__() + + self._lr = lr + self._epochs = epochs + self._train_task_name = train_task_name + self._pre_train_task_name = pre_train_task_name + self._submit_model_task_name = submit_model_task_name + self._exclude_vars = exclude_vars + + # Training setup + self.model = SimpleNetwork() + self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + self.model.to(self.device) + self.loss = nn.CrossEntropyLoss() + self.optimizer = SGD(self.model.parameters(), lr=lr, momentum=0.9) + + # Create Cifar10 dataset for training. + transforms = Compose( + [ + ToTensor(), + Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ] + ) + self._train_dataset = CIFAR10(root=data_path, transform=transforms, download=True, train=True) + self._train_loader = DataLoader(self._train_dataset, batch_size=4, shuffle=True) + self._n_iterations = len(self._train_loader) + + # Setup the persistence manager to save PT model. + # The default training configuration is used by persistence manager + # in case no initial model is found. + self._default_train_conf = {"train": {"model": type(self.model).__name__}} + self.persistence_manager = PTModelPersistenceFormatManager( + data=self.model.state_dict(), default_train_conf=self._default_train_conf + ) + + def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable: + try: + if task_name == self._pre_train_task_name: + # Get the new state dict and send as weights + return self._get_model_weights() + elif task_name == self._train_task_name: + # Get model weights + try: + dxo = from_shareable(shareable) + except: + self.log_error(fl_ctx, "Unable to extract dxo from shareable.") + return make_reply(ReturnCode.BAD_TASK_DATA) + + # Ensure data kind is weights. + if not dxo.data_kind == DataKind.WEIGHTS: + self.log_error(fl_ctx, f"data_kind expected WEIGHTS but got {dxo.data_kind} instead.") + return make_reply(ReturnCode.BAD_TASK_DATA) + + # Convert weights to tensor. Run training + torch_weights = {k: torch.as_tensor(v) for k, v in dxo.data.items()} + self._local_train(fl_ctx, torch_weights, abort_signal) + + # Check the abort_signal after training. + # local_train returns early if abort_signal is triggered. + if abort_signal.triggered: + return make_reply(ReturnCode.TASK_ABORTED) + + # Save the local model after training. + self._save_local_model(fl_ctx) + + # Get the new state dict and send as weights + return self._get_model_weights() + elif task_name == self._submit_model_task_name: + # Load local model + ml = self._load_local_model(fl_ctx) + + # Get the model parameters and create dxo from it + dxo = model_learnable_to_dxo(ml) + return dxo.to_shareable() + else: + return make_reply(ReturnCode.TASK_UNKNOWN) + except Exception as e: + self.log_exception(fl_ctx, f"Exception in simple trainer: {e}.") + return make_reply(ReturnCode.EXECUTION_EXCEPTION) + + def _get_model_weights(self) -> Shareable: + # Get the new state dict and send as weights + weights = {k: v.cpu().numpy() for k, v in self.model.state_dict().items()} + + outgoing_dxo = DXO( + data_kind=DataKind.WEIGHTS, data=weights, meta={MetaKey.NUM_STEPS_CURRENT_ROUND: self._n_iterations} + ) + return outgoing_dxo.to_shareable() + + def _local_train(self, fl_ctx, weights, abort_signal): + # Set the model weights + self.model.load_state_dict(state_dict=weights) + + # Basic training + self.model.train() + for epoch in range(self._epochs): + running_loss = 0.0 + for i, batch in enumerate(self._train_loader): + if abort_signal.triggered: + # If abort_signal is triggered, we simply return. + # The outside function will check it again and decide steps to take. + return + + images, labels = batch[0].to(self.device), batch[1].to(self.device) + self.optimizer.zero_grad() + + predictions = self.model(images) + cost = self.loss(predictions, labels) + cost.backward() + self.optimizer.step() + + running_loss += cost.cpu().detach().numpy() / images.size()[0] + if i % 3000 == 0: + self.log_info( + fl_ctx, f"Epoch: {epoch}/{self._epochs}, Iteration: {i}, " f"Loss: {running_loss/3000}" + ) + running_loss = 0.0 + + def _save_local_model(self, fl_ctx: FLContext): + run_dir = fl_ctx.get_engine().get_workspace().get_run_dir(fl_ctx.get_prop(ReservedKey.RUN_NUM)) + models_dir = os.path.join(run_dir, PTConstants.PTModelsDir) + if not os.path.exists(models_dir): + os.makedirs(models_dir) + model_path = os.path.join(models_dir, PTConstants.PTLocalModelName) + + ml = make_model_learnable(self.model.state_dict(), {}) + self.persistence_manager.update(ml) + torch.save(self.persistence_manager.to_persistence_dict(), model_path) + + def _load_local_model(self, fl_ctx: FLContext): + run_dir = fl_ctx.get_engine().get_workspace().get_run_dir(fl_ctx.get_prop(ReservedKey.RUN_NUM)) + models_dir = os.path.join(run_dir, PTConstants.PTModelsDir) + if not os.path.exists(models_dir): + return None + model_path = os.path.join(models_dir, PTConstants.PTLocalModelName) + + self.persistence_manager = PTModelPersistenceFormatManager( + data=torch.load(model_path), default_train_conf=self._default_train_conf + ) + ml = self.persistence_manager.to_model_learnable(exclude_vars=self._exclude_vars) + return ml diff --git a/examples/advanced/job_config/hello-pt/cifar10validator.py b/examples/advanced/job_config/hello-pt/cifar10validator.py new file mode 100644 index 0000000000..80a0c7b714 --- /dev/null +++ b/examples/advanced/job_config/hello-pt/cifar10validator.py @@ -0,0 +1,112 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from simple_network import SimpleNetwork +from torch.utils.data import DataLoader +from torchvision.datasets import CIFAR10 +from torchvision.transforms import Compose, Normalize, ToTensor + +from nvflare.apis.dxo import DXO, DataKind, from_shareable +from nvflare.apis.executor import Executor +from nvflare.apis.fl_constant import ReturnCode +from nvflare.apis.fl_context import FLContext +from nvflare.apis.shareable import Shareable, make_reply +from nvflare.apis.signal import Signal +from nvflare.app_common.app_constant import AppConstants + + +class Cifar10Validator(Executor): + def __init__(self, data_path="~/data", validate_task_name=AppConstants.TASK_VALIDATION): + super().__init__() + + self._validate_task_name = validate_task_name + + # Setup the model + self.model = SimpleNetwork() + self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + self.model.to(self.device) + + # Preparing the dataset for testing. + transforms = Compose( + [ + ToTensor(), + Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ] + ) + test_data = CIFAR10(root=data_path, train=False, transform=transforms) + self._test_loader = DataLoader(test_data, batch_size=4, shuffle=False) + + def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable: + if task_name == self._validate_task_name: + model_owner = "?" + try: + try: + dxo = from_shareable(shareable) + except: + self.log_error(fl_ctx, "Error in extracting dxo from shareable.") + return make_reply(ReturnCode.BAD_TASK_DATA) + + # Ensure data_kind is weights. + if not dxo.data_kind == DataKind.WEIGHTS: + self.log_exception(fl_ctx, f"DXO is of type {dxo.data_kind} but expected type WEIGHTS.") + return make_reply(ReturnCode.BAD_TASK_DATA) + + # Extract weights and ensure they are tensor. + model_owner = shareable.get_header(AppConstants.MODEL_OWNER, "?") + weights = {k: torch.as_tensor(v, device=self.device) for k, v in dxo.data.items()} + + # Get validation accuracy + val_accuracy = self._validate(weights, abort_signal) + if abort_signal.triggered: + return make_reply(ReturnCode.TASK_ABORTED) + + self.log_info( + fl_ctx, + f"Accuracy when validating {model_owner}'s model on" + f" {fl_ctx.get_identity_name()}" + f"s data: {val_accuracy}", + ) + + dxo = DXO(data_kind=DataKind.METRICS, data={"val_acc": val_accuracy}) + return dxo.to_shareable() + except: + self.log_exception(fl_ctx, f"Exception in validating model from {model_owner}") + return make_reply(ReturnCode.EXECUTION_EXCEPTION) + else: + return make_reply(ReturnCode.TASK_UNKNOWN) + + def _validate(self, weights, abort_signal): + self.model.load_state_dict(weights) + + self.model.eval() + + correct = 0 + total = 0 + with torch.no_grad(): + for i, (images, labels) in enumerate(self._test_loader): + if abort_signal.triggered: + return 0 + + images, labels = images.to(self.device), labels.to(self.device) + output = self.model(images) + + _, pred_label = torch.max(output, 1) + + correct += (pred_label == labels).sum().item() + total += images.size()[0] + + metric = correct / float(total) + + return metric diff --git a/examples/advanced/job_config/hello-pt/hello_pt_job.py b/examples/advanced/job_config/hello-pt/hello_pt_job.py new file mode 100644 index 0000000000..d95a962a3b --- /dev/null +++ b/examples/advanced/job_config/hello-pt/hello_pt_job.py @@ -0,0 +1,120 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from add_shareable_parameter import AddShareable +from cifar10trainer import Cifar10Trainer +from cifar10validator import Cifar10Validator +from print_shareable_parameter import PrintShareable +from pt_model_locator import PTModelLocator + +from nvflare.apis.dxo import DataKind +from nvflare.app_common.aggregators import InTimeAccumulateWeightedAggregator +from nvflare.app_common.shareablegenerators import FullModelShareableGenerator +from nvflare.app_common.widgets.validation_json_generator import ValidationJsonGenerator +from nvflare.app_common.workflows.cross_site_model_eval import CrossSiteModelEval +from nvflare.app_common.workflows.initialize_global_weights import InitializeGlobalWeights +from nvflare.app_common.workflows.scatter_and_gather import ScatterAndGather +from nvflare.app_opt.pt import PTFileModelPersistor +from nvflare.job_config.fed_app_config import ClientAppConfig, FedAppConfig, ServerAppConfig +from nvflare.job_config.fed_job_config import FedJobConfig + + +class HelloPTJob: + def __init__(self) -> None: + super().__init__() + self.job = self.define_job() + + def define_job(self) -> FedJobConfig: + # job = FedJobConfig(job_name="hello-pt", min_clients=2, mandatory_clients="site-1") + job: FedJobConfig = FedJobConfig(job_name="hello-pt", min_clients=2) + + server_app = self._create_server_app() + client_app = self._create_client_app() + + app = FedAppConfig(server_app=server_app, client_app=client_app) + job.add_fed_app("app", app) + + # app = FedAppConfig(client_app=client_app) + # job.add_fed_app("client_app", app) + # job.set_site_app("server", "app") + # job.set_site_app("site-1", "app") + # job.set_site_app("site-2", "client_app") + # job.add_resource_spec("site-1", {"memory": "8GB"}) + + job.set_site_app("@ALL", "app") + + return job + + def _create_client_app(self): + client_app = ClientAppConfig() + executor = Cifar10Trainer(lr=0.01, epochs=1) + client_app.add_executor(["train", "submit_model", "get_weights"], executor) + validator = Cifar10Validator() + client_app.add_executor(["validate"], validator) + + task_filter = AddShareable() + client_app.add_task_result_filter(["train"], task_filter) + task_filter = PrintShareable() + client_app.add_task_data_filter(["validate", "train"], task_filter) + return client_app + + def _create_server_app(self): + server_app = ServerAppConfig() + controller = InitializeGlobalWeights(task_name="get_weights") + server_app.add_workflow("pre_train", controller) + controller = ScatterAndGather( + min_clients=2, + num_rounds=2, + start_round=0, + wait_time_after_min_received=10, + aggregator_id="aggregator", + persistor_id="persistor", + shareable_generator_id="shareable_generator", + train_task_name="train", + train_timeout=0, + ) + server_app.add_workflow("scatter_and_gather", controller) + controller = CrossSiteModelEval(model_locator_id="model_locator") + server_app.add_workflow("cross_site_validate", controller) + + component = PTFileModelPersistor() + server_app.add_component("persistor", component) + component = FullModelShareableGenerator() + server_app.add_component("shareable_generator", component) + component = InTimeAccumulateWeightedAggregator( + expected_data_kind=DataKind.WEIGHTS, aggregation_weights={"site-1": 1.0, "site-2": 1.0} + ) + server_app.add_component("aggregator", component) + component = PTModelLocator() + server_app.add_component("model_locator", component) + component = ValidationJsonGenerator() + server_app.add_component("json_generator", component) + + task_filter = AddShareable() + server_app.add_task_data_filter(["train"], task_filter) + task_filter = PrintShareable() + server_app.add_task_result_filter(["validate", "train"], task_filter) + return server_app + + def export_job(self, job_root): + self.job.generate_job_config(job_root) + + def simulator_run(self, job_root, workspace): + self.job.simulator_run(job_root, workspace, threads=2) + + +if __name__ == "__main__": + job = HelloPTJob() + + # job.export_job("/tmp/nvflare/jobs") + job.simulator_run("/tmp/nvflare/jobs", "/tmp/nvflare/simulator_workspace") diff --git a/examples/advanced/job_config/hello-pt/print_shareable_parameter.py b/examples/advanced/job_config/hello-pt/print_shareable_parameter.py new file mode 100644 index 0000000000..35d68317ba --- /dev/null +++ b/examples/advanced/job_config/hello-pt/print_shareable_parameter.py @@ -0,0 +1,33 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nvflare.apis.dxo import from_shareable +from nvflare.apis.filter import Filter +from nvflare.apis.fl_context import FLContext +from nvflare.apis.shareable import Shareable + + +class PrintShareable(Filter): + def process(self, shareable: Shareable, fl_ctx: FLContext) -> Shareable: + dxo = from_shareable(shareable) + model_weights = dxo.data + + count = 0 + keys = "" + for item in model_weights.keys(): + keys += item + "; " + count += 1 + print(f"{fl_ctx.get_identity_name()} ----- Total parameters in the Shareable: {count}") + + return shareable diff --git a/examples/advanced/job_config/hello-pt/pt_constants.py b/examples/advanced/job_config/hello-pt/pt_constants.py new file mode 100644 index 0000000000..d8deca3517 --- /dev/null +++ b/examples/advanced/job_config/hello-pt/pt_constants.py @@ -0,0 +1,21 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class PTConstants: + PTServerName = "server" + PTFileModelName = "FL_global_model.pt" + PTLocalModelName = "local_model.pt" + + PTModelsDir = "models" diff --git a/examples/advanced/job_config/hello-pt/pt_model_locator.py b/examples/advanced/job_config/hello-pt/pt_model_locator.py new file mode 100644 index 0000000000..8681b2b529 --- /dev/null +++ b/examples/advanced/job_config/hello-pt/pt_model_locator.py @@ -0,0 +1,66 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import List, Union + +import torch.cuda +from pt_constants import PTConstants +from simple_network import SimpleNetwork + +from nvflare.apis.dxo import DXO +from nvflare.apis.fl_context import FLContext +from nvflare.app_common.abstract.model import model_learnable_to_dxo +from nvflare.app_common.abstract.model_locator import ModelLocator +from nvflare.app_opt.pt.model_persistence_format_manager import PTModelPersistenceFormatManager + + +class PTModelLocator(ModelLocator): + def __init__(self): + super().__init__() + self.model = SimpleNetwork() + + def get_model_names(self, fl_ctx: FLContext) -> List[str]: + return [PTConstants.PTServerName] + + def locate_model(self, model_name, fl_ctx: FLContext) -> Union[DXO, None]: + if model_name == PTConstants.PTServerName: + try: + server_run_dir = fl_ctx.get_engine().get_workspace().get_app_dir(fl_ctx.get_job_id()) + model_path = os.path.join(server_run_dir, PTConstants.PTFileModelName) + if not os.path.exists(model_path): + return None + + # Load the torch model + device = "cuda" if torch.cuda.is_available() else "cpu" + data = torch.load(model_path, map_location=device) + + # Set up the persistence manager. + if self.model: + default_train_conf = {"train": {"model": type(self.model).__name__}} + else: + default_train_conf = None + + # Use persistence manager to get learnable + persistence_manager = PTModelPersistenceFormatManager(data, default_train_conf=default_train_conf) + ml = persistence_manager.to_model_learnable(exclude_vars=None) + + # Create dxo and return + return model_learnable_to_dxo(ml) + except Exception as e: + self.log_error(fl_ctx, f"Error in retrieving {model_name}: {e}.", fire_event=False) + return None + else: + self.log_error(fl_ctx, f"PTModelLocator doesn't recognize name: {model_name}", fire_event=False) + return None diff --git a/examples/advanced/job_config/hello-pt/simple_network.py b/examples/advanced/job_config/hello-pt/simple_network.py new file mode 100644 index 0000000000..0f2d2bbe08 --- /dev/null +++ b/examples/advanced/job_config/hello-pt/simple_network.py @@ -0,0 +1,38 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class SimpleNetwork(nn.Module): + def __init__(self): + super(SimpleNetwork, self).__init__() + + self.conv1 = nn.Conv2d(3, 6, 5) + self.pool = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(6, 16, 5) + self.fc1 = nn.Linear(16 * 5 * 5, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + + def forward(self, x): + x = self.pool(F.relu(self.conv1(x))) + x = self.pool(F.relu(self.conv2(x))) + x = torch.flatten(x, 1) # flatten all dimensions except batch + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + return x diff --git a/nvflare/app_common/aggregators/intime_accumulate_model_aggregator.py b/nvflare/app_common/aggregators/intime_accumulate_model_aggregator.py index e74b342231..ff628334ec 100644 --- a/nvflare/app_common/aggregators/intime_accumulate_model_aggregator.py +++ b/nvflare/app_common/aggregators/intime_accumulate_model_aggregator.py @@ -15,6 +15,7 @@ from typing import Any, Dict, Union from nvflare.apis.dxo import DXO, DataKind, from_shareable +from nvflare.apis.event_type import EventType from nvflare.apis.fl_constant import ReservedKey, ReturnCode from nvflare.apis.fl_context import FLContext from nvflare.apis.shareable import Shareable @@ -83,6 +84,18 @@ def __init__( self._single_dxo_key = "" self._weigh_by_local_iter = weigh_by_local_iter + self.aggregation_weights = aggregation_weights + self.exclude_vars = exclude_vars + self.expected_data_kind = expected_data_kind + + def handle_event(self, event_type: str, fl_ctx: FLContext): + # _initialize() can not be called from the constructor. Because it changes the data, even the data format + # of the aggregation_weights and exclude_vars parameters. Inspect could not figure out the passed in + # parameters when re-construct the object creation configuration. + if event_type == EventType.START_RUN: + self._initialize(self.aggregation_weights, self.exclude_vars, self.expected_data_kind) + + def _initialize(self, aggregation_weights, exclude_vars, expected_data_kind): # Check expected data kind if isinstance(expected_data_kind, dict): for k, v in expected_data_kind.items(): @@ -97,7 +110,6 @@ def __init__( f"expected_data_kind = {expected_data_kind} is not {DataKind.WEIGHT_DIFF} or {DataKind.WEIGHTS} or {DataKind.METRICS}" ) self.expected_data_kind = {self._single_dxo_key: expected_data_kind} - # Check exclude_vars if exclude_vars: if not isinstance(exclude_vars, dict) and not isinstance(exclude_vars, str): @@ -111,7 +123,6 @@ def __init__( "A dict exclude_vars should specify exclude_vars for every key in expected_data_kind. " f"But missed these keys: {missing_keys}" ) - exclude_vars_dict = dict() for k in self.expected_data_kind.keys(): if isinstance(exclude_vars, dict): @@ -127,7 +138,6 @@ def __init__( if self._single_dxo_key in self.expected_data_kind: exclude_vars_dict[self._single_dxo_key] = exclude_vars self.exclude_vars = exclude_vars_dict - # Check aggregation weights if _is_nested_aggregation_weights(aggregation_weights): missing_keys = _get_missing_keys(expected_data_kind, aggregation_weights) @@ -136,7 +146,6 @@ def __init__( "A dict of dict aggregation_weights should specify aggregation_weights " f"for every key in expected_data_kind. But missed these keys: {missing_keys}" ) - aggregation_weights = aggregation_weights or {} aggregation_weights_dict = dict() for k in self.expected_data_kind.keys(): @@ -146,7 +155,6 @@ def __init__( # assume same aggregation weights for each entry of DXO collection. aggregation_weights_dict[k] = aggregation_weights self.aggregation_weights = aggregation_weights_dict - # Set up DXO aggregators self.dxo_aggregators = dict() for k in self.expected_data_kind.keys(): diff --git a/nvflare/app_common/executors/in_process_client_api_executor.py b/nvflare/app_common/executors/in_process_client_api_executor.py index 0d70a61a1f..eb91decab1 100644 --- a/nvflare/app_common/executors/in_process_client_api_executor.py +++ b/nvflare/app_common/executors/in_process_client_api_executor.py @@ -70,7 +70,7 @@ def __init__( raise ValueError(f"invalid task_script_path '{task_script_path}'") # only support main() for backward compatibility - self._task_fn_path = task_script_path.replace(".py", ".main") + self._task_script_path = task_script_path self._task_script_args = task_script_args self._task_wait_time = task_wait_time @@ -85,9 +85,6 @@ def __init__( self._to_nvflare_converter_id = to_nvflare_converter_id self._to_nvflare_converter: Optional[ParamsConverter] = None - self._task_fn_wrapper = ExecTaskFuncWrapper( - task_fn_path=self._task_fn_path, task_main_args=self._task_script_args - ) self._engine = None self._task_fn_thread = None self._log_thread = None @@ -97,6 +94,8 @@ def __init__( self._data_bus.subscribe([TOPIC_LOG_DATA], self.log_result_callback) self.local_result = None self._fl_ctx = None + self._task_fn_path = None + self._task_fn_wrapper = None def handle_event(self, event_type: str, fl_ctx: FLContext): if event_type == EventType.START_RUN: @@ -105,6 +104,11 @@ def handle_event(self, event_type: str, fl_ctx: FLContext): self._fl_ctx = fl_ctx self._init_converter(fl_ctx) + self._task_fn_path = self._task_script_path.replace(".py", ".main") + self._task_fn_wrapper = ExecTaskFuncWrapper( + task_fn_path=self._task_fn_path, task_main_args=self._task_script_args + ) + self._task_fn_thread = threading.Thread(target=self._task_fn_wrapper.run) self._task_fn_thread.start() diff --git a/nvflare/app_opt/pt/in_process_client_api_executor.py b/nvflare/app_opt/pt/in_process_client_api_executor.py index 99c0e52ef2..b31702e363 100644 --- a/nvflare/app_opt/pt/in_process_client_api_executor.py +++ b/nvflare/app_opt/pt/in_process_client_api_executor.py @@ -36,6 +36,7 @@ def __init__( train_task_name: str = "train", evaluate_task_name: str = "evaluate", submit_model_task_name: str = "submit_model", + params_exchange_format=ExchangeFormat.PYTORCH, ): super(PTInProcessClientAPIExecutor, self).__init__( task_script_path=task_script_path, @@ -48,7 +49,7 @@ def __init__( submit_model_task_name=submit_model_task_name, from_nvflare_converter_id=from_nvflare_converter_id, to_nvflare_converter_id=to_nvflare_converter_id, - params_exchange_format=ExchangeFormat.PYTORCH, + params_exchange_format=params_exchange_format, params_transfer_type=params_transfer_type, log_pull_interval=log_pull_interval, ) diff --git a/nvflare/job_config/__init__.py b/nvflare/job_config/__init__.py new file mode 100644 index 0000000000..d9155f923f --- /dev/null +++ b/nvflare/job_config/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nvflare/job_config/base_app_config.py b/nvflare/job_config/base_app_config.py new file mode 100644 index 0000000000..e5c021d987 --- /dev/null +++ b/nvflare/job_config/base_app_config.py @@ -0,0 +1,72 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os.path +from abc import ABC +from typing import Dict, List + +from nvflare.apis.filter import Filter +from nvflare.apis.fl_component import FLComponent + + +class BaseAppConfig(ABC): + """BaseAppConfig holds the base essential component data for the ServerApp and ClientApp, including the + task_data_filters, task_result_filters, system components and used external scripts. + + """ + + def __init__(self) -> None: + super().__init__() + + self.task_data_filters: [(List[str], Filter)] = [] + self.task_result_filters: [(List[str], Filter)] = [] + self.components: Dict[str, object] = {} + self.ext_scripts = [] + + self.handlers: [FLComponent] = [] + + def add_component(self, cid: str, component): + if cid in self.components.keys(): + raise RuntimeError(f"Component with ID:{cid} already exist.") + + self.components[cid] = component + + if isinstance(component, FLComponent): + self.handlers.append(component) + + def add_task_data_filter(self, tasks: List[str], filter: Filter): + self._add_task_filter(tasks, filter, self.task_data_filters) + + def add_task_result_filter(self, tasks: List[str], filter: Filter): + self._add_task_filter(tasks, filter, self.task_result_filters) + + def add_ext_script(self, ext_script: str): + if not isinstance(ext_script, str): + raise RuntimeError(f"ext_script must be type of str, but got {ext_script.__class__}") + + if not os.path.exists(ext_script): + raise RuntimeError(f"Could not locate external script: {ext_script}") + + if not ext_script.endswith(".py"): + raise RuntimeError(f"External script: {ext_script} must be a '.py' file.") + + self.ext_scripts.append(ext_script) + + def _add_task_filter(self, tasks, filter, filters): + if not isinstance(filter, Filter): + raise RuntimeError(f"filter must be type of Filter, but got {filter.__class__}") + for task in tasks: + for fd in filters: + if task in fd.tasks: + raise RuntimeError(f"Task {task} already defined in the task filters.") + filters.append((tasks, filter)) diff --git a/nvflare/job_config/fed_app_config.py b/nvflare/job_config/fed_app_config.py new file mode 100644 index 0000000000..3fe76159f1 --- /dev/null +++ b/nvflare/job_config/fed_app_config.py @@ -0,0 +1,88 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List + +from nvflare.apis.executor import Executor +from nvflare.apis.impl.controller import Controller +from nvflare.apis.impl.wf_comm_server import WFCommServer +from nvflare.job_config.base_app_config import BaseAppConfig +from nvflare.private.fed.client.client_json_config import _ExecutorDef +from nvflare.private.fed.server.server_json_config import WorkFlow + + +class ClientAppConfig(BaseAppConfig): + """ClientAppConfig represents the ClientApp inside the Job. It holds the BaseAppConfig components and the task + executors components data for the ClientApp. + + """ + + def __init__(self) -> None: + super().__init__() + + self.executors: [_ExecutorDef] = [] + + def add_executor(self, tasks: List[str], executor: Executor): + if not isinstance(executor, Executor): + raise RuntimeError(f"workflow must be type of Executor, but got {executor.__class__}") + + e = _ExecutorDef() + e.tasks = tasks + e.executor = executor + self.executors.append(e) + + +class ServerAppConfig(BaseAppConfig): + """ServerAppConfig represents the ServerApp inside the Job. it holds the BaseAppConfig components and the + workflow components data for the ServerApp. + + """ + + def __init__(self) -> None: + super().__init__() + + self.workflows: [Controller] = [] + self.ids = [] + + def add_workflow(self, cid, controller: Controller): + if not isinstance(controller, Controller): + raise RuntimeError(f"workflow must be type of Controller, but got {controller.__class__}") + + # self.add_component(cid, controller) + if cid in self.components.keys() or cid in self.ids: + raise RuntimeError(f"Component with ID:{cid} already exist.") + + communicator = WFCommServer() + self.handlers.append(communicator) + controller.set_communicator(communicator) + + self.workflows.append(WorkFlow(cid, controller)) + self.ids.append(cid) + + +class FedAppConfig: + """FedAppConfig represents the App information inside the Job. It contains either a ServerApp, or a ClientApp, or + both of them. + + """ + + def __init__(self, server_app: ServerAppConfig = None, client_app: ClientAppConfig = None) -> None: + super().__init__() + + if server_app and not isinstance(server_app, ServerAppConfig): + raise ValueError(f"server_app must be type of ServerAppConfig, but got {server_app.__class__}") + if client_app and not isinstance(client_app, ClientAppConfig): + raise ValueError(f"client_app must be type of ClientAppConfig, but got {client_app.__class__}") + + self.server_app: ServerAppConfig = server_app + self.client_app: ClientAppConfig = client_app diff --git a/nvflare/job_config/fed_job_config.py b/nvflare/job_config/fed_job_config.py new file mode 100644 index 0000000000..793e40e0c2 --- /dev/null +++ b/nvflare/job_config/fed_job_config.py @@ -0,0 +1,321 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import builtins +import inspect +import json +import os +import shutil +from enum import Enum +from typing import Dict + +from nvflare import SimulatorRunner +from nvflare.job_config.fed_app_config import FedAppConfig +from nvflare.private.fed.app.fl_conf import FL_PACKAGES + +CONFIG = "config" +CUSTOM = "custom" +FED_SERVER_JSON = "config_fed_server.json" +FED_CLIENT_JSON = "config_fed_client.json" +META_JSON = "meta.json" + + +class FedJobConfig: + """FedJobConfig represents the job in the NVFlare.""" + + def __init__(self, job_name, min_clients, mandatory_clients=None) -> None: + """FedJobConfig uses the job_name, min_clients and optional mandatory_clients to create the object. + It also provides the method to add in the FedApp, the deployment map of the FedApp and participants, + and the resource _spec requirements of the participants if needed. + + Args: + job_name: the name of the NVFlare job + min_clients: the minimum number of clients for the job + mandatory_clients: mandatory clients to run the job (optional) + """ + super().__init__() + + self.job_name = job_name + self.min_clients = min_clients + self.mandatory_clients = mandatory_clients + + self.fed_apps: Dict[str, FedAppConfig] = {} + self.deploy_map: Dict[str, str] = {} + self.resource_specs: Dict[str, Dict] = {} + + self.custom_modules = [] + + def add_fed_app(self, app_name: str, fed_app: FedAppConfig): + if not isinstance(fed_app, FedAppConfig): + raise RuntimeError(f"server_app must be type of FedAppConfig, but got {fed_app.__class__}") + + self.fed_apps[app_name] = fed_app + + def set_site_app(self, site_name: str, app_name: str): + if app_name not in self.fed_apps.keys(): + raise RuntimeError(f"fed_app {app_name} does not exist.") + + self.deploy_map[site_name] = app_name + + def add_resource_spec(self, site_name: str, resource_spec: Dict): + if site_name in self.resource_specs.keys(): + raise RuntimeError(f"{site_name} resource specs already exist.") + if not isinstance(resource_spec, dict): + raise RuntimeError(f"resource_spec must be a dict. But got: {resource_spec.__class__}") + + self.resource_specs[site_name] = resource_spec + + def _generate_meta(self, job_dir): + """generate the job meta.json + + Returns: + + """ + meta_file = os.path.join(job_dir, META_JSON) + meta_json = { + "name": self.job_name, + "resource_spec": self.resource_specs, + "min_clients": self.min_clients, + "deploy_map": self._get_deploy_map(), + } + if self.mandatory_clients: + meta_json["mandatory_clients"] = self.mandatory_clients + + with open(meta_file, "w") as outfile: + json_dump = json.dumps(meta_json, indent=4) + outfile.write(json_dump) + + def generate_job_config(self, job_root): + """generate the job config + + Returns: + + """ + job_dir = os.path.join(job_root, self.job_name) + if os.path.exists(job_dir): + shutil.rmtree(job_dir, ignore_errors=True) + + for app_name, fed_app in self.fed_apps.items(): + self.custom_modules = [] + config_dir = os.path.join(job_dir, app_name, CONFIG) + custom_dir = os.path.join(job_dir, app_name, CUSTOM) + os.makedirs(config_dir, exist_ok=True) + + if fed_app.server_app: + self._get_server_app(config_dir, custom_dir, fed_app) + + if fed_app.client_app: + self._get_client_app(config_dir, custom_dir, fed_app) + + self._generate_meta(job_dir) + + def simulator_run(self, job_root, workspace, clients=None, n_clients=None, threads=None, gpu=None): + self.generate_job_config(job_root) + + simulator = SimulatorRunner( + job_folder=os.path.join(job_root, self.job_name), + workspace=workspace, + clients=clients, + n_clients=n_clients, + threads=threads, + gpu=gpu, + ) + simulator.run() + + def _get_server_app(self, config_dir, custom_dir, fed_app): + server_app = {"format_version": 2, "workflows": []} + for workflow in fed_app.server_app.workflows: + server_app["workflows"].append( + { + "id": workflow.id, + "path": self._get_class_path(workflow.controller, custom_dir), + "args": self._get_args(workflow.controller, custom_dir), + } + ) + self._get_base_app(custom_dir, fed_app.server_app, server_app) + server_config = os.path.join(config_dir, FED_SERVER_JSON) + with open(server_config, "w") as outfile: + json_dump = json.dumps(server_app, indent=4) + outfile.write(json_dump) + + self._copy_ext_scripts(custom_dir, fed_app.server_app.ext_scripts) + + def _copy_ext_scripts(self, custom_dir, ext_scripts): + for script in ext_scripts: + dest_file = os.path.join(custom_dir, script) + module = "".join(script.rsplit(".py", 1)).replace(os.sep, ".") + self._copy_source_file(custom_dir, module, script, dest_file) + + def _get_class_path(self, obj, custom_dir): + module = obj.__module__ + source_file = inspect.getsourcefile(obj.__class__) + self._get_custom_file(custom_dir, module, source_file) + + return obj.__module__ + "." + obj.__class__.__name__ + + def _get_custom_file(self, custom_dir, module, source_file): + package = module.split(".")[0] + if os.path.exists(source_file): + if package not in FL_PACKAGES and module not in self.custom_modules: + module_path = module.replace(".", os.sep) + if module_path in source_file: + index = source_file.index(module_path) + dest = source_file[index:] + + self.custom_modules.append(module) + os.makedirs(custom_dir, exist_ok=True) + # dest_file = os.path.join(custom_dir, module.replace(".", os.sep) + ".py") + dest_file = os.path.join(custom_dir, dest) + + self._copy_source_file(custom_dir, module, source_file, dest_file) + + def _copy_source_file(self, custom_dir, module, source_file, dest_file): + os.makedirs(custom_dir, exist_ok=True) + source_dir = os.path.dirname(source_file) + with open(source_file, "r") as sf: + import_lines = list(self.locate_imports(sf, dest_file)) + for line in import_lines: + import_module = line.split(" ")[1] + + import_source = import_module + if import_module.startswith("."): + import_source = import_source[1:] + new_module = module.split(".")[0:-1] + new_module.append(import_source) + import_module = ".".join(new_module) + + import_source_file = os.path.join(source_dir, import_source.replace(".", os.sep) + ".py") + if os.path.exists(import_source_file): + self._get_custom_file(custom_dir, import_module, import_source_file) + + def _get_client_app(self, config_dir, custom_dir, fed_app): + client_app = {"format_version": 2, "executors": []} + for e in fed_app.client_app.executors: + client_app["executors"].append( + { + "tasks": e.tasks, + "executor": { + "path": self._get_class_path(e.executor, custom_dir), + "args": self._get_args(e.executor, custom_dir), + }, + } + ) + self._get_base_app(custom_dir, fed_app.client_app, client_app) + client_config = os.path.join(config_dir, FED_CLIENT_JSON) + with open(client_config, "w") as outfile: + json_dump = json.dumps(client_app, indent=4) + outfile.write(json_dump) + + self._copy_ext_scripts(custom_dir, fed_app.client_app.ext_scripts) + + def _get_base_app(self, custom_dir, app, app_config): + app_config["components"] = [] + for cid, component in app.components.items(): + app_config["components"].append( + { + "id": cid, + "path": self._get_class_path(component, custom_dir), + "args": self._get_args(component, custom_dir), + } + ) + app_config["task_data_filters"] = [] + for tasks, filter in app.task_data_filters: + app_config["task_data_filters"].append( + { + "tasks": tasks, + "filters": [ + { + # self._get_filters(task_filter.filter, custom_dir) + "path": self._get_class_path(filter, custom_dir), + "args": self._get_args(filter, custom_dir), + } + ], + } + ) + app_config["task_result_filters"] = [] + for tasks, filter in app.task_result_filters: + app_config["task_result_filters"].append( + { + "tasks": tasks, + "filters": [ + { + # self._get_filters(result_filer.filter, custom_dir) + "path": self._get_class_path(filter, custom_dir), + "args": self._get_args(filter, custom_dir), + } + ], + } + ) + + def _get_args(self, component, custom_dir): + parameters = self._get_init_parameters(component) + attrs = component.__dict__ + args = {} + + for param in parameters: + attr_key = param if param in attrs.keys() else "_" + param + + if attr_key in ["args", "kwargs"]: + continue + + if attr_key in attrs.keys() and parameters[param].default != attrs[attr_key]: + if type(attrs[attr_key]).__name__ in dir(builtins): + args[param] = attrs[attr_key] + elif issubclass(attrs[attr_key].__class__, Enum): + args[param] = attrs[attr_key].value + else: + args[param] = { + "path": self._get_class_path(attrs[attr_key], custom_dir), + "args": self._get_args(attrs[attr_key], custom_dir), + } + + return args + + def _get_init_parameters(self, component): + class__ = component.__class__ + parameters = {} + self._retrieve_parameters(class__, parameters) + return parameters + + def _retrieve_parameters(self, class__, parameters): + constructor = class__.__init__ + constructor__parameters = inspect.signature(constructor).parameters + parameters.update(constructor__parameters) + if "args" in constructor__parameters.keys() and "kwargs" in constructor__parameters.keys(): + for item in class__.__bases__: + parameters.update(self._retrieve_parameters(item, parameters)) + return parameters + + def _get_filters(self, filters, custom_dir): + r = [] + for f in filters: + r.append({"path": self._get_class_path(f, custom_dir), "args": self._get_args(f, custom_dir)}) + return r + + def locate_imports(self, sf, dest_file): + os.makedirs(os.path.dirname(dest_file), exist_ok=True) + with open(dest_file, "w") as df: + for line in sf: + df.write(line) + trimmed = line.strip() + if trimmed.startswith("from ") and ("import " in trimmed): + yield trimmed + elif trimmed.startswith("import "): + yield trimmed + + def _get_deploy_map(self): + deploy_map = {} + for site, app_name in self.deploy_map.items(): + deploy_map[app_name] = deploy_map.get(app_name, []) + deploy_map[app_name].append(site) + return deploy_map diff --git a/tests/unit_test/app_common/aggregators/in_time_accumulate_weighted_aggregator_test.py b/tests/unit_test/app_common/aggregators/in_time_accumulate_weighted_aggregator_test.py index e96c81ea5f..3f5ae02e4c 100644 --- a/tests/unit_test/app_common/aggregators/in_time_accumulate_weighted_aggregator_test.py +++ b/tests/unit_test/app_common/aggregators/in_time_accumulate_weighted_aggregator_test.py @@ -72,11 +72,14 @@ class TestInTimeAccumulateWeightedAggregator: ) def test_invalid_create(self, exclude_vars, aggregation_weights, expected_data_kind, error, error_msg): with pytest.raises(error, match=re.escape(error_msg)): - _ = InTimeAccumulateWeightedAggregator( + aggregator = InTimeAccumulateWeightedAggregator( exclude_vars=exclude_vars, aggregation_weights=aggregation_weights, expected_data_kind=expected_data_kind, ) + aggregator._initialize( + aggregator.aggregation_weights, aggregator.exclude_vars, aggregator.expected_data_kind + ) @pytest.mark.parametrize( "exclude_vars,aggregation_weights,expected_data_kind,expected_object", @@ -115,9 +118,13 @@ def test_invalid_create(self, exclude_vars, aggregation_weights, expected_data_k ], ) def test_create(self, exclude_vars, aggregation_weights, expected_data_kind, expected_object): + expected_object._initialize( + expected_object.aggregation_weights, expected_object.exclude_vars, expected_object.expected_data_kind + ) result = InTimeAccumulateWeightedAggregator( exclude_vars=exclude_vars, aggregation_weights=aggregation_weights, expected_data_kind=expected_data_kind ) + result._initialize(result.aggregation_weights, result.exclude_vars, result.expected_data_kind) assert result.exclude_vars == expected_object.exclude_vars assert result.aggregation_weights == expected_object.aggregation_weights assert result.expected_data_kind == expected_object.expected_data_kind @@ -126,6 +133,7 @@ def test_create(self, exclude_vars, aggregation_weights, expected_data_kind, exp def test_accept(self, current_round, contribution_round, expected): aggregation_weights = {f"client_{i}": random.random() for i in range(2)} agg = InTimeAccumulateWeightedAggregator(aggregation_weights=aggregation_weights) + agg._initialize(agg.aggregation_weights, agg.exclude_vars, agg.expected_data_kind) client_name = "client_0" iter_number = 1 weights = np.random.random(4) @@ -192,6 +200,7 @@ def test_accept(self, current_round, contribution_round, expected): def test_aggregate(self, received, expected): aggregation_weights = {k: v["weight"] for k, v in received.items()} agg = InTimeAccumulateWeightedAggregator(aggregation_weights=aggregation_weights) + agg._initialize(agg.aggregation_weights, agg.exclude_vars, agg.expected_data_kind) fl_ctx = FLContext() fl_ctx.set_prop(AppConstants.CURRENT_ROUND, 0) for k, v in received.items(): @@ -216,6 +225,7 @@ def test_aggregate(self, received, expected): def test_aggregate_random(self, shape, n_clients): aggregation_weights = {f"client_{i}": random.random() for i in range(n_clients)} agg = InTimeAccumulateWeightedAggregator(aggregation_weights=aggregation_weights) + agg._initialize(agg.aggregation_weights, agg.exclude_vars, agg.expected_data_kind) weighted_sum = np.zeros(shape) sum_of_weights = 0 fl_ctx = FLContext() @@ -254,6 +264,7 @@ def test_aggregate_random_dxos(self, num_dxo, shape, n_clients): aggregation_weights=aggregation_weights, expected_data_kind={dxo_name: DataKind.WEIGHT_DIFF for dxo_name in dxo_names}, ) + agg._initialize(agg.aggregation_weights, agg.exclude_vars, agg.expected_data_kind) weighted_sum = {dxo_name: np.zeros(shape) for dxo_name in dxo_names} sum_of_weights = {dxo_name: 0 for dxo_name in dxo_names} fl_ctx = FLContext() From d4afbee5faa0d62ab76f06f3890f8ddf9a34ef79 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yuan-Ting=20Hsieh=20=28=E8=AC=9D=E6=B2=85=E5=BB=B7=29?= Date: Thu, 18 Apr 2024 17:22:12 -0700 Subject: [PATCH 11/14] Enhancements from 2.4 (#2519) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Starts heartbeat after task is pull and before task execution (#2415) * Starts pipe handler heartbeat send/check after task is pull before task execution (#2442) * [2.4] Improve cell pipe timeout handling (#2441) * improve cell pipe timeout handling * improved end and abort handling * improve timeout handling --------- Co-authored-by: Yuan-Ting Hsieh (謝沅廷) * [2.4] Enhance launcher executor (#2433) * Update LauncherExecutor logs and execution setup timeout * Change name * [2.4] Fire and forget for pipe handler control messages (#2413) * Fire and forget for pipe handler control messages * Add default timeout value * fix wait-for-reply (#2478) * Fix pipe handler timeout in task exchanger and launcher executor (#2495) * Fix metric relay pipe handler timeout (#2496) * Rely on launcher check_run_status to pause/resume hb (#2502) Co-authored-by: Chester Chen <512707+chesterxgchen@users.noreply.github.com> --------- Co-authored-by: Yan Cheng <58191769+yanchengnv@users.noreply.github.com> Co-authored-by: Chester Chen <512707+chesterxgchen@users.noreply.github.com> --- .../executors/client_api_launcher_executor.py | 33 ++--- .../app_common/executors/launcher_executor.py | 67 ++++++----- .../app_common/executors/task_exchanger.py | 17 +-- nvflare/app_common/widgets/metric_relay.py | 4 +- nvflare/fuel/f3/cellnet/cell.py | 113 ++++++++++++------ nvflare/fuel/utils/pipe/cell_pipe.py | 82 +++++++++---- nvflare/fuel/utils/pipe/file_pipe.py | 6 +- nvflare/fuel/utils/pipe/pipe.py | 8 ++ nvflare/fuel/utils/pipe/pipe_handler.py | 20 +++- 9 files changed, 232 insertions(+), 118 deletions(-) diff --git a/nvflare/app_common/executors/client_api_launcher_executor.py b/nvflare/app_common/executors/client_api_launcher_executor.py index f15dad07b8..7cb631b4d5 100644 --- a/nvflare/app_common/executors/client_api_launcher_executor.py +++ b/nvflare/app_common/executors/client_api_launcher_executor.py @@ -31,12 +31,12 @@ def __init__( launch_timeout: Optional[float] = None, task_wait_timeout: Optional[float] = None, last_result_transfer_timeout: float = 300.0, - external_execution_wait: float = 5.0, - peer_read_timeout: Optional[float] = None, + external_pre_init_timeout: float = 60.0, + peer_read_timeout: Optional[float] = 60.0, monitor_interval: float = 0.01, read_interval: float = 0.5, heartbeat_interval: float = 5.0, - heartbeat_timeout: float = 30.0, + heartbeat_timeout: float = 60.0, workers: int = 4, train_with_evaluation: bool = True, train_task_name: str = "train", @@ -51,22 +51,23 @@ def __init__( """Initializes the ClientAPILauncherExecutor. Args: - pipe_id (Optional[str]): Identifier for obtaining the Pipe from NVFlare components. + pipe_id (str): Identifier for obtaining the Pipe from NVFlare components. launcher_id (Optional[str]): Identifier for obtaining the Launcher from NVFlare components. launch_timeout (Optional[float]): Timeout for the Launcher's "launch_task" method to complete (None for no timeout). task_wait_timeout (Optional[float]): Timeout for retrieving the task result (None for no timeout). - last_result_transfer_timeout (float): Timeout for transmitting the last result from an external process (default: 5.0). + last_result_transfer_timeout (float): Timeout for transmitting the last result from an external process. This value should be greater than the time needed for sending the whole result. - peer_read_timeout (Optional[float]): Timeout for waiting the task to be read by the peer from the pipe (None for no timeout). - monitor_interval (float): Interval for monitoring the launcher (default: 0.01). - read_interval (float): Interval for reading from the pipe (default: 0.5). - heartbeat_interval (float): Interval for sending heartbeat to the peer (default: 5.0). - heartbeat_timeout (float): Timeout for waiting for a heartbeat from the peer (default: 30.0). - workers (int): Number of worker threads needed (default: 4). - train_with_evaluation (bool): Whether to run training with global model evaluation (default: True). - train_task_name (str): Task name of train mode (default: train). - evaluate_task_name (str): Task name of evaluate mode (default: evaluate). - submit_model_task_name (str): Task name of submit_model mode (default: submit_model). + external_pre_init_timeout (float): Time to wait for external process before it calls flare.init(). + peer_read_timeout (float, optional): time to wait for peer to accept sent message. + monitor_interval (float): Interval for monitoring the launcher. + read_interval (float): Interval for reading from the pipe. + heartbeat_interval (float): Interval for sending heartbeat to the peer. + heartbeat_timeout (float): Timeout for waiting for a heartbeat from the peer. + workers (int): Number of worker threads needed. + train_with_evaluation (bool): Whether to run training with global model evaluation. + train_task_name (str): Task name of train mode. + evaluate_task_name (str): Task name of evaluate mode. + submit_model_task_name (str): Task name of submit_model mode. from_nvflare_converter_id (Optional[str]): Identifier used to get the ParamsConverter from NVFlare components. This ParamsConverter will be called when model is sent from nvflare controller side to executor side. to_nvflare_converter_id (Optional[str]): Identifier used to get the ParamsConverter from NVFlare components. @@ -83,7 +84,7 @@ def __init__( launch_timeout=launch_timeout, task_wait_timeout=task_wait_timeout, last_result_transfer_timeout=last_result_transfer_timeout, - external_execution_wait=external_execution_wait, + external_pre_init_timeout=external_pre_init_timeout, peer_read_timeout=peer_read_timeout, monitor_interval=monitor_interval, read_interval=read_interval, diff --git a/nvflare/app_common/executors/launcher_executor.py b/nvflare/app_common/executors/launcher_executor.py index d4dfe964df..db181d5279 100644 --- a/nvflare/app_common/executors/launcher_executor.py +++ b/nvflare/app_common/executors/launcher_executor.py @@ -42,13 +42,13 @@ def __init__( launch_timeout: Optional[float] = None, task_wait_timeout: Optional[float] = None, last_result_transfer_timeout: float = 300.0, - external_execution_wait: float = 5.0, - peer_read_timeout: Optional[float] = None, - monitor_interval: float = 1.0, + external_pre_init_timeout: float = 60.0, + peer_read_timeout: Optional[float] = 60.0, + monitor_interval: float = 0.1, read_interval: float = 0.5, heartbeat_interval: float = 5.0, - heartbeat_timeout: float = 30.0, - workers: int = 1, + heartbeat_timeout: float = 60.0, + workers: int = 4, train_with_evaluation: bool = True, train_task_name: str = "train", evaluate_task_name: str = "evaluate", @@ -63,18 +63,19 @@ def __init__( launcher_id (Optional[str]): Identifier for obtaining the Launcher from NVFlare components. launch_timeout (Optional[float]): Timeout for the Launcher's "launch_task" method to complete (None for no timeout). task_wait_timeout (Optional[float]): Timeout for retrieving the task result (None for no timeout). - last_result_transfer_timeout (float): Timeout for transmitting the last result from an external process (default: 5.0). + last_result_transfer_timeout (float): Timeout for transmitting the last result from an external process. This value should be greater than the time needed for sending the whole result. - peer_read_timeout (Optional[float]): Timeout for waiting the task to be read by the peer from the pipe (None for no timeout). - monitor_interval (float): Interval for monitoring the launcher (default: 0.01). - read_interval (float): Interval for reading from the pipe (default: 0.5). - heartbeat_interval (float): Interval for sending heartbeat to the peer (default: 5.0). - heartbeat_timeout (float): Timeout for waiting for a heartbeat from the peer (default: 30.0). - workers (int): Number of worker threads needed (default: 1). - train_with_evaluation (bool): Whether to run training with global model evaluation (default: True). - train_task_name (str): Task name of train mode (default: train). - evaluate_task_name (str): Task name of evaluate mode (default: evaluate). - submit_model_task_name (str): Task name of submit_model mode (default: submit_model). + external_pre_init_timeout (float): Time to wait for external process before it calls flare.init(). + peer_read_timeout (float, optional): time to wait for peer to accept sent message. + monitor_interval (float): Interval for monitoring the launcher. + read_interval (float): Interval for reading from the pipe. + heartbeat_interval (float): Interval for sending heartbeat to the peer. + heartbeat_timeout (float): Timeout for waiting for a heartbeat from the peer. + workers (int): Number of worker threads needed. + train_with_evaluation (bool): Whether to run training with global model evaluation. + train_task_name (str): Task name of train mode. + evaluate_task_name (str): Task name of evaluate mode. + submit_model_task_name (str): Task name of submit_model mode. from_nvflare_converter_id (Optional[str]): Identifier used to get the ParamsConverter from NVFlare components. This ParamsConverter will be called when model is sent from nvflare controller side to executor side. to_nvflare_converter_id (Optional[str]): Identifier used to get the ParamsConverter from NVFlare components. @@ -96,7 +97,7 @@ def __init__( self._launcher_finish = False self._launcher_finish_time = None self._last_result_transfer_timeout = last_result_transfer_timeout - self._external_execution_wait = external_execution_wait + self._external_pre_init_timeout = external_pre_init_timeout self._received_result = Event() self._job_end = False @@ -248,7 +249,6 @@ def _initialize_external_execution( self.log_error(fl_ctx, "External execution set up failed.") abort_signal.trigger("External execution set up failed.") return False - time.sleep(self._external_execution_wait) return True def _execute_launcher_method_in_thread_executor(self, method_name: str, **kwargs) -> Any: @@ -276,17 +276,25 @@ def _execute_launcher_method_in_thread_executor(self, method_name: str, **kwargs def _wait_external_setup(self, task_name: str, fl_ctx: FLContext, abort_signal: Signal): start_time = time.time() while True: - if self._launch_timeout and time.time() - start_time >= self._launch_timeout: - self.log_error(fl_ctx, f"External execution is not set up within timeout: {self._launch_timeout}") + if self._external_pre_init_timeout and time.time() - start_time >= self._external_pre_init_timeout: + self.log_error( + fl_ctx, + f"External process has not called flare.init within timeout: {self._external_pre_init_timeout}", + ) return False if abort_signal.triggered: + self.log_info(fl_ctx, "External execution has not called flare.init but abort signal is triggered.") return False if self.peer_is_up_or_dead(): return True - if self.launcher.check_run_status(task_name, fl_ctx) != LauncherRunStatus.RUNNING: + run_status = self.launcher.check_run_status(task_name, fl_ctx) + if run_status != LauncherRunStatus.RUNNING: + self.log_info( + fl_ctx, f"External process has not called flare.init and run status becomes {run_status}." + ) return False time.sleep(0.1) @@ -294,18 +302,17 @@ def _wait_external_setup(self, task_name: str, fl_ctx: FLContext, abort_signal: def _finalize_external_execution( self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal ) -> bool: - with self._lock: - if self._job_end: - ask_peer_end_success = self.ask_peer_to_end(fl_ctx) - if not ask_peer_end_success: - return False + if self._job_end: + ask_peer_end_success = self.ask_peer_to_end(fl_ctx) + if not ask_peer_end_success: + return False check_run_status = self._execute_launcher_method_in_thread_executor( method_name="check_run_status", task_name=task_name, fl_ctx=fl_ctx, ) - if check_run_status != LauncherRunStatus.COMPLETE_SUCCESS: + if not self._received_result.is_set() and check_run_status != LauncherRunStatus.COMPLETE_SUCCESS: self.log_warning(fl_ctx, f"Try to stop task ({task_name}) when launcher run status is {check_run_status}") self.log_info(fl_ctx, f"Calling stop task ({task_name}).") @@ -365,9 +372,6 @@ def _monitor_launcher(self, fl_ctx: FLContext): if self.launcher is None: break - if self._current_task is None: - continue - task_name = self._current_task run_status = self._execute_launcher_method_in_thread_executor( method_name="check_run_status", @@ -381,16 +385,19 @@ def _monitor_launcher(self, fl_ctx: FLContext): continue elif run_status == LauncherRunStatus.NOT_RUNNING: + # pause pipe handler because external process is not running self.pause_pipe_handler() continue elif run_status == LauncherRunStatus.RUNNING: + # resume pipe handler when external process is running self.resume_pipe_handler() continue elif ( run_status == LauncherRunStatus.COMPLETE_FAILED or run_status == LauncherRunStatus.COMPLETE_SUCCESS ): + # pause pipe handler because external process is completed self.pause_pipe_handler() if not self._launcher_finish: self._launcher_finish_time = time.time() diff --git a/nvflare/app_common/executors/task_exchanger.py b/nvflare/app_common/executors/task_exchanger.py index e9053629a7..3d4eddf200 100644 --- a/nvflare/app_common/executors/task_exchanger.py +++ b/nvflare/app_common/executors/task_exchanger.py @@ -35,10 +35,10 @@ def __init__( pipe_id: str, read_interval: float = 0.5, heartbeat_interval: float = 5.0, - heartbeat_timeout: Optional[float] = 30.0, + heartbeat_timeout: Optional[float] = 60.0, resend_interval: float = 2.0, max_resends: Optional[int] = None, - peer_read_timeout: Optional[float] = 5.0, + peer_read_timeout: Optional[float] = 60.0, task_wait_time: Optional[float] = None, result_poll_interval: float = 0.5, pipe_channel_name=PipeChannelName.TASK, @@ -48,19 +48,16 @@ def __init__( Args: pipe_id (str): component id of pipe. read_interval (float): how often to read from pipe. - Defaults to 0.5. heartbeat_interval (float): how often to send heartbeat to peer. - Defaults to 5.0. heartbeat_timeout (float, optional): how long to wait for a heartbeat from the peer before treating the peer as dead, - 0 means DO NOT check for heartbeat. Defaults to 30.0. + 0 means DO NOT check for heartbeat. resend_interval (float): how often to resend a message if failing to send. None means no resend. Note that if the pipe does not support resending, - then no resend. Defaults to 2.0. + then no resend. max_resends (int, optional): max number of resend. None means no limit. Defaults to None. peer_read_timeout (float, optional): time to wait for peer to accept sent message. - Defaults to 5.0. task_wait_time (float, optional): how long to wait for a task to complete. None means waiting forever. Defaults to None. result_poll_interval (float): how often to poll task result. @@ -145,7 +142,7 @@ def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort task_id = shareable.get_header(key=FLContextKey.TASK_ID) # send to peer - self.log_debug(fl_ctx, "sending task to peer ...") + self.log_info(fl_ctx, f"sending task to peer {self.peer_read_timeout=}") req = Message.new_request(topic=task_name, data=shareable, msg_id=task_id) start_time = time.time() has_been_read = self.pipe_handler.send_to_peer(req, timeout=self.peer_read_timeout, abort_signal=abort_signal) @@ -156,6 +153,8 @@ def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort ) return make_reply(ReturnCode.EXECUTION_EXCEPTION) + self.log_info(fl_ctx, f"task {task_name} sent to peer in {time.time()-start_time} secs") + # wait for result self.log_debug(fl_ctx, "Waiting for result from peer") start = time.time() @@ -213,6 +212,8 @@ def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort if not self.check_output_shareable(task_name, result, fl_ctx): self.log_error(fl_ctx, "bad task result from peer") return make_reply(ReturnCode.EXECUTION_EXCEPTION) + + self.log_info(fl_ctx, f"received result of {task_name} from peer in {time.time()-start} secs") return result except Exception as ex: self.log_error(fl_ctx, f"Failed to convert result: {secure_format_exception(ex)}") diff --git a/nvflare/app_common/widgets/metric_relay.py b/nvflare/app_common/widgets/metric_relay.py index f896fc6a6b..cf321a59d4 100644 --- a/nvflare/app_common/widgets/metric_relay.py +++ b/nvflare/app_common/widgets/metric_relay.py @@ -33,7 +33,6 @@ def __init__( pipe_id: str, read_interval=0.1, heartbeat_interval=5.0, - heartbeat_timeout=30.0, pipe_channel_name=PipeChannelName.METRIC, event_type: str = ANALYTIC_EVENT_TYPE, fed_event: bool = True, @@ -42,7 +41,6 @@ def __init__( self.pipe_id = pipe_id self._read_interval = read_interval self._heartbeat_interval = heartbeat_interval - self._heartbeat_timeout = heartbeat_timeout self.pipe_channel_name = pipe_channel_name self.pipe = None self.pipe_handler = None @@ -64,7 +62,7 @@ def handle_event(self, event_type: str, fl_ctx: FLContext): pipe=self.pipe, read_interval=self._read_interval, heartbeat_interval=self._heartbeat_interval, - heartbeat_timeout=self._heartbeat_timeout, + heartbeat_timeout=0, ) self.pipe_handler.set_status_cb(self._pipe_status_cb) self.pipe_handler.set_message_cb(self._pipe_msg_cb) diff --git a/nvflare/fuel/f3/cellnet/cell.py b/nvflare/fuel/f3/cellnet/cell.py index 07f7a72c68..723b894f75 100644 --- a/nvflare/fuel/f3/cellnet/cell.py +++ b/nvflare/fuel/f3/cellnet/cell.py @@ -27,6 +27,7 @@ from nvflare.fuel.f3.streaming.stream_const import StreamHeaderKey from nvflare.fuel.f3.streaming.stream_types import StreamFuture from nvflare.private.defs import CellChannel +from nvflare.security.logging import secure_format_exception CHANNELS_TO_EXCLUDE = ( CellChannel.CLIENT_MAIN, @@ -233,15 +234,21 @@ def _get_result(self, req_id): return waiter.result def _future_wait(self, future, timeout): + # future could have an error! last_progress = 0 while not future.waiter.wait(timeout): + if future.error: + return False current_progress = future.get_progress() if last_progress == current_progress: return False else: self.logger.debug(f"{current_progress=}") last_progress = current_progress - return True + if future.error: + return False + else: + return True def _encode_message(self, msg: Message): try: @@ -250,11 +257,43 @@ def _encode_message(self, msg: Message): self.logger.error(f"Can't encode {msg=} {exc=}") raise exc - def _send_request(self, channel, target, topic, request, timeout=10.0, secure=False, optional=False): + def _send_request( + self, + channel, + target, + topic, + request, + timeout=10.0, + secure=False, + optional=False, + ): + """Stream one request to the target + + Args: + channel: message channel name + target: FQCN of the target cell + topic: topic of the message + request: request message + timeout: how long to wait + secure: is P2P security to be applied + optional: is the message optional + + Returns: reply data + + """ self._encode_message(request) return self._send_one_request(channel, target, topic, request, timeout, secure, optional) - def _send_one_request(self, channel, target, topic, request, timeout=10.0, secure=False, optional=False): + def _send_one_request( + self, + channel, + target, + topic, + request, + timeout=10.0, + secure=False, + optional=False, + ): req_id = str(uuid.uuid4()) request.add_headers({StreamHeaderKey.STREAM_REQ_ID: req_id}) @@ -263,42 +302,46 @@ def _send_one_request(self, channel, target, topic, request, timeout=10.0, secur waiter = SimpleWaiter(req_id=req_id, result=make_reply(ReturnCode.TIMEOUT)) self.requests_dict[req_id] = waiter - future = self.send_blob( - channel=channel, topic=topic, target=target, message=request, secure=secure, optional=optional - ) - - self.logger.debug(f"{req_id=}: Waiting starts") - # Three stages, sending, waiting for receiving first byte, receiving - - # sending with progress timeout - self.logger.debug(f"{req_id=}: entering sending wait {timeout=}") - sending_complete = self._future_wait(future, timeout) - if not sending_complete: - self.logger.info(f"{req_id=}: sending timeout {timeout=}") - return self._get_result(req_id) - self.logger.debug(f"{req_id=}: sending complete") + try: + future = self.send_blob( + channel=channel, topic=topic, target=target, message=request, secure=secure, optional=optional + ) - # waiting for receiving first byte - self.logger.debug(f"{req_id=}: entering remote process wait {timeout=}") - if not waiter.in_receiving.wait(timeout): - self.logger.info(f"{req_id=}: remote processing timeout {timeout=}") + self.logger.debug(f"{req_id=}: Waiting starts") + + # Three stages, sending, waiting for receiving first byte, receiving + # sending with progress timeout + self.logger.debug(f"{req_id=}: entering sending wait {timeout=}") + sending_complete = self._future_wait(future, timeout) + if not sending_complete: + self.logger.debug(f"{req_id=}: sending timeout {timeout=}") + return self._get_result(req_id) + + self.logger.debug(f"{req_id=}: sending complete") + + # waiting for receiving first byte + self.logger.debug(f"{req_id=}: entering remote process wait {timeout=}") + if not waiter.in_receiving.wait(timeout): + self.logger.debug(f"{req_id=}: remote processing timeout {timeout=}") + return self._get_result(req_id) + self.logger.debug(f"{req_id=}: in receiving") + + # receiving with progress timeout + r_future = waiter.receiving_future + self.logger.debug(f"{req_id=}: entering receiving wait {timeout=}") + receiving_complete = self._future_wait(r_future, timeout) + if not receiving_complete: + self.logger.info(f"{req_id=}: receiving timeout {timeout=}") + return self._get_result(req_id) + self.logger.debug(f"{req_id=}: receiving complete") + waiter.result = Message(r_future.headers, r_future.result()) + decode_payload(waiter.result, encoding_key=StreamHeaderKey.PAYLOAD_ENCODING) + self.logger.debug(f"{req_id=}: return result {waiter.result=}") return self._get_result(req_id) - self.logger.debug(f"{req_id=}: in receiving") - - # receiving with progress timeout - r_future = waiter.receiving_future - self.logger.debug(f"{req_id=}: entering receiving wait {timeout=}") - receiving_complete = self._future_wait(r_future, timeout) - if not receiving_complete: - self.logger.info(f"{req_id=}: receiving timeout {timeout=}") + except Exception as ex: + self.logger.error(f"exception sending request: {secure_format_exception(ex)}") return self._get_result(req_id) - self.logger.debug(f"{req_id=}: receiving complete") - waiter.result = Message(r_future.headers, r_future.result()) - decode_payload(waiter.result, encoding_key=StreamHeaderKey.PAYLOAD_ENCODING) - self.logger.debug(f"{req_id=}: return result {waiter.result=}") - result = self._get_result(req_id) - return result def _process_reply(self, future: StreamFuture): headers = future.headers diff --git a/nvflare/fuel/utils/pipe/cell_pipe.py b/nvflare/fuel/utils/pipe/cell_pipe.py index ab95ec437e..56f6716454 100644 --- a/nvflare/fuel/utils/pipe/cell_pipe.py +++ b/nvflare/fuel/utils/pipe/cell_pipe.py @@ -15,6 +15,7 @@ import logging import queue import threading +import time from typing import Tuple, Union from nvflare.fuel.f3.cellnet.cell import Cell @@ -36,6 +37,8 @@ _HEADER_MSG_TYPE = _PREFIX + "msg_type" _HEADER_MSG_ID = _PREFIX + "msg_id" _HEADER_REQ_ID = _PREFIX + "req_id" +_HEADER_START_TIME = _PREFIX + "start" +_HEADER_HB_SEQ = _PREFIX + "hb_seq" def _cell_fqcn(mode, site_name, token): @@ -46,8 +49,10 @@ def _cell_fqcn(mode, site_name, token): return f"{site_name}_{token}_{mode}" -def _to_cell_message(msg: Message) -> CellMessage: - headers = {_HEADER_MSG_TYPE: msg.msg_type, _HEADER_MSG_ID: msg.msg_id} +def _to_cell_message(msg: Message, extra=None) -> CellMessage: + headers = {_HEADER_MSG_TYPE: msg.msg_type, _HEADER_MSG_ID: msg.msg_id, _HEADER_START_TIME: time.time()} + if extra: + headers.update(extra) if msg.req_id: headers[_HEADER_REQ_ID] = msg.req_id @@ -202,12 +207,29 @@ def __init__( self.channel = None # the cellnet message channel self.pipe_lock = threading.Lock() # used to ensure no msg to be sent after closed self.closed = False + self.last_peer_active_time = 0.0 + self.hb_seq = 1 + + def _update_peer_active_time(self, msg: CellMessage, ch_name: str, msg_type: str): + origin = msg.get_header(MessageHeaderKey.ORIGIN) + if origin == self.peer_fqcn: + self.logger.debug(f"{time.time()}: _update_peer_active_time: {ch_name=} {msg_type=} {msg.headers}") + self.last_peer_active_time = time.time() + + def get_last_peer_active_time(self): + return self.last_peer_active_time def set_cell_cb(self, channel_name: str): # This allows multiple pipes over the same cell (e.g. one channel for tasks, another for metrics), # as long as different pipes use different cell message channels self.channel = f"{_PREFIX}{channel_name}" self.cell.register_request_cb(channel=self.channel, topic="*", cb=self._receive_message) + self.cell.core_cell.add_incoming_request_filter( + channel="*", topic="*", cb=self._update_peer_active_time, ch_name=channel_name, msg_type="req" + ) + self.cell.core_cell.add_incoming_reply_filter( + channel="*", topic="*", cb=self._update_peer_active_time, ch_name=channel_name, msg_type="reply" + ) self.logger.info(f"registered CellPipe request CB for {self.channel}") def send(self, msg: Message, timeout=None) -> bool: @@ -225,31 +247,51 @@ def send(self, msg: Message, timeout=None) -> bool: if self.closed: raise BrokenPipeError("pipe closed") - optional = False - if msg.topic in [Topic.END, Topic.ABORT, Topic.HEARTBEAT]: - optional = True + # Note: the following code must not be within the lock scope + # Otherwise only one message can be sent at a time! + optional = False + if msg.topic in [Topic.END, Topic.ABORT, Topic.HEARTBEAT]: + optional = True + + if not timeout and msg.topic in [Topic.END, Topic.ABORT]: + timeout = 5.0 # need to keep the connection for some time; otherwise the msg may not go out + + if msg.topic == Topic.HEARTBEAT: + # for debugging purpose + extra_headers = {_HEADER_HB_SEQ: self.hb_seq} + self.hb_seq += 1 - reply = self.cell.send_request( + # don't need to wait for reply! + self.cell.fire_and_forget( channel=self.channel, topic=msg.topic, - target=self.peer_fqcn, - request=_to_cell_message(msg), - timeout=timeout, + targets=[self.peer_fqcn], + message=_to_cell_message(msg, extra_headers), optional=optional, ) - if reply: - rc = reply.get_header(MessageHeaderKey.RETURN_CODE) - if rc == ReturnCode.OK: - return True - else: - err = f"failed to send '{msg.topic}' to '{self.peer_fqcn}' in channel '{self.channel}': {rc}" - if optional: - self.logger.debug(err) - else: - self.logger.error(err) - return False + return True + + reply = self.cell.send_request( + channel=self.channel, + topic=msg.topic, + target=self.peer_fqcn, + request=_to_cell_message(msg), + timeout=timeout, + optional=optional, + ) + if reply: + rc = reply.get_header(MessageHeaderKey.RETURN_CODE) + if rc == ReturnCode.OK: + return True else: + err = f"failed to send '{msg.topic}' to '{self.peer_fqcn}' in channel '{self.channel}': {rc}" + if optional: + self.logger.debug(err) + else: + self.logger.error(err) return False + else: + return False def _receive_message(self, request: CellMessage) -> Union[None, CellMessage]: sender = request.get_header(MessageHeaderKey.ORIGIN) diff --git a/nvflare/fuel/utils/pipe/file_pipe.py b/nvflare/fuel/utils/pipe/file_pipe.py index 4b5ceb19bd..da211aa442 100644 --- a/nvflare/fuel/utils/pipe/file_pipe.py +++ b/nvflare/fuel/utils/pipe/file_pipe.py @@ -22,7 +22,7 @@ from nvflare.fuel.utils.pipe.file_accessor import FileAccessor from nvflare.fuel.utils.pipe.file_name_utils import file_name_to_message, message_to_file_name from nvflare.fuel.utils.pipe.fobs_file_accessor import FobsFileAccessor -from nvflare.fuel.utils.pipe.pipe import Message, Pipe +from nvflare.fuel.utils.pipe.pipe import Message, Pipe, Topic from nvflare.fuel.utils.validation_utils import check_object_type, check_positive_number, check_str @@ -260,6 +260,10 @@ def send(self, msg: Message, timeout=None) -> bool: """ if not self.pipe_path: raise BrokenPipeError("pipe is not open") + + if not timeout and msg.topic in [Topic.END, Topic.ABORT, Topic.HEARTBEAT]: + timeout = 5.0 + return self.put_f(msg, timeout) def receive(self, timeout=None): diff --git a/nvflare/fuel/utils/pipe/pipe.py b/nvflare/fuel/utils/pipe/pipe.py index c4aeb81b3e..b928b01b5b 100644 --- a/nvflare/fuel/utils/pipe/pipe.py +++ b/nvflare/fuel/utils/pipe/pipe.py @@ -140,6 +140,14 @@ def can_resend(self) -> bool: """Whether the pipe is able to resend a message.""" pass + def get_last_peer_active_time(self): + """Get the last time that the peer is known to be active + + Returns: the last time that the peer is known to be active; or 0 if this info is not available + + """ + return 0 + def export(self, export_mode: str) -> Tuple[str, dict]: if export_mode == ExportMode.SELF: mode = self.mode diff --git a/nvflare/fuel/utils/pipe/pipe_handler.py b/nvflare/fuel/utils/pipe/pipe_handler.py index 4826c0bfa4..efbbcee134 100644 --- a/nvflare/fuel/utils/pipe/pipe_handler.py +++ b/nvflare/fuel/utils/pipe/pipe_handler.py @@ -61,7 +61,7 @@ def __init__( heartbeat_interval=5.0, heartbeat_timeout=30.0, resend_interval=2.0, - max_resends=None, + max_resends=5, default_request_timeout=5.0, ): """Constructor of the PipeHandler. @@ -166,6 +166,7 @@ def set_message_cb(self, cb, *args, **kwargs): def _send_to_pipe(self, msg: Message, timeout=None, abort_signal: Signal = None): pipe = self.pipe if not pipe: + self.logger.error("cannot send message to pipe since it's already closed") return False if not timeout or not pipe.can_resend() or not self.resend_interval: @@ -181,6 +182,7 @@ def _send_to_pipe(self, msg: Message, timeout=None, abort_signal: Signal = None) return sent if self.max_resends is not None and num_sends > self.max_resends: + self.logger.error(f"abort sending after {num_sends} tries") return False if self.asked_to_stop: @@ -208,10 +210,10 @@ def start(self): """Starts the PipeHandler. Note: before calling this method, the pipe managed by this PipeHandler must have been opened. """ - if not self.reader.is_alive(): + if self.reader and not self.reader.is_alive(): self.reader.start() - if not self.heartbeat_sender.is_alive(): + if self.heartbeat_sender and not self.heartbeat_sender.is_alive(): self.heartbeat_sender.start() def stop(self, close_pipe=True): @@ -256,7 +258,8 @@ def notify_end(self, data): p = self.pipe if p: try: - p.send(self._make_event_message(Topic.END, data)) + # fire and forget + p.send(self._make_event_message(Topic.END, data), 0.1) except Exception as ex: self.logger.debug(f"exception notify_end: {secure_format_exception(ex)}") @@ -265,7 +268,8 @@ def notify_abort(self, data): p = self.pipe if p: try: - p.send(self._make_event_message(Topic.ABORT, data)) + # fire and forget + p.send(self._make_event_message(Topic.ABORT, data), 0.1) except Exception as ex: self.logger.debug(f"exception notify_abort: {secure_format_exception(ex)}") @@ -310,6 +314,11 @@ def _try_read(self): break else: # is peer gone? + # ask the pipe for the last known active time of the peer + last_peer_active_time = self.pipe.get_last_peer_active_time() + if last_peer_active_time > self._last_heartbeat_received_time: + self._last_heartbeat_received_time = last_peer_active_time + if ( self.heartbeat_timeout and now - self._last_heartbeat_received_time > self.heartbeat_timeout @@ -339,6 +348,7 @@ def _heartbeat(self): last_heartbeat_sent_time = now time.sleep(self._check_interval) + self.heartbeat_sender = None def get_next(self) -> Optional[Message]: """Gets the next message from the message queue. From 2eed4071a50a0ec2d6124802bca509986e4c85cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yuan-Ting=20Hsieh=20=28=E8=AC=9D=E6=B2=85=E5=BB=B7=29?= Date: Thu, 18 Apr 2024 17:22:32 -0700 Subject: [PATCH 12/14] Update ci cd from 2.4 (#2520) * Update github actions (#2450) * Fix premerge (#2467) * Fix issues on hello-world TF2 notebook * Fix tf integration test (#2504) * Add client api integration tests --------- Co-authored-by: Isaac Yang Co-authored-by: Sean Yang --- .github/workflows/blossom-ci.yml | 2 +- .github/workflows/codeql.yml | 2 +- .github/workflows/markdown-links-check.yml | 2 +- .github/workflows/premerge.yml | 18 +- .../jobs/hello-tf2/app/custom/trainer.py | 4 + examples/hello-world/hello_world.ipynb | 59 +----- .../app_client/config/config_fed_client.conf | 43 ++++ .../np_loop/app_client/custom/train_diff.py | 60 ++++++ .../np_loop/app_client/custom/train_full.py | 59 ++++++ .../np_loop/app_client/custom/train_loop.py | 68 +++++++ .../app_client/custom/train_metrics.py | 86 ++++++++ .../app_server/config/config_fed_server.conf | 47 +++++ .../data/jobs/np_loop/meta.conf | 15 ++ .../app/config/config_fed_client.conf | 47 +++++ .../app/config/config_fed_server.conf | 47 +++++ .../app/custom/train_diff.py | 60 ++++++ .../app/custom/train_full.py | 59 ++++++ .../app/custom/train_loop.py | 68 +++++++ .../app/custom/train_metrics.py | 86 ++++++++ .../data/jobs/np_loop_cell_pipe/meta.conf | 11 + .../app/config/config_fed_client.conf | 77 +++++++ .../app/config/config_fed_server.conf | 56 ++++++ .../jobs/np_metrics/app/custom/train_diff.py | 60 ++++++ .../jobs/np_metrics/app/custom/train_full.py | 59 ++++++ .../jobs/np_metrics/app/custom/train_loop.py | 68 +++++++ .../np_metrics/app/custom/train_metrics.py | 86 ++++++++ .../data/jobs/np_metrics/meta.conf | 11 + .../app/config/config_fed_client.conf | 77 +++++++ .../app/config/config_fed_server.conf | 62 ++++++ .../pt_client_api/app/custom/cifar10_fl.py | 137 +++++++++++++ .../data/jobs/pt_client_api/app/custom/net.py | 37 ++++ .../data/jobs/pt_client_api/meta.conf | 11 + .../app/config/config_fed_client.conf | 43 ++++ .../app/config/config_fed_server.conf | 38 ++++ .../app/custom/cifar10_fl.py | 132 ++++++++++++ .../pt_client_api_cyclic/app/custom/net.py | 37 ++++ .../data/jobs/pt_client_api_cyclic/meta.conf | 11 + .../app/config/config_fed_client.conf | 77 +++++++ .../app/config/config_fed_server.conf | 62 ++++++ .../app/custom/cifar10_fl.py | 137 +++++++++++++ .../app/custom/net.py | 37 ++++ .../jobs/pt_client_api_launch_once/meta.conf | 11 + .../app/config/config_fed_client.conf | 120 +++++++++++ .../app/config/config_fed_server.conf | 106 ++++++++++ .../jobs/qa_job_4558419/app/custom/net.py | 25 +++ .../jobs/qa_job_4558419/app/custom/train.py | 68 +++++++ .../data/jobs/qa_job_4558419/meta.conf | 11 + .../app/config/config_fed_client.conf | 77 +++++++ .../app/config/config_fed_server.conf | 62 ++++++ .../jobs/qa_job_4561583/app/custom/net.py | 25 +++ .../jobs/qa_job_4561583/app/custom/pl_net.py | 66 ++++++ .../qa_job_4561583/app/custom/poc_executor.py | 41 ++++ .../data/jobs/qa_job_4561583/meta.conf | 11 + .../app/config/config_fed_client.conf | 144 ++++++++++++++ .../app/config/config_fed_server.conf | 33 +++ .../jobs/qa_job_4561872/app/custom/net.py | 37 ++++ .../jobs/qa_job_4561872/app/custom/train.py | 188 ++++++++++++++++++ .../data/jobs/qa_job_4561872/meta.conf | 11 + .../app/config/config_fed_client.conf | 77 +++++++ .../app/config/config_fed_server.conf | 62 ++++++ .../jobs/qa_job_4592780/app/custom/net.py | 25 +++ .../qa_job_4592780/app/custom/poc_executor.py | 78 ++++++++ .../data/jobs/qa_job_4592780/meta.conf | 11 + .../standalone_job/client_api.yml | 148 ++++++++++++++ .../standalone_job/client_api_qa.yml | 102 ++++++++++ .../integration_test/run_integration_tests.sh | 2 +- tests/integration_test/src/utils.py | 4 + tests/integration_test/test_configs.yml | 4 + 68 files changed, 3645 insertions(+), 62 deletions(-) create mode 100644 tests/integration_test/data/jobs/np_loop/app_client/config/config_fed_client.conf create mode 100755 tests/integration_test/data/jobs/np_loop/app_client/custom/train_diff.py create mode 100755 tests/integration_test/data/jobs/np_loop/app_client/custom/train_full.py create mode 100755 tests/integration_test/data/jobs/np_loop/app_client/custom/train_loop.py create mode 100755 tests/integration_test/data/jobs/np_loop/app_client/custom/train_metrics.py create mode 100644 tests/integration_test/data/jobs/np_loop/app_server/config/config_fed_server.conf create mode 100644 tests/integration_test/data/jobs/np_loop/meta.conf create mode 100644 tests/integration_test/data/jobs/np_loop_cell_pipe/app/config/config_fed_client.conf create mode 100644 tests/integration_test/data/jobs/np_loop_cell_pipe/app/config/config_fed_server.conf create mode 100755 tests/integration_test/data/jobs/np_loop_cell_pipe/app/custom/train_diff.py create mode 100755 tests/integration_test/data/jobs/np_loop_cell_pipe/app/custom/train_full.py create mode 100755 tests/integration_test/data/jobs/np_loop_cell_pipe/app/custom/train_loop.py create mode 100755 tests/integration_test/data/jobs/np_loop_cell_pipe/app/custom/train_metrics.py create mode 100644 tests/integration_test/data/jobs/np_loop_cell_pipe/meta.conf create mode 100644 tests/integration_test/data/jobs/np_metrics/app/config/config_fed_client.conf create mode 100644 tests/integration_test/data/jobs/np_metrics/app/config/config_fed_server.conf create mode 100755 tests/integration_test/data/jobs/np_metrics/app/custom/train_diff.py create mode 100755 tests/integration_test/data/jobs/np_metrics/app/custom/train_full.py create mode 100755 tests/integration_test/data/jobs/np_metrics/app/custom/train_loop.py create mode 100755 tests/integration_test/data/jobs/np_metrics/app/custom/train_metrics.py create mode 100644 tests/integration_test/data/jobs/np_metrics/meta.conf create mode 100644 tests/integration_test/data/jobs/pt_client_api/app/config/config_fed_client.conf create mode 100644 tests/integration_test/data/jobs/pt_client_api/app/config/config_fed_server.conf create mode 100644 tests/integration_test/data/jobs/pt_client_api/app/custom/cifar10_fl.py create mode 100644 tests/integration_test/data/jobs/pt_client_api/app/custom/net.py create mode 100644 tests/integration_test/data/jobs/pt_client_api/meta.conf create mode 100644 tests/integration_test/data/jobs/pt_client_api_cyclic/app/config/config_fed_client.conf create mode 100644 tests/integration_test/data/jobs/pt_client_api_cyclic/app/config/config_fed_server.conf create mode 100644 tests/integration_test/data/jobs/pt_client_api_cyclic/app/custom/cifar10_fl.py create mode 100644 tests/integration_test/data/jobs/pt_client_api_cyclic/app/custom/net.py create mode 100644 tests/integration_test/data/jobs/pt_client_api_cyclic/meta.conf create mode 100644 tests/integration_test/data/jobs/pt_client_api_launch_once/app/config/config_fed_client.conf create mode 100644 tests/integration_test/data/jobs/pt_client_api_launch_once/app/config/config_fed_server.conf create mode 100644 tests/integration_test/data/jobs/pt_client_api_launch_once/app/custom/cifar10_fl.py create mode 100644 tests/integration_test/data/jobs/pt_client_api_launch_once/app/custom/net.py create mode 100644 tests/integration_test/data/jobs/pt_client_api_launch_once/meta.conf create mode 100644 tests/integration_test/data/jobs/qa_job_4558419/app/config/config_fed_client.conf create mode 100644 tests/integration_test/data/jobs/qa_job_4558419/app/config/config_fed_server.conf create mode 100644 tests/integration_test/data/jobs/qa_job_4558419/app/custom/net.py create mode 100644 tests/integration_test/data/jobs/qa_job_4558419/app/custom/train.py create mode 100644 tests/integration_test/data/jobs/qa_job_4558419/meta.conf create mode 100644 tests/integration_test/data/jobs/qa_job_4561583/app/config/config_fed_client.conf create mode 100644 tests/integration_test/data/jobs/qa_job_4561583/app/config/config_fed_server.conf create mode 100644 tests/integration_test/data/jobs/qa_job_4561583/app/custom/net.py create mode 100644 tests/integration_test/data/jobs/qa_job_4561583/app/custom/pl_net.py create mode 100644 tests/integration_test/data/jobs/qa_job_4561583/app/custom/poc_executor.py create mode 100644 tests/integration_test/data/jobs/qa_job_4561583/meta.conf create mode 100644 tests/integration_test/data/jobs/qa_job_4561872/app/config/config_fed_client.conf create mode 100644 tests/integration_test/data/jobs/qa_job_4561872/app/config/config_fed_server.conf create mode 100644 tests/integration_test/data/jobs/qa_job_4561872/app/custom/net.py create mode 100644 tests/integration_test/data/jobs/qa_job_4561872/app/custom/train.py create mode 100644 tests/integration_test/data/jobs/qa_job_4561872/meta.conf create mode 100644 tests/integration_test/data/jobs/qa_job_4592780/app/config/config_fed_client.conf create mode 100644 tests/integration_test/data/jobs/qa_job_4592780/app/config/config_fed_server.conf create mode 100644 tests/integration_test/data/jobs/qa_job_4592780/app/custom/net.py create mode 100644 tests/integration_test/data/jobs/qa_job_4592780/app/custom/poc_executor.py create mode 100644 tests/integration_test/data/jobs/qa_job_4592780/meta.conf create mode 100644 tests/integration_test/data/test_configs/standalone_job/client_api.yml create mode 100644 tests/integration_test/data/test_configs/standalone_job/client_api_qa.yml diff --git a/.github/workflows/blossom-ci.yml b/.github/workflows/blossom-ci.yml index 844cdf1c93..ce13d01fbb 100644 --- a/.github/workflows/blossom-ci.yml +++ b/.github/workflows/blossom-ci.yml @@ -74,7 +74,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout code - uses: actions/checkout@v2 + uses: actions/checkout@v4 with: repository: ${{ fromJson(needs.Authorization.outputs.args).repo }} ref: ${{ fromJson(needs.Authorization.outputs.args).ref }} diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml index 193f7b48e5..0425542192 100644 --- a/.github/workflows/codeql.yml +++ b/.github/workflows/codeql.yml @@ -36,7 +36,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 # Initializes the CodeQL tools for scanning. - name: Initialize CodeQL diff --git a/.github/workflows/markdown-links-check.yml b/.github/workflows/markdown-links-check.yml index 56fe4e7982..1a8686ea30 100644 --- a/.github/workflows/markdown-links-check.yml +++ b/.github/workflows/markdown-links-check.yml @@ -23,7 +23,7 @@ jobs: markdown-link-check: runs-on: ubuntu-latest steps: - - uses: actions/checkout@master + - uses: actions/checkout@v4 - uses: gaurav-nelson/github-action-markdown-link-check@1.0.15 with: max-depth: -1 diff --git a/.github/workflows/premerge.yml b/.github/workflows/premerge.yml index 932275df7e..d3de85e167 100644 --- a/.github/workflows/premerge.yml +++ b/.github/workflows/premerge.yml @@ -29,15 +29,15 @@ jobs: os: [ ubuntu-22.04, ubuntu-20.04 ] python-version: [ "3.8", "3.9", "3.10" ] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | - python -m pip install --upgrade pip - pip install -e .[dev] + python3 -m pip install --upgrade pip + python3 -m pip install --no-cache-dir -e .[dev] - name: Run unit test run: ./runtest.sh @@ -49,15 +49,15 @@ jobs: os: [ ubuntu-22.04, ubuntu-20.04 ] python-version: [ "3.8", "3.9", "3.10" ] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | - python -m pip install --upgrade pip - pip install -e .[dev] - pip install build twine torch torchvision + python3 -m pip install --upgrade pip + python3 -m pip install --no-cache-dir -e .[dev] + python3 -m pip install --no-cache-dir build twine torch torchvision - name: Run wheel build run: python3 -m build --wheel diff --git a/examples/hello-world/hello-tf2/jobs/hello-tf2/app/custom/trainer.py b/examples/hello-world/hello-tf2/jobs/hello-tf2/app/custom/trainer.py index 720f795198..25725687d4 100644 --- a/examples/hello-world/hello-tf2/jobs/hello-tf2/app/custom/trainer.py +++ b/examples/hello-world/hello-tf2/jobs/hello-tf2/app/custom/trainer.py @@ -12,6 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os + +os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true" + import numpy as np import tensorflow as tf from tf2_net import Net diff --git a/examples/hello-world/hello_world.ipynb b/examples/hello-world/hello_world.ipynb index a1853e1440..cd8e9b8f23 100644 --- a/examples/hello-world/hello_world.ipynb +++ b/examples/hello-world/hello_world.ipynb @@ -738,8 +738,9 @@ ] }, { + "attachments": {}, "cell_type": "markdown", - "id": "425c7f1d-7cb6-4602-bb88-4b87ff517529", + "id": "696a384e-f6da-4044-a7b9-db4c464f52ac", "metadata": {}, "source": [ "#### Running Tensorflow on local host with GPU \n", @@ -748,53 +749,13 @@ "We are running with 1 server, 2 sites in a local machine, which means three process involved for this federated training. \n", "If the local host has GPU, you might enter OOM error, due to the way Tensorflow consumes GPU memory. By default, TensorFlow maps nearly all of the GPU memory of all GPUs (subject to CUDA_VISIBLE_DEVICES) visible to the process. If one has multiple process, some of the process will be OOM. To avoid multiple processes grabbing all GPU memory in TF, use the options described in [Limiting GPU memory growth]( https://www.tensorflow.org/guide/gpu#limiting_gpu_memory_growth). \n", "\n", - "In our cases, we prefer that the process only allocates a subset of the available memory, or to only grow the memory usage as is needed by the process. TensorFlow provides two methods to control this. \n", + "In our cases, we prefer that the process only allocates a subset of the available memory, or to only grow the memory usage as is needed by the process. TensorFlow provides two methods to control this, as described in the above link.\n", + "\n", + "In this example, we explictly set the environment varialble `TF_FORCE_GPU_ALLOW_GROWTH` to `true` at the very beginning of the trainer.py file, which runs in the clients and will allocate GPU memory for training. With the env var been set, TF will not grab the entire GPU memory and will not cause GPU OOM error when running POC on local host.\n", + "\n", + "Note that setting the env var `TF_FORCE_GPU_ALLOW_GROWTH` inside this notebook takes no effect because the clients of POC have already started and their env vars are set at the starting time.\n", + "\n", "\n", - "The First method is set the environmental variable TF_FORCE_GPU_ALLOW_GROWTH to true. This configuration is platform specific. \n", - "The 2nd method is using the piece of code below" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "026ce33a-90b1-4ef7-8c7c-f25722f2d2ae", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "%env TF_FORCE_GPU_ALLOW_GROWTH=true" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "91cf84ec-57f1-439f-b72b-b33fea1a7f7c", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "import tensorflow as tf\n", - "gpus = tf.config.list_physical_devices('GPU')\n", - "if gpus:\n", - " # Restrict TensorFlow to only allocate 1GB of memory on the first GPU\n", - " try:\n", - " tf.config.set_logical_device_configuration(\n", - " gpus[0],\n", - " [tf.config.LogicalDeviceConfiguration(memory_limit=1024)])\n", - " logical_gpus = tf.config.list_logical_devices('GPU')\n", - " print(len(gpus), \"Physical GPUs,\", len(logical_gpus), \"Logical GPUs\")\n", - " except RuntimeError as e:\n", - " # Virtual devices must be set before GPUs have been initialized\n", - " print(e)" - ] - }, - { - "cell_type": "markdown", - "id": "0c3f1149-d54d-4f3b-b497-c06deba22fef", - "metadata": {}, - "source": [ "### 1. Submit job using FLARE API\n", "\n", "Starting a FLARE API session and submit the hello-tf2 job\n", @@ -833,7 +794,7 @@ }, "outputs": [], "source": [ - "! tail -100 /tmp/nvflare/poc/server/log.txt" + "! tail -100 /tmp/nvflare/poc/example_project/prod_00/server/log.txt" ] }, { @@ -977,7 +938,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.17" + "version": "3.8.15" }, "vscode": { "interpreter": { diff --git a/tests/integration_test/data/jobs/np_loop/app_client/config/config_fed_client.conf b/tests/integration_test/data/jobs/np_loop/app_client/config/config_fed_client.conf new file mode 100644 index 0000000000..5248dae97a --- /dev/null +++ b/tests/integration_test/data/jobs/np_loop/app_client/config/config_fed_client.conf @@ -0,0 +1,43 @@ +{ + format_version = 2 + app_script = "train_loop.py" + app_config = "" + executors = [ + { + tasks = [ + "train" + ] + executor { + path = "nvflare.app_common.executors.client_api_launcher_executor.ClientAPILauncherExecutor" + args { + launcher_id = "launcher" + pipe_id = "pipe" + heartbeat_timeout = 60 + params_exchange_format = "numpy" + params_transfer_type = "FULL" + train_with_evaluation = true + } + } + } + ] + task_data_filters = [] + task_result_filters = [] + components = [ + { + id = "launcher" + path = "nvflare.app_common.launchers.subprocess_launcher.SubprocessLauncher" + args { + script = "python3 custom/{app_script} {app_config} " + launch_once = true + } + } + { + id = "pipe" + path = "nvflare.fuel.utils.pipe.file_pipe.FilePipe" + args { + mode = "PASSIVE" + root_path = "{WORKSPACE}/{JOB_ID}/{SITE_NAME}" + } + } + ] +} diff --git a/tests/integration_test/data/jobs/np_loop/app_client/custom/train_diff.py b/tests/integration_test/data/jobs/np_loop/app_client/custom/train_diff.py new file mode 100755 index 0000000000..08f3c8e7c0 --- /dev/null +++ b/tests/integration_test/data/jobs/np_loop/app_client/custom/train_diff.py @@ -0,0 +1,60 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +import nvflare.client as flare + + +def train(input_arr): + output_arr = copy.deepcopy(input_arr) + # mock training with plus 1 + return output_arr + 1 + + +def evaluate(input_arr): + # mock evaluation metrics + return 100 + + +def main(): + # initializes NVFlare interface + flare.init() + + # get model from NVFlare + input_model = flare.receive() + print(f"received weights is: {input_model.params}") + + # get system information + sys_info = flare.system_info() + print(f"system info is: {sys_info}") + + input_numpy_array = input_model.params["numpy_key"] + + # training + output_numpy_array = train(input_numpy_array) + + # evaluation + metrics = evaluate(input_numpy_array) + + # calculate difference here + diff = output_numpy_array - input_numpy_array + + # send back the model difference + print(f"send back: {diff}") + flare.send(flare.FLModel(params={"numpy_key": diff}, params_type="DIFF", metrics={"accuracy": metrics})) + + +if __name__ == "__main__": + main() diff --git a/tests/integration_test/data/jobs/np_loop/app_client/custom/train_full.py b/tests/integration_test/data/jobs/np_loop/app_client/custom/train_full.py new file mode 100755 index 0000000000..9c05536b85 --- /dev/null +++ b/tests/integration_test/data/jobs/np_loop/app_client/custom/train_full.py @@ -0,0 +1,59 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +import nvflare.client as flare + + +def train(input_arr): + output_arr = copy.deepcopy(input_arr) + # mock training with plus 1 + return output_arr + 1 + + +def evaluate(input_arr): + # mock evaluation metrics + return 100 + + +def main(): + # initializes NVFlare interface + flare.init() + + # get model from NVFlare + input_model = flare.receive() + print(f"received weights is: {input_model.params}") + + # get system information + sys_info = flare.system_info() + print(f"system info is: {sys_info}") + + input_numpy_array = input_model.params["numpy_key"] + + # training + output_numpy_array = train(input_numpy_array) + + # evaluation + metrics = evaluate(input_numpy_array) + + # send back the model + print(f"send back: {output_numpy_array}") + flare.send( + flare.FLModel(params={"numpy_key": output_numpy_array}, params_type="FULL", metrics={"accuracy": metrics}) + ) + + +if __name__ == "__main__": + main() diff --git a/tests/integration_test/data/jobs/np_loop/app_client/custom/train_loop.py b/tests/integration_test/data/jobs/np_loop/app_client/custom/train_loop.py new file mode 100755 index 0000000000..ea9ec25149 --- /dev/null +++ b/tests/integration_test/data/jobs/np_loop/app_client/custom/train_loop.py @@ -0,0 +1,68 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +import nvflare.client as flare + + +def train(input_arr): + output_arr = copy.deepcopy(input_arr) + # mock training with plus 1 + return output_arr + 1 + + +def evaluate(input_arr): + # mock evaluation metrics + return 100 + + +def main(): + # initializes NVFlare interface + flare.init() + + # get system information + sys_info = flare.system_info() + print(f"system info is: {sys_info}", flush=True) + + while flare.is_running(): + input_model = flare.receive() + print(f"received weights is: {input_model.params}", flush=True) + + input_numpy_array = input_model.params["numpy_key"] + + # training + output_numpy_array = train(input_numpy_array) + + # evaluation + metrics = evaluate(input_numpy_array) + + sys_info = flare.system_info() + print(f"system info is: {sys_info}", flush=True) + print(f"finish round: {input_model.current_round}", flush=True) + + # send back the model + print(f"send back: {output_numpy_array}", flush=True) + flare.send( + flare.FLModel( + params={"numpy_key": output_numpy_array}, + params_type="FULL", + metrics={"accuracy": metrics}, + current_round=input_model.current_round, + ) + ) + + +if __name__ == "__main__": + main() diff --git a/tests/integration_test/data/jobs/np_loop/app_client/custom/train_metrics.py b/tests/integration_test/data/jobs/np_loop/app_client/custom/train_metrics.py new file mode 100755 index 0000000000..e508b74f30 --- /dev/null +++ b/tests/integration_test/data/jobs/np_loop/app_client/custom/train_metrics.py @@ -0,0 +1,86 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import time + +import nvflare.client as flare +from nvflare.client.tracking import MLflowWriter + + +def train(input_arr, current_round, epochs=3): + writer = MLflowWriter() + output_arr = copy.deepcopy(input_arr) + num_of_data = 2000 + batch_size = 16 + num_of_batches = num_of_data // batch_size + for i in range(epochs): + for j in range(num_of_batches): + global_step = current_round * num_of_batches * epochs + i * num_of_batches + j + print(f"logging record: {global_step}") + writer.log_metric( + key="global_step", + value=global_step, + step=global_step, + ) + # mock training with plus 1 + output_arr += 1 + # assume each epoch takes 1 seconds + time.sleep(1.0) + return output_arr + + +def evaluate(input_arr): + # mock evaluation metrics + return 100 + + +def main(): + # initializes NVFlare interface + flare.init() + + # get system information + sys_info = flare.system_info() + print(f"system info is: {sys_info}") + + while flare.is_running(): + input_model = flare.receive() + print(f"received weights is: {input_model.params}") + + input_numpy_array = input_model.params["numpy_key"] + + # training + output_numpy_array = train(input_numpy_array, current_round=input_model.current_round, epochs=3) + + # evaluation + metrics = evaluate(input_numpy_array) + + sys_info = flare.system_info() + print(f"system info is: {sys_info}") + print(f"finish round: {input_model.current_round}") + + # send back the model + print(f"send back: {output_numpy_array}") + flare.send( + flare.FLModel( + params={"numpy_key": output_numpy_array}, + params_type="FULL", + metrics={"accuracy": metrics}, + current_round=input_model.current_round, + ) + ) + + +if __name__ == "__main__": + main() diff --git a/tests/integration_test/data/jobs/np_loop/app_server/config/config_fed_server.conf b/tests/integration_test/data/jobs/np_loop/app_server/config/config_fed_server.conf new file mode 100644 index 0000000000..36dcd2c344 --- /dev/null +++ b/tests/integration_test/data/jobs/np_loop/app_server/config/config_fed_server.conf @@ -0,0 +1,47 @@ +{ + format_version = 2 + task_data_filters = [] + task_result_filters = [] + workflows = [ + { + id = "scatter_and_gather" + path = "nvflare.app_common.workflows.scatter_and_gather.ScatterAndGather" + args { + min_clients = 2 + num_rounds = 5 + start_round = 0 + wait_time_after_min_received = 0 + aggregator_id = "aggregator" + persistor_id = "persistor" + shareable_generator_id = "shareable_generator" + train_task_name = "train" + train_timeout = 0 + } + } + ] + components = [ + { + id = "persistor" + path = "nvflare.app_common.np.np_model_persistor.NPModelPersistor" + } + { + id = "shareable_generator" + path = "nvflare.app_common.shareablegenerators.full_model_shareable_generator.FullModelShareableGenerator" + args {} + } + { + id = "aggregator" + path = "nvflare.app_common.aggregators.intime_accumulate_model_aggregator.InTimeAccumulateWeightedAggregator" + args { + expected_data_kind = "WEIGHTS" + } + } + { + id = "model_selector" + path = "nvflare.app_common.widgets.intime_model_selector.IntimeModelSelector" + args { + key_metric = "accuracy" + } + } + ] +} diff --git a/tests/integration_test/data/jobs/np_loop/meta.conf b/tests/integration_test/data/jobs/np_loop/meta.conf new file mode 100644 index 0000000000..5e421f1319 --- /dev/null +++ b/tests/integration_test/data/jobs/np_loop/meta.conf @@ -0,0 +1,15 @@ +{ + name = "np_loop" + resource_spec {} + deploy_map { + app_server = [ + "server" + ], + app_client = [ + "site-1", + "site-2" + ] + } + min_clients = 2 + mandatory_clients = [] +} diff --git a/tests/integration_test/data/jobs/np_loop_cell_pipe/app/config/config_fed_client.conf b/tests/integration_test/data/jobs/np_loop_cell_pipe/app/config/config_fed_client.conf new file mode 100644 index 0000000000..458f50f568 --- /dev/null +++ b/tests/integration_test/data/jobs/np_loop_cell_pipe/app/config/config_fed_client.conf @@ -0,0 +1,47 @@ +{ + format_version = 2 + app_script = "train_loop.py" + app_config = "" + executors = [ + { + tasks = [ + "train" + ] + executor { + path = "nvflare.app_common.executors.client_api_launcher_executor.ClientAPILauncherExecutor" + args { + launcher_id = "launcher" + pipe_id = "pipe" + heartbeat_timeout = 60 + params_exchange_format = "numpy" + params_transfer_type = "FULL" + train_with_evaluation = true + } + } + } + ] + task_data_filters = [] + task_result_filters = [] + components = [ + { + id = "launcher" + path = "nvflare.app_common.launchers.subprocess_launcher.SubprocessLauncher" + args { + script = "python3 custom/{app_script} {app_config} " + launch_once = true + } + } + { + id = "pipe" + path = "nvflare.fuel.utils.pipe.cell_pipe.CellPipe" + args { + mode = "PASSIVE" + site_name = "{SITE_NAME}" + token = "{JOB_ID}" + root_url = "{ROOT_URL}" + secure_mode = "{SECURE_MODE}" + workspace_dir = "{WORKSPACE}" + } + } + ] +} diff --git a/tests/integration_test/data/jobs/np_loop_cell_pipe/app/config/config_fed_server.conf b/tests/integration_test/data/jobs/np_loop_cell_pipe/app/config/config_fed_server.conf new file mode 100644 index 0000000000..36dcd2c344 --- /dev/null +++ b/tests/integration_test/data/jobs/np_loop_cell_pipe/app/config/config_fed_server.conf @@ -0,0 +1,47 @@ +{ + format_version = 2 + task_data_filters = [] + task_result_filters = [] + workflows = [ + { + id = "scatter_and_gather" + path = "nvflare.app_common.workflows.scatter_and_gather.ScatterAndGather" + args { + min_clients = 2 + num_rounds = 5 + start_round = 0 + wait_time_after_min_received = 0 + aggregator_id = "aggregator" + persistor_id = "persistor" + shareable_generator_id = "shareable_generator" + train_task_name = "train" + train_timeout = 0 + } + } + ] + components = [ + { + id = "persistor" + path = "nvflare.app_common.np.np_model_persistor.NPModelPersistor" + } + { + id = "shareable_generator" + path = "nvflare.app_common.shareablegenerators.full_model_shareable_generator.FullModelShareableGenerator" + args {} + } + { + id = "aggregator" + path = "nvflare.app_common.aggregators.intime_accumulate_model_aggregator.InTimeAccumulateWeightedAggregator" + args { + expected_data_kind = "WEIGHTS" + } + } + { + id = "model_selector" + path = "nvflare.app_common.widgets.intime_model_selector.IntimeModelSelector" + args { + key_metric = "accuracy" + } + } + ] +} diff --git a/tests/integration_test/data/jobs/np_loop_cell_pipe/app/custom/train_diff.py b/tests/integration_test/data/jobs/np_loop_cell_pipe/app/custom/train_diff.py new file mode 100755 index 0000000000..08f3c8e7c0 --- /dev/null +++ b/tests/integration_test/data/jobs/np_loop_cell_pipe/app/custom/train_diff.py @@ -0,0 +1,60 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +import nvflare.client as flare + + +def train(input_arr): + output_arr = copy.deepcopy(input_arr) + # mock training with plus 1 + return output_arr + 1 + + +def evaluate(input_arr): + # mock evaluation metrics + return 100 + + +def main(): + # initializes NVFlare interface + flare.init() + + # get model from NVFlare + input_model = flare.receive() + print(f"received weights is: {input_model.params}") + + # get system information + sys_info = flare.system_info() + print(f"system info is: {sys_info}") + + input_numpy_array = input_model.params["numpy_key"] + + # training + output_numpy_array = train(input_numpy_array) + + # evaluation + metrics = evaluate(input_numpy_array) + + # calculate difference here + diff = output_numpy_array - input_numpy_array + + # send back the model difference + print(f"send back: {diff}") + flare.send(flare.FLModel(params={"numpy_key": diff}, params_type="DIFF", metrics={"accuracy": metrics})) + + +if __name__ == "__main__": + main() diff --git a/tests/integration_test/data/jobs/np_loop_cell_pipe/app/custom/train_full.py b/tests/integration_test/data/jobs/np_loop_cell_pipe/app/custom/train_full.py new file mode 100755 index 0000000000..9c05536b85 --- /dev/null +++ b/tests/integration_test/data/jobs/np_loop_cell_pipe/app/custom/train_full.py @@ -0,0 +1,59 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +import nvflare.client as flare + + +def train(input_arr): + output_arr = copy.deepcopy(input_arr) + # mock training with plus 1 + return output_arr + 1 + + +def evaluate(input_arr): + # mock evaluation metrics + return 100 + + +def main(): + # initializes NVFlare interface + flare.init() + + # get model from NVFlare + input_model = flare.receive() + print(f"received weights is: {input_model.params}") + + # get system information + sys_info = flare.system_info() + print(f"system info is: {sys_info}") + + input_numpy_array = input_model.params["numpy_key"] + + # training + output_numpy_array = train(input_numpy_array) + + # evaluation + metrics = evaluate(input_numpy_array) + + # send back the model + print(f"send back: {output_numpy_array}") + flare.send( + flare.FLModel(params={"numpy_key": output_numpy_array}, params_type="FULL", metrics={"accuracy": metrics}) + ) + + +if __name__ == "__main__": + main() diff --git a/tests/integration_test/data/jobs/np_loop_cell_pipe/app/custom/train_loop.py b/tests/integration_test/data/jobs/np_loop_cell_pipe/app/custom/train_loop.py new file mode 100755 index 0000000000..ea9ec25149 --- /dev/null +++ b/tests/integration_test/data/jobs/np_loop_cell_pipe/app/custom/train_loop.py @@ -0,0 +1,68 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +import nvflare.client as flare + + +def train(input_arr): + output_arr = copy.deepcopy(input_arr) + # mock training with plus 1 + return output_arr + 1 + + +def evaluate(input_arr): + # mock evaluation metrics + return 100 + + +def main(): + # initializes NVFlare interface + flare.init() + + # get system information + sys_info = flare.system_info() + print(f"system info is: {sys_info}", flush=True) + + while flare.is_running(): + input_model = flare.receive() + print(f"received weights is: {input_model.params}", flush=True) + + input_numpy_array = input_model.params["numpy_key"] + + # training + output_numpy_array = train(input_numpy_array) + + # evaluation + metrics = evaluate(input_numpy_array) + + sys_info = flare.system_info() + print(f"system info is: {sys_info}", flush=True) + print(f"finish round: {input_model.current_round}", flush=True) + + # send back the model + print(f"send back: {output_numpy_array}", flush=True) + flare.send( + flare.FLModel( + params={"numpy_key": output_numpy_array}, + params_type="FULL", + metrics={"accuracy": metrics}, + current_round=input_model.current_round, + ) + ) + + +if __name__ == "__main__": + main() diff --git a/tests/integration_test/data/jobs/np_loop_cell_pipe/app/custom/train_metrics.py b/tests/integration_test/data/jobs/np_loop_cell_pipe/app/custom/train_metrics.py new file mode 100755 index 0000000000..e508b74f30 --- /dev/null +++ b/tests/integration_test/data/jobs/np_loop_cell_pipe/app/custom/train_metrics.py @@ -0,0 +1,86 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import time + +import nvflare.client as flare +from nvflare.client.tracking import MLflowWriter + + +def train(input_arr, current_round, epochs=3): + writer = MLflowWriter() + output_arr = copy.deepcopy(input_arr) + num_of_data = 2000 + batch_size = 16 + num_of_batches = num_of_data // batch_size + for i in range(epochs): + for j in range(num_of_batches): + global_step = current_round * num_of_batches * epochs + i * num_of_batches + j + print(f"logging record: {global_step}") + writer.log_metric( + key="global_step", + value=global_step, + step=global_step, + ) + # mock training with plus 1 + output_arr += 1 + # assume each epoch takes 1 seconds + time.sleep(1.0) + return output_arr + + +def evaluate(input_arr): + # mock evaluation metrics + return 100 + + +def main(): + # initializes NVFlare interface + flare.init() + + # get system information + sys_info = flare.system_info() + print(f"system info is: {sys_info}") + + while flare.is_running(): + input_model = flare.receive() + print(f"received weights is: {input_model.params}") + + input_numpy_array = input_model.params["numpy_key"] + + # training + output_numpy_array = train(input_numpy_array, current_round=input_model.current_round, epochs=3) + + # evaluation + metrics = evaluate(input_numpy_array) + + sys_info = flare.system_info() + print(f"system info is: {sys_info}") + print(f"finish round: {input_model.current_round}") + + # send back the model + print(f"send back: {output_numpy_array}") + flare.send( + flare.FLModel( + params={"numpy_key": output_numpy_array}, + params_type="FULL", + metrics={"accuracy": metrics}, + current_round=input_model.current_round, + ) + ) + + +if __name__ == "__main__": + main() diff --git a/tests/integration_test/data/jobs/np_loop_cell_pipe/meta.conf b/tests/integration_test/data/jobs/np_loop_cell_pipe/meta.conf new file mode 100644 index 0000000000..1b27fd54a7 --- /dev/null +++ b/tests/integration_test/data/jobs/np_loop_cell_pipe/meta.conf @@ -0,0 +1,11 @@ +{ + name = "np_loop_cell_pipe" + resource_spec {} + deploy_map { + app = [ + "@ALL" + ] + } + min_clients = 2 + mandatory_clients = [] +} diff --git a/tests/integration_test/data/jobs/np_metrics/app/config/config_fed_client.conf b/tests/integration_test/data/jobs/np_metrics/app/config/config_fed_client.conf new file mode 100644 index 0000000000..455684cf80 --- /dev/null +++ b/tests/integration_test/data/jobs/np_metrics/app/config/config_fed_client.conf @@ -0,0 +1,77 @@ +{ + format_version = 2 + app_script = "train_metrics.py" + app_config = "" + executors = [ + { + tasks = [ + "train" + ] + executor { + path = "nvflare.app_common.executors.client_api_launcher_executor.ClientAPILauncherExecutor" + args { + launcher_id = "launcher" + pipe_id = "pipe" + heartbeat_timeout = 60 + params_exchange_format = "numpy" + params_transfer_type = "DIFF" + train_with_evaluation = true + } + } + } + ] + task_data_filters = [] + task_result_filters = [] + components = [ + { + id = "launcher" + path = "nvflare.app_common.launchers.subprocess_launcher.SubprocessLauncher" + args { + script = "python3 custom/{app_script} {app_config} " + launch_once = true + } + } + { + id = "pipe" + path = "nvflare.fuel.utils.pipe.cell_pipe.CellPipe" + args { + mode = "PASSIVE" + site_name = "{SITE_NAME}" + token = "{JOB_ID}" + root_url = "{ROOT_URL}" + secure_mode = "{SECURE_MODE}" + workspace_dir = "{WORKSPACE}" + } + } + { + id = "metrics_pipe" + path = "nvflare.fuel.utils.pipe.cell_pipe.CellPipe" + args { + mode = "PASSIVE" + site_name = "{SITE_NAME}" + token = "{JOB_ID}" + root_url = "{ROOT_URL}" + secure_mode = "{SECURE_MODE}" + workspace_dir = "{WORKSPACE}" + } + } + { + id = "metric_relay" + path = "nvflare.app_common.widgets.metric_relay.MetricRelay" + args { + pipe_id = "metrics_pipe" + event_type = "fed.analytix_log_stats" + read_interval = 0.1 + } + } + { + id = "client_api_config_preparer" + path = "nvflare.app_common.widgets.external_configurator.ExternalConfigurator" + args { + component_ids = [ + "metric_relay" + ] + } + } + ] +} diff --git a/tests/integration_test/data/jobs/np_metrics/app/config/config_fed_server.conf b/tests/integration_test/data/jobs/np_metrics/app/config/config_fed_server.conf new file mode 100644 index 0000000000..a5f1f0856c --- /dev/null +++ b/tests/integration_test/data/jobs/np_metrics/app/config/config_fed_server.conf @@ -0,0 +1,56 @@ +{ + format_version = 2 + task_data_filters = [] + task_result_filters = [] + workflows = [ + { + id = "scatter_and_gather" + path = "nvflare.app_common.workflows.scatter_and_gather.ScatterAndGather" + args { + min_clients = 2 + num_rounds = 5 + start_round = 0 + wait_time_after_min_received = 0 + aggregator_id = "aggregator" + persistor_id = "persistor" + shareable_generator_id = "shareable_generator" + train_task_name = "train" + train_timeout = 0 + } + } + ] + components = [ + { + id = "persistor" + path = "nvflare.app_common.np.np_model_persistor.NPModelPersistor" + } + { + id = "shareable_generator" + path = "nvflare.app_common.shareablegenerators.full_model_shareable_generator.FullModelShareableGenerator" + args {} + } + { + id = "aggregator" + path = "nvflare.app_common.aggregators.intime_accumulate_model_aggregator.InTimeAccumulateWeightedAggregator" + args { + expected_data_kind = "WEIGHT_DIFF" + } + } + { + id = "model_selector" + path = "nvflare.app_common.widgets.intime_model_selector.IntimeModelSelector" + args { + key_metric = "accuracy" + } + } + { + id = "tb_analytics_receiver" + path = "nvflare.app_opt.tracking.tb.tb_receiver.TBAnalyticsReceiver" + args { + events = [ + "fed.analytix_log_stats" + ] + } + } + ] +} diff --git a/tests/integration_test/data/jobs/np_metrics/app/custom/train_diff.py b/tests/integration_test/data/jobs/np_metrics/app/custom/train_diff.py new file mode 100755 index 0000000000..08f3c8e7c0 --- /dev/null +++ b/tests/integration_test/data/jobs/np_metrics/app/custom/train_diff.py @@ -0,0 +1,60 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +import nvflare.client as flare + + +def train(input_arr): + output_arr = copy.deepcopy(input_arr) + # mock training with plus 1 + return output_arr + 1 + + +def evaluate(input_arr): + # mock evaluation metrics + return 100 + + +def main(): + # initializes NVFlare interface + flare.init() + + # get model from NVFlare + input_model = flare.receive() + print(f"received weights is: {input_model.params}") + + # get system information + sys_info = flare.system_info() + print(f"system info is: {sys_info}") + + input_numpy_array = input_model.params["numpy_key"] + + # training + output_numpy_array = train(input_numpy_array) + + # evaluation + metrics = evaluate(input_numpy_array) + + # calculate difference here + diff = output_numpy_array - input_numpy_array + + # send back the model difference + print(f"send back: {diff}") + flare.send(flare.FLModel(params={"numpy_key": diff}, params_type="DIFF", metrics={"accuracy": metrics})) + + +if __name__ == "__main__": + main() diff --git a/tests/integration_test/data/jobs/np_metrics/app/custom/train_full.py b/tests/integration_test/data/jobs/np_metrics/app/custom/train_full.py new file mode 100755 index 0000000000..9c05536b85 --- /dev/null +++ b/tests/integration_test/data/jobs/np_metrics/app/custom/train_full.py @@ -0,0 +1,59 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +import nvflare.client as flare + + +def train(input_arr): + output_arr = copy.deepcopy(input_arr) + # mock training with plus 1 + return output_arr + 1 + + +def evaluate(input_arr): + # mock evaluation metrics + return 100 + + +def main(): + # initializes NVFlare interface + flare.init() + + # get model from NVFlare + input_model = flare.receive() + print(f"received weights is: {input_model.params}") + + # get system information + sys_info = flare.system_info() + print(f"system info is: {sys_info}") + + input_numpy_array = input_model.params["numpy_key"] + + # training + output_numpy_array = train(input_numpy_array) + + # evaluation + metrics = evaluate(input_numpy_array) + + # send back the model + print(f"send back: {output_numpy_array}") + flare.send( + flare.FLModel(params={"numpy_key": output_numpy_array}, params_type="FULL", metrics={"accuracy": metrics}) + ) + + +if __name__ == "__main__": + main() diff --git a/tests/integration_test/data/jobs/np_metrics/app/custom/train_loop.py b/tests/integration_test/data/jobs/np_metrics/app/custom/train_loop.py new file mode 100755 index 0000000000..c7916efeec --- /dev/null +++ b/tests/integration_test/data/jobs/np_metrics/app/custom/train_loop.py @@ -0,0 +1,68 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +import nvflare.client as flare + + +def train(input_arr): + output_arr = copy.deepcopy(input_arr) + # mock training with plus 1 + return output_arr + 1 + + +def evaluate(input_arr): + # mock evaluation metrics + return 100 + + +def main(): + # initializes NVFlare interface + flare.init() + + # get system information + sys_info = flare.system_info() + print(f"system info is: {sys_info}") + + while flare.is_running(): + input_model = flare.receive() + print(f"received weights is: {input_model.params}") + + input_numpy_array = input_model.params["numpy_key"] + + # training + output_numpy_array = train(input_numpy_array) + + # evaluation + metrics = evaluate(input_numpy_array) + + sys_info = flare.system_info() + print(f"system info is: {sys_info}") + print(f"finish round: {input_model.current_round}") + + # send back the model + print(f"send back: {output_numpy_array}") + flare.send( + flare.FLModel( + params={"numpy_key": output_numpy_array}, + params_type="FULL", + metrics={"accuracy": metrics}, + current_round=input_model.current_round, + ) + ) + + +if __name__ == "__main__": + main() diff --git a/tests/integration_test/data/jobs/np_metrics/app/custom/train_metrics.py b/tests/integration_test/data/jobs/np_metrics/app/custom/train_metrics.py new file mode 100755 index 0000000000..e5e7cd4257 --- /dev/null +++ b/tests/integration_test/data/jobs/np_metrics/app/custom/train_metrics.py @@ -0,0 +1,86 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import time + +import nvflare.client as flare +from nvflare.client.tracking import MLflowWriter + + +def train(input_arr, current_round, epochs=3): + writer = MLflowWriter() + output_arr = copy.deepcopy(input_arr) + num_of_data = 2000 + batch_size = 16 + num_of_batches = num_of_data // batch_size + for i in range(epochs): + for j in range(num_of_batches): + global_step = current_round * num_of_batches * epochs + i * num_of_batches + j + print(f"logging record: {global_step}", flush=True) + writer.log_metric( + key="global_step", + value=global_step, + step=global_step, + ) + # mock training with plus 1 + output_arr += 1 + # assume each epoch takes 1 seconds + time.sleep(1.0) + return output_arr + + +def evaluate(input_arr): + # mock evaluation metrics + return 100 + + +def main(): + # initializes NVFlare interface + flare.init() + + # get system information + sys_info = flare.system_info() + print(f"system info is: {sys_info}") + + while flare.is_running(): + input_model = flare.receive() + print(f"received weights is: {input_model.params}") + + input_numpy_array = input_model.params["numpy_key"] + + # training + output_numpy_array = train(input_numpy_array, current_round=input_model.current_round, epochs=3) + + # evaluation + metrics = evaluate(input_numpy_array) + + sys_info = flare.system_info() + print(f"system info is: {sys_info}") + print(f"finish round: {input_model.current_round}") + + # send back the model + print(f"send back: {output_numpy_array}") + flare.send( + flare.FLModel( + params={"numpy_key": output_numpy_array}, + params_type="FULL", + metrics={"accuracy": metrics}, + current_round=input_model.current_round, + ) + ) + + +if __name__ == "__main__": + main() diff --git a/tests/integration_test/data/jobs/np_metrics/meta.conf b/tests/integration_test/data/jobs/np_metrics/meta.conf new file mode 100644 index 0000000000..13374f8a0e --- /dev/null +++ b/tests/integration_test/data/jobs/np_metrics/meta.conf @@ -0,0 +1,11 @@ +{ + name = "np_metrics" + resource_spec {} + deploy_map { + app = [ + "@ALL" + ] + } + min_clients = 2 + mandatory_clients = [] +} diff --git a/tests/integration_test/data/jobs/pt_client_api/app/config/config_fed_client.conf b/tests/integration_test/data/jobs/pt_client_api/app/config/config_fed_client.conf new file mode 100644 index 0000000000..289f0e81a0 --- /dev/null +++ b/tests/integration_test/data/jobs/pt_client_api/app/config/config_fed_client.conf @@ -0,0 +1,77 @@ +{ + format_version = 2 + app_script = "cifar10_fl.py" + app_config = "" + executors = [ + { + tasks = [ + "train" + ] + executor { + path = "nvflare.app_opt.pt.client_api_launcher_executor.PTClientAPILauncherExecutor" + args { + launcher_id = "launcher" + pipe_id = "pipe" + heartbeat_timeout = 60 + params_exchange_format = "pytorch" + params_transfer_type = "DIFF" + train_with_evaluation = true + } + } + } + ] + task_data_filters = [] + task_result_filters = [] + components = [ + { + id = "launcher" + path = "nvflare.app_common.launchers.subprocess_launcher.SubprocessLauncher" + args { + script = "python3 -u custom/{app_script} {app_config} " + launch_once = false + } + } + { + id = "pipe" + path = "nvflare.fuel.utils.pipe.cell_pipe.CellPipe" + args { + mode = "PASSIVE" + site_name = "{SITE_NAME}" + token = "{JOB_ID}" + root_url = "{ROOT_URL}" + secure_mode = "{SECURE_MODE}" + workspace_dir = "{WORKSPACE}" + } + } + { + id = "metrics_pipe" + path = "nvflare.fuel.utils.pipe.cell_pipe.CellPipe" + args { + mode = "PASSIVE" + site_name = "{SITE_NAME}" + token = "{JOB_ID}" + root_url = "{ROOT_URL}" + secure_mode = "{SECURE_MODE}" + workspace_dir = "{WORKSPACE}" + } + } + { + id = "metric_relay" + path = "nvflare.app_common.widgets.metric_relay.MetricRelay" + args { + pipe_id = "metrics_pipe" + event_type = "fed.analytix_log_stats" + read_interval = 0.1 + } + } + { + id = "config_preparer" + path = "nvflare.app_common.widgets.external_configurator.ExternalConfigurator" + args { + component_ids = [ + "metric_relay" + ] + } + } + ] +} diff --git a/tests/integration_test/data/jobs/pt_client_api/app/config/config_fed_server.conf b/tests/integration_test/data/jobs/pt_client_api/app/config/config_fed_server.conf new file mode 100644 index 0000000000..8245a2d527 --- /dev/null +++ b/tests/integration_test/data/jobs/pt_client_api/app/config/config_fed_server.conf @@ -0,0 +1,62 @@ +{ + format_version = 2 + task_data_filters = [] + task_result_filters = [] + model_class_path = "net.Net" + workflows = [ + { + id = "scatter_and_gather" + path = "nvflare.app_common.workflows.scatter_and_gather.ScatterAndGather" + args { + min_clients = 2 + num_rounds = 2 + start_round = 0 + wait_time_after_min_received = 0 + aggregator_id = "aggregator" + persistor_id = "persistor" + shareable_generator_id = "shareable_generator" + train_task_name = "train" + train_timeout = 0 + } + } + ] + components = [ + { + id = "persistor" + path = "nvflare.app_opt.pt.file_model_persistor.PTFileModelPersistor" + args { + model { + path = "{model_class_path}" + } + } + } + { + id = "shareable_generator" + path = "nvflare.app_common.shareablegenerators.full_model_shareable_generator.FullModelShareableGenerator" + args {} + } + { + id = "aggregator" + path = "nvflare.app_common.aggregators.intime_accumulate_model_aggregator.InTimeAccumulateWeightedAggregator" + args { + expected_data_kind = "WEIGHT_DIFF" + } + } + { + id = "model_selector" + path = "nvflare.app_common.widgets.intime_model_selector.IntimeModelSelector" + args { + key_metric = "accuracy" + } + } + { + id = "receiver" + path = "nvflare.app_opt.tracking.tb.tb_receiver.TBAnalyticsReceiver" + args { + events = [ + "fed.analytix_log_stats" + ] + } + } + ] +} diff --git a/tests/integration_test/data/jobs/pt_client_api/app/custom/cifar10_fl.py b/tests/integration_test/data/jobs/pt_client_api/app/custom/cifar10_fl.py new file mode 100644 index 0000000000..25fc291e98 --- /dev/null +++ b/tests/integration_test/data/jobs/pt_client_api/app/custom/cifar10_fl.py @@ -0,0 +1,137 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch.optim as optim +import torchvision +import torchvision.transforms as transforms +from net import Net + +# (1) import nvflare client API +import nvflare.client as flare + +# (optional) metrics +from nvflare.client.tracking import SummaryWriter + +# (optional) set a fix place so we don't need to download everytime +DATASET_PATH = "/tmp/nvflare/data" +# (optional) We change to use GPU to speed things up. +# if you want to use CPU, change DEVICE="cpu" +DEVICE = "cuda:0" + + +def main(): + transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) + + batch_size = 4 + epochs = 2 + + trainset = torchvision.datasets.CIFAR10(root=DATASET_PATH, train=True, download=True, transform=transform) + trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2) + + testset = torchvision.datasets.CIFAR10(root=DATASET_PATH, train=False, download=True, transform=transform) + testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2) + + net = Net() + + # (2) initializes NVFlare client API + flare.init() + + summary_writer = SummaryWriter() + while flare.is_running(): + # (3) receives FLModel from NVFlare + input_model = flare.receive() + print(f"current_round={input_model.current_round}") + + # (4) loads model from NVFlare + net.load_state_dict(input_model.params) + + criterion = nn.CrossEntropyLoss() + optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) + + # (optional) use GPU to speed things up + net.to(DEVICE) + # (optional) calculate total steps + steps = epochs * len(trainloader) + for epoch in range(epochs): # loop over the dataset multiple times + + running_loss = 0.0 + for i, data in enumerate(trainloader, 0): + # get the inputs; data is a list of [inputs, labels] + # (optional) use GPU to speed things up + inputs, labels = data[0].to(DEVICE), data[1].to(DEVICE) + + # zero the parameter gradients + optimizer.zero_grad() + + # forward + backward + optimize + outputs = net(inputs) + loss = criterion(outputs, labels) + loss.backward() + optimizer.step() + + # print statistics + running_loss += loss.item() + if i % 2000 == 1999: # print every 2000 mini-batches + print(f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}") + global_step = input_model.current_round * steps + epoch * len(trainloader) + i + + summary_writer.add_scalar(tag="loss_for_each_batch", scalar=running_loss, global_step=global_step) + running_loss = 0.0 + + print("Finished Training") + + PATH = "./cifar_net.pth" + torch.save(net.state_dict(), PATH) + + # (5) wraps evaluation logic into a method to re-use for + # evaluation on both trained and received model + def evaluate(input_weights): + net = Net() + net.load_state_dict(input_weights) + # (optional) use GPU to speed things up + net.to(DEVICE) + + correct = 0 + total = 0 + # since we're not training, we don't need to calculate the gradients for our outputs + with torch.no_grad(): + for data in testloader: + # (optional) use GPU to speed things up + images, labels = data[0].to(DEVICE), data[1].to(DEVICE) + # calculate outputs by running images through the network + outputs = net(images) + # the class with the highest energy is what we choose as prediction + _, predicted = torch.max(outputs.data, 1) + total += labels.size(0) + correct += (predicted == labels).sum().item() + + print(f"Accuracy of the network on the 10000 test images: {100 * correct // total} %") + return 100 * correct // total + + # (6) evaluate on received model for model selection + accuracy = evaluate(input_model.params) + # (7) construct trained FL model + output_model = flare.FLModel( + params=net.cpu().state_dict(), + metrics={"accuracy": accuracy}, + meta={"NUM_STEPS_CURRENT_ROUND": steps}, + ) + # (8) send model back to NVFlare + flare.send(output_model) + + +if __name__ == "__main__": + main() diff --git a/tests/integration_test/data/jobs/pt_client_api/app/custom/net.py b/tests/integration_test/data/jobs/pt_client_api/app/custom/net.py new file mode 100644 index 0000000000..47ac7e9589 --- /dev/null +++ b/tests/integration_test/data/jobs/pt_client_api/app/custom/net.py @@ -0,0 +1,37 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class Net(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(3, 6, 5) + self.pool = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(6, 16, 5) + self.fc1 = nn.Linear(16 * 5 * 5, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + + def forward(self, x): + x = self.pool(F.relu(self.conv1(x))) + x = self.pool(F.relu(self.conv2(x))) + x = torch.flatten(x, 1) # flatten all dimensions except batch + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + return x diff --git a/tests/integration_test/data/jobs/pt_client_api/meta.conf b/tests/integration_test/data/jobs/pt_client_api/meta.conf new file mode 100644 index 0000000000..479de325a2 --- /dev/null +++ b/tests/integration_test/data/jobs/pt_client_api/meta.conf @@ -0,0 +1,11 @@ +{ + name = "client_api" + resource_spec {} + deploy_map { + app = [ + "@ALL" + ] + } + min_clients = 2 + mandatory_clients = [] +} diff --git a/tests/integration_test/data/jobs/pt_client_api_cyclic/app/config/config_fed_client.conf b/tests/integration_test/data/jobs/pt_client_api_cyclic/app/config/config_fed_client.conf new file mode 100644 index 0000000000..1809d849c4 --- /dev/null +++ b/tests/integration_test/data/jobs/pt_client_api_cyclic/app/config/config_fed_client.conf @@ -0,0 +1,43 @@ +{ + format_version = 2 + app_script = "cifar10_fl.py" + app_config = "" + executors = [ + { + tasks = [ + "train" + ] + executor { + path = "nvflare.app_opt.pt.client_api_launcher_executor.PTClientAPILauncherExecutor" + args { + launcher_id = "launcher" + pipe_id = "pipe" + heartbeat_timeout = 60 + params_exchange_format = "pytorch" + params_transfer_type = "DIFF" + train_with_evaluation = true + } + } + } + ] + task_data_filters = [] + task_result_filters = [] + components = [ + { + id = "launcher" + path = "nvflare.app_common.launchers.subprocess_launcher.SubprocessLauncher" + args { + script = "python3 -u custom/{app_script} {app_config} " + launch_once = true + } + } + { + id = "pipe" + path = "nvflare.fuel.utils.pipe.file_pipe.FilePipe" + args { + mode = "PASSIVE" + root_path = "{WORKSPACE}/{JOB_ID}/{SITE_NAME}" + } + } + ] +} diff --git a/tests/integration_test/data/jobs/pt_client_api_cyclic/app/config/config_fed_server.conf b/tests/integration_test/data/jobs/pt_client_api_cyclic/app/config/config_fed_server.conf new file mode 100644 index 0000000000..a194c1ed9e --- /dev/null +++ b/tests/integration_test/data/jobs/pt_client_api_cyclic/app/config/config_fed_server.conf @@ -0,0 +1,38 @@ +{ + format_version = 2 + server { + heart_beat_timeout = 600 + } + task_data_filters = [] + task_result_filters = [] + model_class_path = "net.Net" + components = [ + { + id = "persistor" + path = "nvflare.app_opt.pt.file_model_persistor.PTFileModelPersistor" + args { + model { + path = "{model_class_path}" + } + } + } + { + id = "shareable_generator" + path = "nvflare.app_common.shareablegenerators.full_model_shareable_generator.FullModelShareableGenerator" + args {} + } + ] + workflows = [ + { + id = "cyclic_ctl" + path = "nvflare.app_common.workflows.cyclic_ctl.CyclicController" + args { + num_rounds = 3 + task_assignment_timeout = 8 + persistor_id = "persistor" + shareable_generator_id = "shareable_generator" + task_name = "train" + } + } + ] +} diff --git a/tests/integration_test/data/jobs/pt_client_api_cyclic/app/custom/cifar10_fl.py b/tests/integration_test/data/jobs/pt_client_api_cyclic/app/custom/cifar10_fl.py new file mode 100644 index 0000000000..80df2f4086 --- /dev/null +++ b/tests/integration_test/data/jobs/pt_client_api_cyclic/app/custom/cifar10_fl.py @@ -0,0 +1,132 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch.optim as optim +import torchvision +import torchvision.transforms as transforms +from net import Net + +# (1) import nvflare client API +import nvflare.client as flare + +# (optional) set a fix place so we don't need to download everytime +DATASET_PATH = "/tmp/nvflare/data" +# (optional) We change to use GPU to speed things up. +# if you want to use CPU, change DEVICE="cpu" +DEVICE = "cuda:0" + + +def main(): + transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) + + batch_size = 4 + epochs = 2 + + trainset = torchvision.datasets.CIFAR10(root=DATASET_PATH, train=True, download=True, transform=transform) + trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2) + + testset = torchvision.datasets.CIFAR10(root=DATASET_PATH, train=False, download=True, transform=transform) + testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2) + + net = Net() + + # (2) initializes NVFlare client API + flare.init() + + while flare.is_running(): + # (3) receives FLModel from NVFlare + input_model = flare.receive() + print(f"current_round={input_model.current_round}") + + # (4) loads model from NVFlare + net.load_state_dict(input_model.params) + + criterion = nn.CrossEntropyLoss() + optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) + + # (optional) use GPU to speed things up + net.to(DEVICE) + # (optional) calculate total steps + steps = epochs * len(trainloader) + for epoch in range(epochs): # loop over the dataset multiple times + + running_loss = 0.0 + for i, data in enumerate(trainloader, 0): + # get the inputs; data is a list of [inputs, labels] + # (optional) use GPU to speed things up + inputs, labels = data[0].to(DEVICE), data[1].to(DEVICE) + + # zero the parameter gradients + optimizer.zero_grad() + + # forward + backward + optimize + outputs = net(inputs) + loss = criterion(outputs, labels) + loss.backward() + optimizer.step() + + # print statistics + running_loss += loss.item() + if i % 2000 == 1999: # print every 2000 mini-batches + print(f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}") + global_step = input_model.current_round * steps + epoch * len(trainloader) + i + + running_loss = 0.0 + + print("Finished Training") + + PATH = "./cifar_net.pth" + torch.save(net.state_dict(), PATH) + + # (5) wraps evaluation logic into a method to re-use for + # evaluation on both trained and received model + def evaluate(input_weights): + net = Net() + net.load_state_dict(input_weights) + # (optional) use GPU to speed things up + net.to(DEVICE) + + correct = 0 + total = 0 + # since we're not training, we don't need to calculate the gradients for our outputs + with torch.no_grad(): + for data in testloader: + # (optional) use GPU to speed things up + images, labels = data[0].to(DEVICE), data[1].to(DEVICE) + # calculate outputs by running images through the network + outputs = net(images) + # the class with the highest energy is what we choose as prediction + _, predicted = torch.max(outputs.data, 1) + total += labels.size(0) + correct += (predicted == labels).sum().item() + + print(f"Accuracy of the network on the 10000 test images: {100 * correct // total} %") + return 100 * correct // total + + # (6) evaluate on received model for model selection + accuracy = evaluate(input_model.params) + # (7) construct trained FL model + output_model = flare.FLModel( + params=net.cpu().state_dict(), + metrics={"accuracy": accuracy}, + meta={"NUM_STEPS_CURRENT_ROUND": steps}, + ) + # (8) send model back to NVFlare + flare.send(output_model) + + +if __name__ == "__main__": + main() diff --git a/tests/integration_test/data/jobs/pt_client_api_cyclic/app/custom/net.py b/tests/integration_test/data/jobs/pt_client_api_cyclic/app/custom/net.py new file mode 100644 index 0000000000..47ac7e9589 --- /dev/null +++ b/tests/integration_test/data/jobs/pt_client_api_cyclic/app/custom/net.py @@ -0,0 +1,37 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class Net(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(3, 6, 5) + self.pool = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(6, 16, 5) + self.fc1 = nn.Linear(16 * 5 * 5, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + + def forward(self, x): + x = self.pool(F.relu(self.conv1(x))) + x = self.pool(F.relu(self.conv2(x))) + x = torch.flatten(x, 1) # flatten all dimensions except batch + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + return x diff --git a/tests/integration_test/data/jobs/pt_client_api_cyclic/meta.conf b/tests/integration_test/data/jobs/pt_client_api_cyclic/meta.conf new file mode 100644 index 0000000000..f6ce5c1810 --- /dev/null +++ b/tests/integration_test/data/jobs/pt_client_api_cyclic/meta.conf @@ -0,0 +1,11 @@ +{ + name = "client_api" + resource_spec {} + deploy_map { + app = [ + "@ALL" + ] + } + min_clients = 1 + mandatory_clients = [] +} diff --git a/tests/integration_test/data/jobs/pt_client_api_launch_once/app/config/config_fed_client.conf b/tests/integration_test/data/jobs/pt_client_api_launch_once/app/config/config_fed_client.conf new file mode 100644 index 0000000000..2d3d180531 --- /dev/null +++ b/tests/integration_test/data/jobs/pt_client_api_launch_once/app/config/config_fed_client.conf @@ -0,0 +1,77 @@ +{ + format_version = 2 + app_script = "cifar10_fl.py" + app_config = "" + executors = [ + { + tasks = [ + "train" + ] + executor { + path = "nvflare.app_opt.pt.client_api_launcher_executor.PTClientAPILauncherExecutor" + args { + launcher_id = "launcher" + pipe_id = "pipe" + heartbeat_timeout = 60 + params_exchange_format = "pytorch" + params_transfer_type = "DIFF" + train_with_evaluation = true + } + } + } + ] + task_data_filters = [] + task_result_filters = [] + components = [ + { + id = "launcher" + path = "nvflare.app_common.launchers.subprocess_launcher.SubprocessLauncher" + args { + script = "python3 -u custom/{app_script} {app_config} " + launch_once = true + } + } + { + id = "pipe" + path = "nvflare.fuel.utils.pipe.cell_pipe.CellPipe" + args { + mode = "PASSIVE" + site_name = "{SITE_NAME}" + token = "{JOB_ID}" + root_url = "{ROOT_URL}" + secure_mode = "{SECURE_MODE}" + workspace_dir = "{WORKSPACE}" + } + } + { + id = "metrics_pipe" + path = "nvflare.fuel.utils.pipe.cell_pipe.CellPipe" + args { + mode = "PASSIVE" + site_name = "{SITE_NAME}" + token = "{JOB_ID}" + root_url = "{ROOT_URL}" + secure_mode = "{SECURE_MODE}" + workspace_dir = "{WORKSPACE}" + } + } + { + id = "metric_relay" + path = "nvflare.app_common.widgets.metric_relay.MetricRelay" + args { + pipe_id = "metrics_pipe" + event_type = "fed.analytix_log_stats" + read_interval = 0.1 + } + } + { + id = "config_preparer" + path = "nvflare.app_common.widgets.external_configurator.ExternalConfigurator" + args { + component_ids = [ + "metric_relay" + ] + } + } + ] +} diff --git a/tests/integration_test/data/jobs/pt_client_api_launch_once/app/config/config_fed_server.conf b/tests/integration_test/data/jobs/pt_client_api_launch_once/app/config/config_fed_server.conf new file mode 100644 index 0000000000..8245a2d527 --- /dev/null +++ b/tests/integration_test/data/jobs/pt_client_api_launch_once/app/config/config_fed_server.conf @@ -0,0 +1,62 @@ +{ + format_version = 2 + task_data_filters = [] + task_result_filters = [] + model_class_path = "net.Net" + workflows = [ + { + id = "scatter_and_gather" + path = "nvflare.app_common.workflows.scatter_and_gather.ScatterAndGather" + args { + min_clients = 2 + num_rounds = 2 + start_round = 0 + wait_time_after_min_received = 0 + aggregator_id = "aggregator" + persistor_id = "persistor" + shareable_generator_id = "shareable_generator" + train_task_name = "train" + train_timeout = 0 + } + } + ] + components = [ + { + id = "persistor" + path = "nvflare.app_opt.pt.file_model_persistor.PTFileModelPersistor" + args { + model { + path = "{model_class_path}" + } + } + } + { + id = "shareable_generator" + path = "nvflare.app_common.shareablegenerators.full_model_shareable_generator.FullModelShareableGenerator" + args {} + } + { + id = "aggregator" + path = "nvflare.app_common.aggregators.intime_accumulate_model_aggregator.InTimeAccumulateWeightedAggregator" + args { + expected_data_kind = "WEIGHT_DIFF" + } + } + { + id = "model_selector" + path = "nvflare.app_common.widgets.intime_model_selector.IntimeModelSelector" + args { + key_metric = "accuracy" + } + } + { + id = "receiver" + path = "nvflare.app_opt.tracking.tb.tb_receiver.TBAnalyticsReceiver" + args { + events = [ + "fed.analytix_log_stats" + ] + } + } + ] +} diff --git a/tests/integration_test/data/jobs/pt_client_api_launch_once/app/custom/cifar10_fl.py b/tests/integration_test/data/jobs/pt_client_api_launch_once/app/custom/cifar10_fl.py new file mode 100644 index 0000000000..25fc291e98 --- /dev/null +++ b/tests/integration_test/data/jobs/pt_client_api_launch_once/app/custom/cifar10_fl.py @@ -0,0 +1,137 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch.optim as optim +import torchvision +import torchvision.transforms as transforms +from net import Net + +# (1) import nvflare client API +import nvflare.client as flare + +# (optional) metrics +from nvflare.client.tracking import SummaryWriter + +# (optional) set a fix place so we don't need to download everytime +DATASET_PATH = "/tmp/nvflare/data" +# (optional) We change to use GPU to speed things up. +# if you want to use CPU, change DEVICE="cpu" +DEVICE = "cuda:0" + + +def main(): + transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) + + batch_size = 4 + epochs = 2 + + trainset = torchvision.datasets.CIFAR10(root=DATASET_PATH, train=True, download=True, transform=transform) + trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2) + + testset = torchvision.datasets.CIFAR10(root=DATASET_PATH, train=False, download=True, transform=transform) + testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2) + + net = Net() + + # (2) initializes NVFlare client API + flare.init() + + summary_writer = SummaryWriter() + while flare.is_running(): + # (3) receives FLModel from NVFlare + input_model = flare.receive() + print(f"current_round={input_model.current_round}") + + # (4) loads model from NVFlare + net.load_state_dict(input_model.params) + + criterion = nn.CrossEntropyLoss() + optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) + + # (optional) use GPU to speed things up + net.to(DEVICE) + # (optional) calculate total steps + steps = epochs * len(trainloader) + for epoch in range(epochs): # loop over the dataset multiple times + + running_loss = 0.0 + for i, data in enumerate(trainloader, 0): + # get the inputs; data is a list of [inputs, labels] + # (optional) use GPU to speed things up + inputs, labels = data[0].to(DEVICE), data[1].to(DEVICE) + + # zero the parameter gradients + optimizer.zero_grad() + + # forward + backward + optimize + outputs = net(inputs) + loss = criterion(outputs, labels) + loss.backward() + optimizer.step() + + # print statistics + running_loss += loss.item() + if i % 2000 == 1999: # print every 2000 mini-batches + print(f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}") + global_step = input_model.current_round * steps + epoch * len(trainloader) + i + + summary_writer.add_scalar(tag="loss_for_each_batch", scalar=running_loss, global_step=global_step) + running_loss = 0.0 + + print("Finished Training") + + PATH = "./cifar_net.pth" + torch.save(net.state_dict(), PATH) + + # (5) wraps evaluation logic into a method to re-use for + # evaluation on both trained and received model + def evaluate(input_weights): + net = Net() + net.load_state_dict(input_weights) + # (optional) use GPU to speed things up + net.to(DEVICE) + + correct = 0 + total = 0 + # since we're not training, we don't need to calculate the gradients for our outputs + with torch.no_grad(): + for data in testloader: + # (optional) use GPU to speed things up + images, labels = data[0].to(DEVICE), data[1].to(DEVICE) + # calculate outputs by running images through the network + outputs = net(images) + # the class with the highest energy is what we choose as prediction + _, predicted = torch.max(outputs.data, 1) + total += labels.size(0) + correct += (predicted == labels).sum().item() + + print(f"Accuracy of the network on the 10000 test images: {100 * correct // total} %") + return 100 * correct // total + + # (6) evaluate on received model for model selection + accuracy = evaluate(input_model.params) + # (7) construct trained FL model + output_model = flare.FLModel( + params=net.cpu().state_dict(), + metrics={"accuracy": accuracy}, + meta={"NUM_STEPS_CURRENT_ROUND": steps}, + ) + # (8) send model back to NVFlare + flare.send(output_model) + + +if __name__ == "__main__": + main() diff --git a/tests/integration_test/data/jobs/pt_client_api_launch_once/app/custom/net.py b/tests/integration_test/data/jobs/pt_client_api_launch_once/app/custom/net.py new file mode 100644 index 0000000000..47ac7e9589 --- /dev/null +++ b/tests/integration_test/data/jobs/pt_client_api_launch_once/app/custom/net.py @@ -0,0 +1,37 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class Net(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(3, 6, 5) + self.pool = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(6, 16, 5) + self.fc1 = nn.Linear(16 * 5 * 5, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + + def forward(self, x): + x = self.pool(F.relu(self.conv1(x))) + x = self.pool(F.relu(self.conv2(x))) + x = torch.flatten(x, 1) # flatten all dimensions except batch + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + return x diff --git a/tests/integration_test/data/jobs/pt_client_api_launch_once/meta.conf b/tests/integration_test/data/jobs/pt_client_api_launch_once/meta.conf new file mode 100644 index 0000000000..479de325a2 --- /dev/null +++ b/tests/integration_test/data/jobs/pt_client_api_launch_once/meta.conf @@ -0,0 +1,11 @@ +{ + name = "client_api" + resource_spec {} + deploy_map { + app = [ + "@ALL" + ] + } + min_clients = 2 + mandatory_clients = [] +} diff --git a/tests/integration_test/data/jobs/qa_job_4558419/app/config/config_fed_client.conf b/tests/integration_test/data/jobs/qa_job_4558419/app/config/config_fed_client.conf new file mode 100644 index 0000000000..5b41e42a04 --- /dev/null +++ b/tests/integration_test/data/jobs/qa_job_4558419/app/config/config_fed_client.conf @@ -0,0 +1,120 @@ +{ + # version of the configuration + format_version = 2 + + # This is the application script which will be invoked. Client can replace this script with user's own training script. + app_script = "train.py" + + # Additional arguments needed by the training code. For example, in lightning, these can be --trainer.batch_size=xxx. + app_config = "" + + # Client Computing Executors. + executors = [ + { + # tasks the executors are defined to handle + tasks = ["train"] + + # This particular executor + executor { + + # This is an executor for Client API. The underline data exchange is using Pipe. + path = "nvflare.app_opt.pt.client_api_launcher_executor.PTClientAPILauncherExecutor" + + args { + # launcher_id is used to locate the Launcher object in "components" + launcher_id = "launcher" + + # pipe_id is used to locate the Pipe object in "components" + pipe_id = "pipe" + + # Timeout in seconds for waiting for a heartbeat from the training script. Defaults to 30 seconds. + # Please refer to the class docstring for all available arguments + heartbeat_timeout = 10 + + peer_read_timeout = 120 + + external_pre_init_timeout = 120 + + # format of the exchange parameters + params_exchange_format = "pytorch" + + # if the transfer_type is FULL, then it will be sent directly + # if the transfer_type is DIFF, then we will calculate the + # difference VS received parameters and send the difference + params_transfer_type = "DIFF" + + # if train_with_evaluation is true, the executor will expect + # the custom code need to send back both the trained parameters and the evaluation metric + # otherwise only trained parameters are expected + train_with_evaluation = true + } + } + } + ], + + # this defined an array of task data filters. If provided, it will control the data from server controller to client executor + task_data_filters = [] + + # this defined an array of task result filters. If provided, it will control the result from client executor to server controller + task_result_filters = [] + + components = [ + { + # component id is "launcher" + id = "launcher" + + # the class path of this component + path = "nvflare.app_common.launchers.subprocess_launcher.SubprocessLauncher" + + args { + # the launcher will invoke the script + script = "python3 -u custom/{app_script} {app_config} " + # if launch_once is true, the SubprocessLauncher will launch once for the whole job + # if launch_once is false, the SubprocessLauncher will launch a process for each task it receives from server + launch_once = false + } + } + { + id = "pipe" + path = "nvflare.fuel.utils.pipe.cell_pipe.CellPipe" + args { + mode = "PASSIVE" + site_name = "{SITE_NAME}" + token = "{JOB_ID}" + root_url = "{ROOT_URL}" + secure_mode = "{SECURE_MODE}" + workspace_dir = "{WORKSPACE}" + } + } + { + id = "metrics_pipe" + path = "nvflare.fuel.utils.pipe.cell_pipe.CellPipe" + args { + mode = "PASSIVE" + site_name = "{SITE_NAME}" + token = "{JOB_ID}" + root_url = "{ROOT_URL}" + secure_mode = "{SECURE_MODE}" + workspace_dir = "{WORKSPACE}" + } + }, + { + id = "metric_relay" + path = "nvflare.app_common.widgets.metric_relay.MetricRelay" + args { + pipe_id = "metrics_pipe" + event_type = "fed.analytix_log_stats" + # how fast should it read from the peer + read_interval = 0.1 + } + }, + { + # we use this component so the client api `flare.init()` can get required information + id = "config_preparer" + path = "nvflare.app_common.widgets.external_configurator.ExternalConfigurator" + args { + component_ids = ["metric_relay"] + } + } + ] +} diff --git a/tests/integration_test/data/jobs/qa_job_4558419/app/config/config_fed_server.conf b/tests/integration_test/data/jobs/qa_job_4558419/app/config/config_fed_server.conf new file mode 100644 index 0000000000..da5b075533 --- /dev/null +++ b/tests/integration_test/data/jobs/qa_job_4558419/app/config/config_fed_server.conf @@ -0,0 +1,106 @@ +{ + # version of the configuration + format_version = 2 + + # task data filter: if filters are provided, the filter will filter the data flow out of server to client. + task_data_filters =[] + + # task result filter: if filters are provided, the filter will filter the result flow out of client to server. + task_result_filters = [] + + # This assumes that there will be a "net.py" file with class name "Net". + # If your model code is not in "net.py" and class name is not "Net", please modify here + model_class_path = "net.Net" + + # workflows: Array of workflows the control the Federated Learning workflow lifecycle. + # One can specify multiple workflows. The NVFLARE will run them in the order specified. + workflows = [ + { + # 1st workflow" + id = "scatter_and_gather" + + # name = ScatterAndGather, path is the class path of the ScatterAndGather controller. + path = "nvflare.app_common.workflows.scatter_and_gather.ScatterAndGather" + args { + # argument of the ScatterAndGather class. + # min number of clients required for ScatterAndGather controller to move to the next round + # during the workflow cycle. The controller will wait until the min_clients returned from clients + # before move to the next step. + min_clients = 2 + + # number of global round of the training. + num_rounds = 2 + + # starting round is 0-based + start_round = 0 + + # after received min number of clients' result, + # how much time should we wait further before move to the next step + wait_time_after_min_received = 0 + + # For ScatterAndGather, the server will aggregate the weights based on the client's result. + # the aggregator component id is named here. One can use the this ID to find the corresponding + # aggregator component listed below + aggregator_id = "aggregator" + + # The Scatter and Gather controller use an persistor to load the model and save the model. + # The persistent component can be identified by component ID specified here. + persistor_id = "persistor" + + # Shareable to a communication message, i.e. shared between clients and server. + # Shareable generator is a component that responsible to take the model convert to/from this communication message: Shareable. + # The component can be identified via "shareable_generator_id" + shareable_generator_id = "shareable_generator" + + # train task name: client side needs to have an executor that handles this task + train_task_name = "train" + + # train timeout in second. If zero, meaning no timeout. + train_timeout = 0 + } + } + ] + + # List of components used in the server side workflow. + components = [ + { + # This is the persistence component used in above workflow. + # PTFileModelPersistor is a Pytorch persistor which save/read the model to/from file. + + id = "persistor" + path = "nvflare.app_opt.pt.file_model_persistor.PTFileModelPersistor" + + # the persitor class take model class as argument + # This imply that the model is initialized from the server-side. + # The initialized model will be broadcast to all the clients to start the training. + args.model.path = "{model_class_path}" + }, + { + # This is the generator that convert the model to shareable communication message structure used in workflow + id = "shareable_generator" + path = "nvflare.app_common.shareablegenerators.full_model_shareable_generator.FullModelShareableGenerator" + args = {} + }, + { + # This is the aggregator that perform the weighted average aggregation. + # the aggregation is "in-time", so it doesn't wait for client results, but aggregates as soon as it received the data. + id = "aggregator" + path = "nvflare.app_common.aggregators.intime_accumulate_model_aggregator.InTimeAccumulateWeightedAggregator" + args.expected_data_kind = "WEIGHT_DIFF" + }, + { + # This component is not directly used in Workflow. + # it select the best model based on the incoming global validation metrics. + id = "model_selector" + path = "nvflare.app_common.widgets.intime_model_selector.IntimeModelSelector" + # need to make sure this "key_metric" match what server side received + args.key_metric = "accuracy" + }, + { + id = "receiver" + path = "nvflare.app_opt.tracking.tb.tb_receiver.TBAnalyticsReceiver" + args.events = ["fed.analytix_log_stats"] + } + ] + +} diff --git a/tests/integration_test/data/jobs/qa_job_4558419/app/custom/net.py b/tests/integration_test/data/jobs/qa_job_4558419/app/custom/net.py new file mode 100644 index 0000000000..6d20a6a783 --- /dev/null +++ b/tests/integration_test/data/jobs/qa_job_4558419/app/custom/net.py @@ -0,0 +1,25 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch.nn as nn + + +class Net(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(10**3, 10**5) + + def forward(self, x): + x = self.fc1(x) + return x diff --git a/tests/integration_test/data/jobs/qa_job_4558419/app/custom/train.py b/tests/integration_test/data/jobs/qa_job_4558419/app/custom/train.py new file mode 100644 index 0000000000..64c990e787 --- /dev/null +++ b/tests/integration_test/data/jobs/qa_job_4558419/app/custom/train.py @@ -0,0 +1,68 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +import logging +import random +import re +from datetime import datetime + +import torch + +import nvflare.client as flare + + +def evaluate(data): + + t = time.time() / 1e10 + print(f"fake evaluate data: {data}") + print( + f"fake evaluate result: {t}, generated at {datetime.utcfromtimestamp(t * 1e10).strftime('%Y-%m-%d %H:%M:%S')}" + ) + return t + + +def main(): + flare.init() + input_model = flare.receive() + + print("@@@ input_model: ", input_model) + round_num = input_model.current_round + print("@@@ Round number in this round: ", round_num) + + site_name = input_model.meta.get("site_name") + multiplier = re.search(r"\d+", site_name).group() + print("@@@ site_name: ", site_name) + + start_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + if input_model.current_round == 0: + weight = torch.zeros([10**5, 10**3], dtype=torch.float32) + bias = torch.zeros([10**5], dtype=torch.float32) + else: + weight = input_model.params.get("fc1.weight") + bias = input_model.params.get("fc1.bias") + + zzz = random.uniform(1.0, 3.0) + print("@@@ Sleep " + str(zzz)) + time.sleep(zzz) + + weight = torch.add(weight, 1) * int(multiplier) + bias = torch.add(bias, 1) * int(multiplier) + + end_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + logging.info("Finished Training") + + params = {"fc1.weight": weight, "fc1.bias": bias} + + accuracy = evaluate(params) + + output_model = flare.FLModel( + params=params, + metrics={"accuracy": accuracy}, + meta={"NUM_STEPS_CURRENT_ROUND": 2, "start": start_time, "end": end_time}, + ) + + print("@@@ output_model: ", output_model) + flare.send(output_model) + + +if __name__ == "__main__": + main() diff --git a/tests/integration_test/data/jobs/qa_job_4558419/meta.conf b/tests/integration_test/data/jobs/qa_job_4558419/meta.conf new file mode 100644 index 0000000000..23c6c60e78 --- /dev/null +++ b/tests/integration_test/data/jobs/qa_job_4558419/meta.conf @@ -0,0 +1,11 @@ +{ + name = "qa_job_4558419" + resource_spec {} + deploy_map { + app = [ + "@ALL" + ] + } + min_clients = 2 + mandatory_clients = [] +} diff --git a/tests/integration_test/data/jobs/qa_job_4561583/app/config/config_fed_client.conf b/tests/integration_test/data/jobs/qa_job_4561583/app/config/config_fed_client.conf new file mode 100644 index 0000000000..9dd88b904c --- /dev/null +++ b/tests/integration_test/data/jobs/qa_job_4561583/app/config/config_fed_client.conf @@ -0,0 +1,77 @@ +{ + format_version = 2 + app_script = "poc_executor.py" + app_config = "" + executors = [ + { + tasks = [ + "train" + ] + executor { + path = "nvflare.app_opt.pt.client_api_launcher_executor.PTClientAPILauncherExecutor" + args { + launcher_id = "launcher" + pipe_id = "pipe" + heartbeat_timeout = 60 + params_exchange_format = "pytorch" + params_transfer_type = "DIFF" + train_with_evaluation = true + } + } + } + ] + task_data_filters = [] + task_result_filters = [] + components = [ + { + id = "launcher" + path = "nvflare.app_common.launchers.subprocess_launcher.SubprocessLauncher" + args { + script = "python3 -u custom/{app_script} {app_config} " + launch_once = false + } + } + { + id = "pipe" + path = "nvflare.fuel.utils.pipe.cell_pipe.CellPipe" + args { + mode = "PASSIVE" + site_name = "{SITE_NAME}" + token = "{JOB_ID}" + root_url = "{ROOT_URL}" + secure_mode = "{SECURE_MODE}" + workspace_dir = "{WORKSPACE}" + } + } + { + id = "metrics_pipe" + path = "nvflare.fuel.utils.pipe.cell_pipe.CellPipe" + args { + mode = "PASSIVE" + site_name = "{SITE_NAME}" + token = "{JOB_ID}" + root_url = "{ROOT_URL}" + secure_mode = "{SECURE_MODE}" + workspace_dir = "{WORKSPACE}" + } + } + { + id = "metric_relay" + path = "nvflare.app_common.widgets.metric_relay.MetricRelay" + args { + pipe_id = "metrics_pipe" + event_type = "fed.analytix_log_stats" + read_interval = 0.1 + } + } + { + id = "config_preparer" + path = "nvflare.app_common.widgets.external_configurator.ExternalConfigurator" + args { + component_ids = [ + "metric_relay" + ] + } + } + ] +} diff --git a/tests/integration_test/data/jobs/qa_job_4561583/app/config/config_fed_server.conf b/tests/integration_test/data/jobs/qa_job_4561583/app/config/config_fed_server.conf new file mode 100644 index 0000000000..98ca4a26ac --- /dev/null +++ b/tests/integration_test/data/jobs/qa_job_4561583/app/config/config_fed_server.conf @@ -0,0 +1,62 @@ +{ + format_version = 2 + task_data_filters = [] + task_result_filters = [] + model_class_path = "pl_net.PlNet" + workflows = [ + { + id = "scatter_and_gather" + path = "nvflare.app_common.workflows.scatter_and_gather.ScatterAndGather" + args { + min_clients = 2 + num_rounds = 2 + start_round = 0 + wait_time_after_min_received = 0 + aggregator_id = "aggregator" + persistor_id = "persistor" + shareable_generator_id = "shareable_generator" + train_task_name = "train" + train_timeout = 0 + } + } + ] + components = [ + { + id = "persistor" + path = "nvflare.app_opt.pt.file_model_persistor.PTFileModelPersistor" + args { + model { + path = "{model_class_path}" + } + } + } + { + id = "shareable_generator" + path = "nvflare.app_common.shareablegenerators.full_model_shareable_generator.FullModelShareableGenerator" + args {} + } + { + id = "aggregator" + path = "nvflare.app_common.aggregators.intime_accumulate_model_aggregator.InTimeAccumulateWeightedAggregator" + args { + expected_data_kind = "WEIGHT_DIFF" + } + } + { + id = "model_selector" + path = "nvflare.app_common.widgets.intime_model_selector.IntimeModelSelector" + args { + key_metric = "val_acc_epoch" + } + } + { + id = "receiver" + path = "nvflare.app_opt.tracking.tb.tb_receiver.TBAnalyticsReceiver" + args { + events = [ + "fed.analytix_log_stats" + ] + } + } + ] +} diff --git a/tests/integration_test/data/jobs/qa_job_4561583/app/custom/net.py b/tests/integration_test/data/jobs/qa_job_4561583/app/custom/net.py new file mode 100644 index 0000000000..1d7d0bd123 --- /dev/null +++ b/tests/integration_test/data/jobs/qa_job_4561583/app/custom/net.py @@ -0,0 +1,25 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch.nn as nn + + +class Net(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(10, 5) + + def forward(self, x): + x = self.fc1(x) + return x diff --git a/tests/integration_test/data/jobs/qa_job_4561583/app/custom/pl_net.py b/tests/integration_test/data/jobs/qa_job_4561583/app/custom/pl_net.py new file mode 100644 index 0000000000..43eb716532 --- /dev/null +++ b/tests/integration_test/data/jobs/qa_job_4561583/app/custom/pl_net.py @@ -0,0 +1,66 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +import random +import re +import time + +import lightning as L +import net +import torch +from torch import nn, optim + + +class PlNet(L.LightningModule): + def __init__(self): + super().__init__() + self.model = net.Net() + + self.site_name = "site-0" + self.current_round = 0 + + def training_step(self, batch, batch_idx): + print(f"@@@ {self.site_name}: batch: {batch}, batch idx: {batch_idx}") + + # Set fixed model's weight and bias by site num & round num + if self.current_round == 0: + weight = torch.zeros([5, 10], dtype=torch.float32) + bias = torch.zeros([5], dtype=torch.float32) + else: + print(f"@@@ {self.site_name} self.model.state_dict(): {self.model.state_dict()}") + weight = self.model.state_dict().get("fc1.weight") + bias = self.model.state_dict().get("fc1.bias") + + multiplier = re.search(r"\d+", self.site_name).group() + print(f"@@@ {self.site_name}: multiplier: {multiplier}") + + weight = torch.add(weight, 1) * int(multiplier) + bias = torch.add(bias, 1) * int(multiplier) + self.model.load_state_dict( + { + "fc1.weight": weight, + "fc1.bias": bias, + } + ) + + # Fake training steps, to adapt pytorch lightning framework + x = batch + y = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]) + + x = x.view(x.size(0), -1) + y = y.view(1, -1) + + # sleep to simulate training time elapsed + zzz = random.uniform(1.0, 3.0) + print(f"@@@ {self.site_name}: Sleep {zzz}") + time.sleep(zzz) + + loss = nn.functional.mse_loss(x, y) + loss.requires_grad_(True) + return loss + + def validation_step(self, batch, batch_idx): + t = time.time() / 1e10 + return torch.tensor([t]) + + def configure_optimizers(self): + optimizer = optim.Adam(self.parameters(), lr=1e-3) + return optimizer diff --git a/tests/integration_test/data/jobs/qa_job_4561583/app/custom/poc_executor.py b/tests/integration_test/data/jobs/qa_job_4561583/app/custom/poc_executor.py new file mode 100644 index 0000000000..2191ee3490 --- /dev/null +++ b/tests/integration_test/data/jobs/qa_job_4561583/app/custom/poc_executor.py @@ -0,0 +1,41 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +import lightning as L +import pl_net +import torch +from torch import utils + +import nvflare.client.lightning as flare + + +def main(): + + plnet = pl_net.PlNet() + dataset = torch.tensor( + [[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], [4.0, 3.0, 2.0, 1.0, 2.0, 5.0, 6.0, 2.0, 1.0, 32.0]] + ) + train_loader = utils.data.DataLoader(dataset) + trainer = L.Trainer(limit_train_batches=1, max_epochs=1, accelerator="cpu") + + for i in range(2): + flare.patch(trainer) + print(f"@@@ length of cb: {len(trainer.callbacks)}") + + site_name = flare.get_site_name() + print(f"@@@ site_name: {site_name}") + + while flare.is_running(): + # flare.receive() called for getting current_round information + input_model = flare.receive() + if input_model: + print(f"@@@ {site_name}: current_round={input_model.current_round}") + plnet.current_round = input_model.current_round + plnet.site_name = site_name + # Test the patch for validate and fit + trainer.validate(plnet, train_loader) + trainer.fit(plnet, train_loader) + print(f"@@@ {site_name} param: {plnet.state_dict()}") + + +if __name__ == "__main__": + main() diff --git a/tests/integration_test/data/jobs/qa_job_4561583/meta.conf b/tests/integration_test/data/jobs/qa_job_4561583/meta.conf new file mode 100644 index 0000000000..556f3e31c2 --- /dev/null +++ b/tests/integration_test/data/jobs/qa_job_4561583/meta.conf @@ -0,0 +1,11 @@ +{ + name = "qa_job_4561583" + resource_spec {} + deploy_map { + app = [ + "@ALL" + ] + } + min_clients = 2 + mandatory_clients = [] +} diff --git a/tests/integration_test/data/jobs/qa_job_4561872/app/config/config_fed_client.conf b/tests/integration_test/data/jobs/qa_job_4561872/app/config/config_fed_client.conf new file mode 100644 index 0000000000..1380e17ae1 --- /dev/null +++ b/tests/integration_test/data/jobs/qa_job_4561872/app/config/config_fed_client.conf @@ -0,0 +1,144 @@ +format_version = 2 +# This is the application script which will be invoked. Client can replace this script with user's own training script. +app_script = "train.py" +# Additional arguments needed by the training code. +app_config = "" +# Client Computing Executors. +executors = [ + { + # tasks the executors are defined to handle + tasks = [ + "train", + "validate", + "submit_model" + ] + executor { + id = "Executor" + # Executor name : PTClientAPILauncherExecutor + # This is an executor for pytorch + Client API. The underline data exchange is using Pipe. + path = "nvflare.app_opt.pt.client_api_launcher_executor.PTClientAPILauncherExecutor" + args { + # launcher_id is used to locate the Launcher object in "components" + launcher_id = "launcher" + # pipe_id is used to locate the Pipe object in "components" + pipe_id = "pipe" + # Timeout in seconds for waiting for a heartbeat from the training script. Defaults to 30 seconds. + # Please refer to the class docstring for all available arguments + heartbeat_timeout = 60 + # if the transfer_type is FULL, then it will be sent directly + # if the transfer_type is DIFF, then we will calculate the + # difference VS received parameters and send the difference + params_transfer_type = "FULL" + # if train_with_evaluation is true, the executor will expect + # the custom code need to send back both the trained parameters and the evaluation metric + # otherwise only trained parameters are expected + train_with_evaluation = true + # tasks for different modes + train_task_name = "train" + evaluate_task_name = "validate" + submit_model_task_name = "submit_model" + } + } + } + { + # All tasks prefixed with swarm_ are routed to SwarmClientController + tasks = ["swarm_*"] + executor { + # client-side controller for training and logic and aggregation management + path = "nvflare.app_common.ccwf.SwarmClientController" + args { + # train task must be implemented by Executor + learn_task_name = "train" + # how long to wait for current learn task before timing out the gathering + learn_task_timeout = 600 + # ids must map to corresponding components + persistor_id = "persistor" + aggregator_id = "aggregator" + shareable_generator_id = "shareable_generator" + min_responses_required = 2 + wait_time_after_min_resps_received = 30 + } + } + } + { + # All tasks prefixed with cse_ are routed to CrossSiteEvalClientController + tasks = ["cse_*"] + executor { + # client-side controller for cse + path = "nvflare.app_common.ccwf.CrossSiteEvalClientController" + args { + # submit_model and validate tasks must be implemented by Executor + submit_model_task_name = "submit_model" + validation_task_name = "validate" + # persistor id must map to corresponding persistor component + persistor_id = "persistor" + get_model_timeout = 60 + } + } + } +] +task_result_filters = [] +task_data_filters = [] +components = [ + { + # component id is "launcher" + id = "launcher" + + # the class path of this component + path = "nvflare.app_common.launchers.subprocess_launcher.SubprocessLauncher" + + args { + # the launcher will invoke the script + script = "python3 -u custom/{app_script} {app_config} " + # if launch_once is true, the SubprocessLauncher will launch once for the whole job + # if launch_once is false, the SubprocessLauncher will launch a process for each task it receives from server + launch_once = true + } + } + { + id = "pipe" + + path = "nvflare.fuel.utils.pipe.file_pipe.FilePipe" + + args { + mode = "PASSIVE" + # root_path: is the directory location of the parameters exchange. + # You can also set it to an absolute path in your system. + root_path = "{WORKSPACE}/{JOB_ID}/{SITE_NAME}" + } + } + { + id = "persistor" + path = "nvflare.app_opt.pt.file_model_persistor.PTFileModelPersistor" + args { + model { + path = "net.Net" + } + } + } + { + id = "shareable_generator" + path = "nvflare.app_common.ccwf.comps.simple_model_shareable_generator.SimpleModelShareableGenerator" + args {} + } + { + id = "aggregator" + path = "nvflare.app_common.aggregators.intime_accumulate_model_aggregator.InTimeAccumulateWeightedAggregator" + args { + expected_data_kind = "WEIGHTS" + } + } + { + id = "model_selector" + path = "nvflare.app_common.widgets.intime_model_selector.IntimeModelSelector" + args { + key_metric = "accuracy" + } + } + # prints best results once executor is finalized + { + id = "result_printer" + path = "nvflare.app_common.ccwf.comps.cwe_result_printer.CWEResultPrinter" + args {} + } +] diff --git a/tests/integration_test/data/jobs/qa_job_4561872/app/config/config_fed_server.conf b/tests/integration_test/data/jobs/qa_job_4561872/app/config/config_fed_server.conf new file mode 100644 index 0000000000..548e7d6655 --- /dev/null +++ b/tests/integration_test/data/jobs/qa_job_4561872/app/config/config_fed_server.conf @@ -0,0 +1,33 @@ +format_version = 2 +task_data_filters = [] +task_result_filters = [] +components = [ + { + # write validation results to json file + id = "json_generator" + path = "nvflare.app_common.widgets.validation_json_generator.ValidationJsonGenerator" + args {} + } +] +workflows = [ + { + # server-side controller to manage job life cycle + id = "swarm_controller" + path = "nvflare.app_common.ccwf.SwarmServerController" + args { + # can also set aggregation clients and train clients, see class for all available args + num_rounds = 3, + aggr_clients = ["site-1"], + train_clients = ["site-2", "site-3"] + } + } + { + # server-side controller to manage configuration and evaluation workflow + id = "cross_site_eval" + path = "nvflare.app_common.ccwf.CrossSiteEvalServerController" + args { + # can also set evaluators and evaluatees, see class for all available args + eval_task_timeout = 300 + } + } +] diff --git a/tests/integration_test/data/jobs/qa_job_4561872/app/custom/net.py b/tests/integration_test/data/jobs/qa_job_4561872/app/custom/net.py new file mode 100644 index 0000000000..47ac7e9589 --- /dev/null +++ b/tests/integration_test/data/jobs/qa_job_4561872/app/custom/net.py @@ -0,0 +1,37 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class Net(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(3, 6, 5) + self.pool = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(6, 16, 5) + self.fc1 = nn.Linear(16 * 5 * 5, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + + def forward(self, x): + x = self.pool(F.relu(self.conv1(x))) + x = self.pool(F.relu(self.conv2(x))) + x = torch.flatten(x, 1) # flatten all dimensions except batch + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + return x diff --git a/tests/integration_test/data/jobs/qa_job_4561872/app/custom/train.py b/tests/integration_test/data/jobs/qa_job_4561872/app/custom/train.py new file mode 100644 index 0000000000..566471d0cb --- /dev/null +++ b/tests/integration_test/data/jobs/qa_job_4561872/app/custom/train.py @@ -0,0 +1,188 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +import torch +import torch.nn as nn +import torch.optim as optim +import torchvision +import torchvision.transforms as transforms +from net import Net + +# (1) import nvflare client API +import nvflare.client as flare +from nvflare.app_common.app_constant import ModelName + +# (optional) set a fix place so we don't need to download everytime +CIFAR10_ROOT = "/tmp/nvflare/data" +# (optional) We change to use GPU to speed things up. +# if you want to use CPU, change DEVICE="cpu" +DEVICE = "cuda:0" + + +def define_parser(): + parser = argparse.ArgumentParser() + parser.add_argument("--dataset_path", type=str, default=CIFAR10_ROOT, nargs="?") + parser.add_argument("--batch_size", type=int, default=4, nargs="?") + parser.add_argument("--num_workers", type=int, default=1, nargs="?") + parser.add_argument("--local_epochs", type=int, default=2, nargs="?") + parser.add_argument("--model_path", type=str, default=f"{CIFAR10_ROOT}/cifar_net.pth", nargs="?") + return parser.parse_args() + + +def main(): + # define local parameters + args = define_parser() + + dataset_path = args.dataset_path + batch_size = args.batch_size + num_workers = args.num_workers + local_epochs = args.local_epochs + model_path = args.model_path + + transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) + trainset = torchvision.datasets.CIFAR10(root=dataset_path, train=True, download=True, transform=transform) + trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers) + testset = torchvision.datasets.CIFAR10(root=dataset_path, train=False, download=True, transform=transform) + testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_workers) + + net = Net() + best_accuracy = 0.0 + + # wraps evaluation logic into a method to re-use for + # evaluation on both trained and received model + def evaluate(input_weights): + net = Net() + net.load_state_dict(input_weights) + # (optional) use GPU to speed things up + net.to(DEVICE) + + correct = 0 + total = 0 + # since we're not training, we don't need to calculate the gradients for our outputs + with torch.no_grad(): + for data in testloader: + # (optional) use GPU to speed things up + images, labels = data[0].to(DEVICE), data[1].to(DEVICE) + # calculate outputs by running images through the network + outputs = net(images) + # the class with the highest energy is what we choose as prediction + _, predicted = torch.max(outputs.data, 1) + total += labels.size(0) + correct += (predicted == labels).sum().item() + + return 100 * correct // total + + # (2) initialize NVFlare client API + flare.init() + + # (3) run continously when launch_once=true + while flare.is_running(): + + # (4) receive FLModel from NVFlare + input_model = flare.receive() + client_id = flare.get_site_name() + + # Based on different "task" we will do different things + # for "train" task (flare.is_train()) we use the received model to do training and/or evaluation + # and send back updated model and/or evaluation metrics, if the "train_with_evaluation" is specified as True + # in the config_fed_client we will need to do evaluation and include the evaluation metrics + # for "evaluate" task (flare.is_evaluate()) we use the received model to do evaluation + # and send back the evaluation metrics + # for "submit_model" task (flare.is_submit_model()) we just need to send back the local model + # (5) performing train task on received model + if flare.is_train(): + print(f"({client_id}) current_round={input_model.current_round}, total_rounds={input_model.total_rounds}") + + # (5.1) loads model from NVFlare + net.load_state_dict(input_model.params) + + criterion = nn.CrossEntropyLoss() + optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) + + # (optional) use GPU to speed things up + net.to(DEVICE) + # (optional) calculate total steps + steps = local_epochs * len(trainloader) + for epoch in range(local_epochs): # loop over the dataset multiple times + + running_loss = 0.0 + for i, data in enumerate(trainloader, 0): + # get the inputs; data is a list of [inputs, labels] + # (optional) use GPU to speed things up + inputs, labels = data[0].to(DEVICE), data[1].to(DEVICE) + + # zero the parameter gradients + optimizer.zero_grad() + + # forward + backward + optimize + outputs = net(inputs) + loss = criterion(outputs, labels) + loss.backward() + optimizer.step() + + # print statistics + running_loss += loss.item() + if i % 2000 == 1999: # print every 2000 mini-batches + print(f"({client_id}) [{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}") + running_loss = 0.0 + + print(f"({client_id}) Finished Training") + + # (5.2) evaluation on local trained model to save best model + local_accuracy = evaluate(net.state_dict()) + print(f"({client_id}) Evaluating local trained model. Accuracy on the 10000 test images: {local_accuracy}") + if local_accuracy > best_accuracy: + best_accuracy = local_accuracy + torch.save(net.state_dict(), model_path) + + # (5.3) evaluate on received model for model selection + accuracy = evaluate(input_model.params) + print( + f"({client_id}) Evaluating received model for model selection. Accuracy on the 10000 test images: {accuracy}" + ) + + # (5.4) construct trained FL model + output_model = flare.FLModel( + params=net.cpu().state_dict(), + metrics={"accuracy": accuracy}, + meta={"NUM_STEPS_CURRENT_ROUND": steps}, + ) + + # (5.5) send model back to NVFlare + flare.send(output_model) + + # (6) performing evaluate task on received model + elif flare.is_evaluate(): + accuracy = evaluate(input_model.params) + flare.send(flare.FLModel(metrics={"accuracy": accuracy})) + + # (7) performing submit_model task to obtain best local model + elif flare.is_submit_model(): + model_name = input_model.meta["submit_model_name"] + if model_name == ModelName.BEST_MODEL: + try: + weights = torch.load(model_path) + net = Net() + net.load_state_dict(weights) + flare.send(flare.FLModel(params=net.cpu().state_dict())) + except Exception as e: + raise ValueError("Unable to load best model") from e + else: + raise ValueError(f"Unknown model_type: {model_name}") + + +if __name__ == "__main__": + main() diff --git a/tests/integration_test/data/jobs/qa_job_4561872/meta.conf b/tests/integration_test/data/jobs/qa_job_4561872/meta.conf new file mode 100644 index 0000000000..9af81513bc --- /dev/null +++ b/tests/integration_test/data/jobs/qa_job_4561872/meta.conf @@ -0,0 +1,11 @@ +{ + name = "swarm_cse_pt" + resource_spec {} + deploy_map { + app = [ + "@ALL" + ] + } + min_clients = 3 + mandatory_clients = [] +} diff --git a/tests/integration_test/data/jobs/qa_job_4592780/app/config/config_fed_client.conf b/tests/integration_test/data/jobs/qa_job_4592780/app/config/config_fed_client.conf new file mode 100644 index 0000000000..9dd88b904c --- /dev/null +++ b/tests/integration_test/data/jobs/qa_job_4592780/app/config/config_fed_client.conf @@ -0,0 +1,77 @@ +{ + format_version = 2 + app_script = "poc_executor.py" + app_config = "" + executors = [ + { + tasks = [ + "train" + ] + executor { + path = "nvflare.app_opt.pt.client_api_launcher_executor.PTClientAPILauncherExecutor" + args { + launcher_id = "launcher" + pipe_id = "pipe" + heartbeat_timeout = 60 + params_exchange_format = "pytorch" + params_transfer_type = "DIFF" + train_with_evaluation = true + } + } + } + ] + task_data_filters = [] + task_result_filters = [] + components = [ + { + id = "launcher" + path = "nvflare.app_common.launchers.subprocess_launcher.SubprocessLauncher" + args { + script = "python3 -u custom/{app_script} {app_config} " + launch_once = false + } + } + { + id = "pipe" + path = "nvflare.fuel.utils.pipe.cell_pipe.CellPipe" + args { + mode = "PASSIVE" + site_name = "{SITE_NAME}" + token = "{JOB_ID}" + root_url = "{ROOT_URL}" + secure_mode = "{SECURE_MODE}" + workspace_dir = "{WORKSPACE}" + } + } + { + id = "metrics_pipe" + path = "nvflare.fuel.utils.pipe.cell_pipe.CellPipe" + args { + mode = "PASSIVE" + site_name = "{SITE_NAME}" + token = "{JOB_ID}" + root_url = "{ROOT_URL}" + secure_mode = "{SECURE_MODE}" + workspace_dir = "{WORKSPACE}" + } + } + { + id = "metric_relay" + path = "nvflare.app_common.widgets.metric_relay.MetricRelay" + args { + pipe_id = "metrics_pipe" + event_type = "fed.analytix_log_stats" + read_interval = 0.1 + } + } + { + id = "config_preparer" + path = "nvflare.app_common.widgets.external_configurator.ExternalConfigurator" + args { + component_ids = [ + "metric_relay" + ] + } + } + ] +} diff --git a/tests/integration_test/data/jobs/qa_job_4592780/app/config/config_fed_server.conf b/tests/integration_test/data/jobs/qa_job_4592780/app/config/config_fed_server.conf new file mode 100644 index 0000000000..8245a2d527 --- /dev/null +++ b/tests/integration_test/data/jobs/qa_job_4592780/app/config/config_fed_server.conf @@ -0,0 +1,62 @@ +{ + format_version = 2 + task_data_filters = [] + task_result_filters = [] + model_class_path = "net.Net" + workflows = [ + { + id = "scatter_and_gather" + path = "nvflare.app_common.workflows.scatter_and_gather.ScatterAndGather" + args { + min_clients = 2 + num_rounds = 2 + start_round = 0 + wait_time_after_min_received = 0 + aggregator_id = "aggregator" + persistor_id = "persistor" + shareable_generator_id = "shareable_generator" + train_task_name = "train" + train_timeout = 0 + } + } + ] + components = [ + { + id = "persistor" + path = "nvflare.app_opt.pt.file_model_persistor.PTFileModelPersistor" + args { + model { + path = "{model_class_path}" + } + } + } + { + id = "shareable_generator" + path = "nvflare.app_common.shareablegenerators.full_model_shareable_generator.FullModelShareableGenerator" + args {} + } + { + id = "aggregator" + path = "nvflare.app_common.aggregators.intime_accumulate_model_aggregator.InTimeAccumulateWeightedAggregator" + args { + expected_data_kind = "WEIGHT_DIFF" + } + } + { + id = "model_selector" + path = "nvflare.app_common.widgets.intime_model_selector.IntimeModelSelector" + args { + key_metric = "accuracy" + } + } + { + id = "receiver" + path = "nvflare.app_opt.tracking.tb.tb_receiver.TBAnalyticsReceiver" + args { + events = [ + "fed.analytix_log_stats" + ] + } + } + ] +} diff --git a/tests/integration_test/data/jobs/qa_job_4592780/app/custom/net.py b/tests/integration_test/data/jobs/qa_job_4592780/app/custom/net.py new file mode 100644 index 0000000000..1d7d0bd123 --- /dev/null +++ b/tests/integration_test/data/jobs/qa_job_4592780/app/custom/net.py @@ -0,0 +1,25 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch.nn as nn + + +class Net(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(10, 5) + + def forward(self, x): + x = self.fc1(x) + return x diff --git a/tests/integration_test/data/jobs/qa_job_4592780/app/custom/poc_executor.py b/tests/integration_test/data/jobs/qa_job_4592780/app/custom/poc_executor.py new file mode 100644 index 0000000000..6852b72e38 --- /dev/null +++ b/tests/integration_test/data/jobs/qa_job_4592780/app/custom/poc_executor.py @@ -0,0 +1,78 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +import logging +import random +import re +from datetime import datetime + +import torch + +# Client API +import nvflare.client as flare + +logger = logging.getLogger("POCExecutor") + + +def main(): + + flare.init() + input_model = flare.receive() + + print("@@@ input_model: ", input_model) + round_num = input_model.current_round + print("@@@ Round number in this round: ", round_num) + + # 'site-2' + site_name = input_model.meta.get("site_name") + multiplier = re.search(r"\d+", site_name).group() + print("@@@ site_name: ", site_name) + + start_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + # data shape + if input_model.current_round == 0: + # Not involving init random weights + weight = torch.zeros([5, 10], dtype=torch.float32) + bias = torch.zeros([5], dtype=torch.float32) + else: + weight = input_model.params.get("fc1.weight") + bias = input_model.params.get("fc1.bias") + + # do the job + zzz = random.uniform(1.0, 3.0) + print("@@@ Sleep " + str(zzz)) + time.sleep(zzz) + + weight = torch.add(weight, 1) * int(multiplier) + bias = torch.add(bias, 1) * int(multiplier) + + end_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + logger.info("Finished Training") + + params = { + "fc1.weight": weight, + "fc1.bias": bias, + } + + def evaluate(data): + # between 0-1, increase with time + t = time.time() / 1e10 + print(f"fake evaluate data: {data}") + print( + f"fake evaluate result: {t}, generated at {datetime.utcfromtimestamp(t * 1e10).strftime('%Y-%m-%d %H:%M:%S')}" + ) + return t + + accuracy = evaluate(params) + + output_model = flare.FLModel( + params=params, + metrics={"accuracy": accuracy}, + meta={"NUM_STEPS_CURRENT_ROUND": 2, "start": start_time, "end": end_time}, + ) + + print("@@@ output_model: ", output_model) + flare.send(output_model) + + +if __name__ == "__main__": + main() diff --git a/tests/integration_test/data/jobs/qa_job_4592780/meta.conf b/tests/integration_test/data/jobs/qa_job_4592780/meta.conf new file mode 100644 index 0000000000..55a75fcabb --- /dev/null +++ b/tests/integration_test/data/jobs/qa_job_4592780/meta.conf @@ -0,0 +1,11 @@ +{ + name = "qa_job_4582780" + resource_spec {} + deploy_map { + app = [ + "@ALL" + ] + } + min_clients = 2 + mandatory_clients = [] +} diff --git a/tests/integration_test/data/test_configs/standalone_job/client_api.yml b/tests/integration_test/data/test_configs/standalone_job/client_api.yml new file mode 100644 index 0000000000..c8c030a79b --- /dev/null +++ b/tests/integration_test/data/test_configs/standalone_job/client_api.yml @@ -0,0 +1,148 @@ +n_servers: 1 +n_clients: 2 +jobs_root_dir: ./data/jobs +cleanup: True + + +tests: + - test_name: "run np-loop" + event_sequence: + - "trigger": + "type": "server_log" + "data": "Server started" + "actions": [ "submit_job np_loop" ] + "result": + "type": "job_submit_success" + - "trigger": + "type": "run_state" + "data": { "run_finished": True } + "actions": [ "ensure_current_job_done" ] + "result": + "type": "run_state" + "data": { "run_finished": True } + - test_name: "run np-loop-cell-pipe" + event_sequence: + - "trigger": + "type": "server_log" + "data": "Server started" + "actions": [ "submit_job np_loop_cell_pipe" ] + "result": + "type": "job_submit_success" + - "trigger": + "type": "run_state" + "data": { "run_finished": True } + "actions": [ "ensure_current_job_done" ] + "result": + "type": "run_state" + "data": { "run_finished": True } + - test_name: "run np-metrics" + event_sequence: + - "trigger": + "type": "server_log" + "data": "Server started" + "actions": [ "submit_job np_metrics" ] + "result": + "type": "job_submit_success" + - "trigger": + "type": "run_state" + "data": { "run_finished": True } + "actions": [ "ensure_current_job_done" ] + "result": + "type": "run_state" + "data": { "run_finished": True } + - test_name: "run pt-client-api" + event_sequence: + - "trigger": + "type": "server_log" + "data": "Server started" + "actions": [ "submit_job pt_client_api" ] + "result": + "type": "job_submit_success" + - "trigger": + "type": "run_state" + "data": { "run_finished": True } + "actions": [ "ensure_current_job_done" ] + "result": + "type": "run_state" + "data": { "run_finished": True } + setup: + - python -c "from torchvision.datasets import CIFAR10; CIFAR10(root='/tmp/nvflare/data', train=True, download=True)" + teardown: + - rm -rf /tmp/nvflare/data + - test_name: "run pt-client-api-launch-once" + event_sequence: + - "trigger": + "type": "server_log" + "data": "Server started" + "actions": [ "submit_job pt_client_api_launch_once" ] + "result": + "type": "job_submit_success" + - "trigger": + "type": "run_state" + "data": { "run_finished": True } + "actions": [ "ensure_current_job_done" ] + "result": + "type": "run_state" + "data": { "run_finished": True } + setup: + - python -c "from torchvision.datasets import CIFAR10; CIFAR10(root='/tmp/nvflare/data', train=True, download=True)" + teardown: + - rm -rf /tmp/nvflare/data + - test_name: "run pt-client-api-cyclic" + event_sequence: + - "trigger": + "type": "server_log" + "data": "Server started" + "actions": [ "submit_job pt_client_api_cyclic" ] + "result": + "type": "job_submit_success" + - "trigger": + "type": "run_state" + "data": { "run_finished": True } + "actions": [ "ensure_current_job_done" ] + "result": + "type": "run_state" + "data": { "run_finished": True } + setup: + - python -c "from torchvision.datasets import CIFAR10; CIFAR10(root='/tmp/nvflare/data', train=True, download=True)" + teardown: + - rm -rf /tmp/nvflare/data + - test_name: "run pt-decorator" + event_sequence: + - "trigger": + "type": "server_log" + "data": "Server started" + "actions": [ "submit_job decorator" ] + "result": + "type": "job_submit_success" + - "trigger": + "type": "run_state" + "data": { "run_finished": True } + "actions": [ "ensure_current_job_done" ] + "result": + "type": "run_state" + "data": { "run_finished": True } + setup: + - python -c "from torchvision.datasets import CIFAR10; CIFAR10(root='/tmp/nvflare/data', train=True, download=True)" + teardown: + - rm -rf /tmp/nvflare/data + - test_name: "run lightning" + event_sequence: + - "trigger": + "type": "server_log" + "data": "Server started" + "actions": [ "submit_job lightning" ] + "result": + "type": "job_submit_success" + - "trigger": + "type": "run_state" + "data": { "run_finished": True } + "actions": [ "ensure_current_job_done" ] + "result": + "type": "run_state" + "data": { "run_finished": True } + setup: + - python -m pip install pytorch_lightning + - python -c "from torchvision.datasets import CIFAR10; CIFAR10(root='/tmp/nvflare/data', train=True, download=True)" + teardown: + - rm -rf /tmp/nvflare/data diff --git a/tests/integration_test/data/test_configs/standalone_job/client_api_qa.yml b/tests/integration_test/data/test_configs/standalone_job/client_api_qa.yml new file mode 100644 index 0000000000..166f41b211 --- /dev/null +++ b/tests/integration_test/data/test_configs/standalone_job/client_api_qa.yml @@ -0,0 +1,102 @@ +n_servers: 1 +n_clients: 3 +jobs_root_dir: ./data/jobs +cleanup: True + + +tests: + - test_name: "run qa_job_4561872" + event_sequence: + - "trigger": + "type": "server_log" + "data": "Server started" + "actions": [ "submit_job qa_job_4561872" ] + "result": + "type": "job_submit_success" + - "trigger": + "type": "run_state" + "data": { "run_finished": True } + "actions": [ "ensure_current_job_done" ] + "result": + "type": "run_state" + "data": { "run_finished": True } + setup: + - python -c "from torchvision.datasets import CIFAR10; CIFAR10(root='/tmp/nvflare/data', train=True, download=True)" + teardown: + - rm -rf /tmp/nvflare/data + - test_name: "run qa_job_4558419" + event_sequence: + - "trigger": + "type": "server_log" + "data": "Server started" + "actions": [ "submit_job qa_job_4558419" ] + "result": + "type": "job_submit_success" + - "trigger": + "type": "run_state" + "data": { "run_finished": True } + "actions": [ "ensure_current_job_done" ] + "result": + "type": "run_state" + "data": { "run_finished": True } + setup: + - python -c "from torchvision.datasets import CIFAR10; CIFAR10(root='/tmp/nvflare/data', train=True, download=True)" + teardown: + - rm -rf /tmp/nvflare/data + - test_name: "run qa_job_4561583" + event_sequence: + - "trigger": + "type": "server_log" + "data": "Server started" + "actions": [ "submit_job qa_job_4561583" ] + "result": + "type": "job_submit_success" + - "trigger": + "type": "run_state" + "data": { "run_finished": True } + "actions": [ "ensure_current_job_done" ] + "result": + "type": "run_state" + "data": { "run_finished": True } + setup: + - python -c "from torchvision.datasets import CIFAR10; CIFAR10(root='/tmp/nvflare/data', train=True, download=True)" + teardown: + - rm -rf /tmp/nvflare/data + - test_name: "run qa_job_4561872" + event_sequence: + - "trigger": + "type": "server_log" + "data": "Server started" + "actions": [ "submit_job qa_job_4561872" ] + "result": + "type": "job_submit_success" + - "trigger": + "type": "run_state" + "data": { "run_finished": True } + "actions": [ "ensure_current_job_done" ] + "result": + "type": "run_state" + "data": { "run_finished": True } + setup: + - python -c "from torchvision.datasets import CIFAR10; CIFAR10(root='/tmp/nvflare/data', train=True, download=True)" + teardown: + - rm -rf /tmp/nvflare/data + - test_name: "run qa_job_4592780" + event_sequence: + - "trigger": + "type": "server_log" + "data": "Server started" + "actions": [ "submit_job qa_job_4592780" ] + "result": + "type": "job_submit_success" + - "trigger": + "type": "run_state" + "data": { "run_finished": True } + "actions": [ "ensure_current_job_done" ] + "result": + "type": "run_state" + "data": { "run_finished": True } + setup: + - python -c "from torchvision.datasets import CIFAR10; CIFAR10(root='/tmp/nvflare/data', train=True, download=True)" + teardown: + - rm -rf /tmp/nvflare/data diff --git a/tests/integration_test/run_integration_tests.sh b/tests/integration_test/run_integration_tests.sh index 49021f6104..e32a6c9d86 100755 --- a/tests/integration_test/run_integration_tests.sh +++ b/tests/integration_test/run_integration_tests.sh @@ -3,7 +3,7 @@ set -e PYTHONPATH="${PWD}/../.." -backends=(numpy tensorflow pytorch overseer ha auth preflight cifar auto stats) +backends=(numpy tensorflow pytorch overseer ha auth preflight cifar auto stats xgboost client_api client_api_qa) usage() { diff --git a/tests/integration_test/src/utils.py b/tests/integration_test/src/utils.py index 0a945607c9..d80f34b13b 100644 --- a/tests/integration_test/src/utils.py +++ b/tests/integration_test/src/utils.py @@ -195,6 +195,10 @@ def check_client_status_ready(response: dict) -> bool: if "client_statuses" not in response["details"]: return False + for row in response["details"]["client_statuses"][1:]: + if row[3] == "No Reply": + return False + return True diff --git a/tests/integration_test/test_configs.yml b/tests/integration_test/test_configs.yml index d75f59a784..55e710620c 100644 --- a/tests/integration_test/test_configs.yml +++ b/tests/integration_test/test_configs.yml @@ -32,3 +32,7 @@ test_configs: xgboost: - ./data/test_configs/standalone_job/xgb_histogram_examples.yml - ./data/test_configs/standalone_job/xgb_tree_examples.yml + client_api: + - ./data/test_configs/standalone_job/client_api.yml + client_api_qa: + - ./data/test_configs/standalone_job/client_api_qa.yml From e1fc2eb3d95c5b7ced6859c40c4da711e777c643 Mon Sep 17 00:00:00 2001 From: Sean Yang Date: Fri, 19 Apr 2024 12:12:17 -0700 Subject: [PATCH 13/14] use controller name for stats (#2522) --- nvflare/apis/impl/wf_comm_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nvflare/apis/impl/wf_comm_server.py b/nvflare/apis/impl/wf_comm_server.py index 679efe643d..21fd0c6033 100644 --- a/nvflare/apis/impl/wf_comm_server.py +++ b/nvflare/apis/impl/wf_comm_server.py @@ -139,7 +139,7 @@ def _set_stats(self, fl_ctx: FLContext): "collector must be an instance of GroupInfoCollector, but got {}".format(type(collector)) ) collector.add_info( - group_name=self.name, + group_name=self.controller.name, info={ "tasks": {t.name: [ct.client.name for ct in t.client_tasks] for t in self._tasks}, }, From bf978b7feab6c745059b8110d83e54de58c6bd76 Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Fri, 19 Apr 2024 18:20:06 -0400 Subject: [PATCH 14/14] Simulator workspace re-design (#2492) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Redesign simulator workspace structure. * working, needs clean. * Changed the simulator workspacce structure to be consistent with POC. * Moved the logfile init to start_server_app(). * optimzed. * adjust the stats pool location. * Addressed the PR views. --------- Co-authored-by: Chester Chen <512707+chesterxgchen@users.noreply.github.com> Co-authored-by: Yuan-Ting Hsieh (謝沅廷) --- .../fed/app/simulator/simulator_runner.py | 95 +++++++++++-------- .../fed/app/simulator/simulator_worker.py | 11 ++- nvflare/private/fed/utils/fed_utils.py | 5 + 3 files changed, 70 insertions(+), 41 deletions(-) diff --git a/nvflare/private/fed/app/simulator/simulator_runner.py b/nvflare/private/fed/app/simulator/simulator_runner.py index e95c23242c..e9565d881b 100644 --- a/nvflare/private/fed/app/simulator/simulator_runner.py +++ b/nvflare/private/fed/app/simulator/simulator_runner.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import copy import json import logging.config import os @@ -58,7 +58,7 @@ from nvflare.private.fed.simulator.simulator_app_runner import SimulatorServerAppRunner from nvflare.private.fed.simulator.simulator_audit import SimulatorAuditor from nvflare.private.fed.simulator.simulator_const import SimulatorConstants -from nvflare.private.fed.utils.fed_utils import add_logfile_handler, fobs_initialize, split_gpus +from nvflare.private.fed.utils.fed_utils import add_logfile_handler, fobs_initialize, get_simulator_app_root, split_gpus from nvflare.security.logging import secure_format_exception, secure_log_traceback from nvflare.security.security import EmptyAuthorizer @@ -134,9 +134,6 @@ def setup(self): if not os.path.isfile(log_config_file_path): log_config_file_path = os.path.join(os.path.dirname(__file__), WorkspaceConstants.LOGGING_CONFIG) logging.config.fileConfig(fname=log_config_file_path, disable_existing_loggers=False) - local_dir = os.path.join(self.args.workspace, "local") - os.makedirs(local_dir, exist_ok=True) - shutil.copyfile(log_config_file_path, os.path.join(local_dir, WorkspaceConstants.LOGGING_CONFIG)) self.args.log_config = None self.args.config_folder = "config" @@ -153,16 +150,10 @@ def setup(self): AuthorizationService.initialize(EmptyAuthorizer()) AuditService.the_auditor = SimulatorAuditor() + self.simulator_root = self.args.workspace + self._cleanup_workspace() init_security_content_service(self.args.workspace) - self.simulator_root = os.path.join(self.args.workspace, SimulatorConstants.JOB_NAME) - if os.path.exists(self.simulator_root): - shutil.rmtree(self.simulator_root) - - os.makedirs(self.simulator_root) - log_file = os.path.join(self.simulator_root, WorkspaceConstants.LOG_FILE_NAME) - add_logfile_handler(log_file) - try: data_bytes, job_name, meta = self.validate_job_data() @@ -237,7 +228,7 @@ def setup(self): self.args.sp_scheme = parsed_url.scheme self.logger.info("Deploy the Apps.") - self._deploy_apps(job_name, data_bytes, meta) + self._deploy_apps(job_name, data_bytes, meta, log_config_file_path) return True @@ -246,6 +237,25 @@ def setup(self): secure_log_traceback() return False + def _cleanup_workspace(self): + os.makedirs(self.simulator_root, exist_ok=True) + with tempfile.TemporaryDirectory() as temp_dir: + startup_dir = os.path.join(self.args.workspace, "startup") + temp_start_up = os.path.join(temp_dir, "startup") + if os.path.exists(startup_dir): + shutil.move(startup_dir, temp_start_up) + if os.path.exists(self.simulator_root): + shutil.rmtree(self.simulator_root) + if os.path.exists(temp_start_up): + shutil.move(temp_start_up, startup_dir) + + def _setup_local_startup(self, log_config_file_path, workspace): + local_dir = os.path.join(workspace, "local") + startup = os.path.join(workspace, "startup") + os.makedirs(local_dir, exist_ok=True) + shutil.copyfile(log_config_file_path, os.path.join(local_dir, WorkspaceConstants.LOGGING_CONFIG)) + shutil.copytree(os.path.join(self.simulator_root, "startup"), startup) + def validate_job_data(self): # Validate the simulate job job_name = split_path(self.args.job_folder)[1] @@ -281,7 +291,7 @@ def _validate_client_names(self, meta, client_names): if no_app_clients: raise RuntimeError(f"The job does not have App to run for clients: {no_app_clients}") - def _deploy_apps(self, job_name, data_bytes, meta): + def _deploy_apps(self, job_name, data_bytes, meta, log_config_file_path): with tempfile.TemporaryDirectory() as temp_dir: if os.path.exists(temp_dir): shutil.rmtree(temp_dir) @@ -289,22 +299,19 @@ def _deploy_apps(self, job_name, data_bytes, meta): unzip_all_from_bytes(data_bytes, temp_dir) temp_job_folder = os.path.join(temp_dir, job_name) - app_server_root = os.path.join(self.simulator_root, "app_server") for app_name, participants in meta.get(JobMetaKey.DEPLOY_MAP).items(): if len(participants) == 1 and participants[0].upper() == ALL_SITES: participants = ["server"] participants.extend([client for client in self.client_names]) for p in participants: - if p == "server": + if p == "server" or p in self.client_names: + app_root = get_simulator_app_root(self.simulator_root, p) + self._setup_local_startup(log_config_file_path, os.path.join(self.simulator_root, p)) app = os.path.join(temp_job_folder, app_name) - shutil.copytree(app, app_server_root) - elif p in self.client_names: - app_client_root = os.path.join(self.simulator_root, "app_" + p) - app = os.path.join(temp_job_folder, app_name) - shutil.copytree(app, app_client_root) + shutil.copytree(app, app_root) - job_meta_file = os.path.join(self.simulator_root, WorkspaceConstants.JOB_META_FILE) + job_meta_file = os.path.join(self.simulator_root, "server", WorkspaceConstants.JOB_META_FILE) with open(job_meta_file, "w") as f: json.dump(meta, f, indent=4) @@ -336,7 +343,7 @@ def create_client(self, client_name): def _set_client_status(self): for client in self.federated_clients: - app_client_root = os.path.join(self.simulator_root, "app_" + client.client_name) + app_client_root = get_simulator_app_root(self.simulator_root, client.client_name) client.app_client_root = app_client_root client.args = self.args # self.create_client_runner(client) @@ -404,7 +411,8 @@ def simulator_run_main(self): } self.logger.info("Deploy and start the Server App.") - server_thread = threading.Thread(target=self.start_server_app, args=[]) + args = copy.deepcopy(self.args) + server_thread = threading.Thread(target=self.start_server_app, args=[args]) server_thread.start() # wait for the server app is started @@ -448,17 +456,23 @@ def client_run(self, clients, gpu): client_runner = SimulatorClientRunner(self.args, clients, self.client_config, self.deploy_args, self.build_ctx) client_runner.run(gpu) - def start_server_app(self): - app_server_root = os.path.join(self.simulator_root, "app_server") - self.args.server_config = os.path.join("config", JobConstants.SERVER_JOB_CONFIG) + def start_server_app(self, args): + app_server_root = os.path.join(self.simulator_root, "server", SimulatorConstants.JOB_NAME, "app_server") + args.workspace = app_server_root + os.chdir(args.workspace) + + log_file = os.path.join(self.simulator_root, "server", WorkspaceConstants.LOG_FILE_NAME) + add_logfile_handler(log_file) + + args.server_config = os.path.join("config", JobConstants.SERVER_JOB_CONFIG) app_custom_folder = os.path.join(app_server_root, "custom") sys.path.append(app_custom_folder) - startup = os.path.join(self.args.workspace, WorkspaceConstants.STARTUP_FOLDER_NAME) + startup = os.path.join(args.workspace, WorkspaceConstants.STARTUP_FOLDER_NAME) os.makedirs(startup, exist_ok=True) - local = os.path.join(self.args.workspace, WorkspaceConstants.SITE_FOLDER_NAME) + local = os.path.join(args.workspace, WorkspaceConstants.SITE_FOLDER_NAME) os.makedirs(local, exist_ok=True) - workspace = Workspace(root_dir=self.args.workspace, site_name="server") + workspace = Workspace(root_dir=args.workspace, site_name="server") self.server.job_cell = self.server.create_job_cell( SimulatorConstants.JOB_NAME, @@ -471,7 +485,7 @@ def start_server_app(self): snapshot = None kv_list = [f"secure_train={self.server.secure_train}"] server_app_runner.start_server_app( - workspace, self.args, app_server_root, self.args.job_id, snapshot, self.logger, kv_list=kv_list + workspace, args, app_server_root, args.job_id, snapshot, self.logger, kv_list=kv_list ) # start = time.time() @@ -489,8 +503,8 @@ def start_server_app(self): def dump_stats(self, workspace: Workspace): stats_dict = StatsPoolManager.to_dict() json_object = json.dumps(stats_dict, indent=4) - os.makedirs(os.path.join(workspace.get_run_dir(SimulatorConstants.JOB_NAME), POOL_STATS_DIR)) - file = os.path.join(workspace.get_run_dir(SimulatorConstants.JOB_NAME), POOL_STATS_DIR, SIMULATOR_POOL_STATS) + os.makedirs(os.path.join(workspace.get_root_dir(), POOL_STATS_DIR)) + file = os.path.join(workspace.get_root_dir(), POOL_STATS_DIR, SIMULATOR_POOL_STATS) with open(file, "w") as outfile: outfile.write(json_object) @@ -502,7 +516,7 @@ def __init__(self, args, clients: [], client_config, deploy_args, build_ctx): self.federated_clients = clients self.run_client_index = -1 - self.simulator_root = os.path.join(self.args.workspace, SimulatorConstants.JOB_NAME) + self.simulator_root = self.args.workspace self.client_config = client_config self.deploy_args = deploy_args self.build_ctx = build_ctx @@ -576,13 +590,16 @@ def run_client_thread(self, num_of_threads, gpu, lock, rank, timeout=60): def do_one_task(self, client, num_of_threads, gpu, lock, timeout=60.0, task_name=RunnerTask.TASK_EXEC): open_port = get_open_ports(1)[0] - client_workspace = os.path.join(self.args.workspace, SimulatorConstants.JOB_NAME, "app_" + client.client_name) + client_workspace = os.path.join(self.args.workspace, client.client_name) + logging_config = os.path.join( + self.args.workspace, client.client_name, "local", WorkspaceConstants.LOGGING_CONFIG + ) command = ( sys.executable + " -m nvflare.private.fed.app.simulator.simulator_worker -o " + client_workspace + " --logging_config " - + self.logging_config + + logging_config + " --client " + client.client_name + " --token " @@ -612,10 +629,12 @@ def do_one_task(self, client, num_of_threads, gpu, lock, timeout=60.0, task_name conn = self._create_connection(open_port, timeout=timeout) self.build_ctx["client_name"] = client.client_name + deploy_args = copy.deepcopy(self.deploy_args) + deploy_args.workspace = os.path.join(deploy_args.workspace, client.client_name) data = { # SimulatorConstants.CLIENT: client, SimulatorConstants.CLIENT_CONFIG: self.client_config, - SimulatorConstants.DEPLOY_ARGS: self.deploy_args, + SimulatorConstants.DEPLOY_ARGS: deploy_args, SimulatorConstants.BUILD_CTX: self.build_ctx, } conn.send(data) diff --git a/nvflare/private/fed/app/simulator/simulator_worker.py b/nvflare/private/fed/app/simulator/simulator_worker.py index 85cde612c7..4fc8b2d0e9 100644 --- a/nvflare/private/fed/app/simulator/simulator_worker.py +++ b/nvflare/private/fed/app/simulator/simulator_worker.py @@ -37,7 +37,7 @@ from nvflare.private.fed.simulator.simulator_app_runner import SimulatorClientAppRunner from nvflare.private.fed.simulator.simulator_audit import SimulatorAuditor from nvflare.private.fed.simulator.simulator_const import SimulatorConstants -from nvflare.private.fed.utils.fed_utils import add_logfile_handler, fobs_initialize +from nvflare.private.fed.utils.fed_utils import add_logfile_handler, fobs_initialize, get_simulator_app_root from nvflare.security.logging import secure_format_exception, secure_log_traceback from nvflare.security.security import EmptyAuthorizer @@ -146,7 +146,7 @@ def run(self, args, conn): client = self._create_client(args, build_ctx, deploy_args) - app_root = os.path.join(args.workspace, SimulatorConstants.JOB_NAME, "app_" + client.client_name) + app_root = get_simulator_app_root(args.workspace, client.client_name) app_custom_folder = os.path.join(app_root, "custom") sys.path.append(app_custom_folder) @@ -185,7 +185,7 @@ def _create_client(self, args, build_ctx, deploy_args): return client def _set_client_status(self, client, deploy_args, simulator_root): - app_client_root = os.path.join(simulator_root, "app_" + client.client_name) + app_client_root = get_simulator_app_root(simulator_root, client.client_name) client.app_client_root = app_client_root client.args = deploy_args # self.create_client_runner(client) @@ -238,6 +238,11 @@ def main(args): app_custom_folder = os.path.join(args.workspace, "custom") sys.path.append(app_custom_folder) os.chdir(args.workspace) + startup = os.path.join(args.workspace, WorkspaceConstants.STARTUP_FOLDER_NAME) + os.makedirs(startup, exist_ok=True) + local = os.path.join(args.workspace, WorkspaceConstants.SITE_FOLDER_NAME) + os.makedirs(local, exist_ok=True) + fobs_initialize() AuthorizationService.initialize(EmptyAuthorizer()) # AuditService.initialize(audit_file_name=WorkspaceConstants.AUDIT_LOG) diff --git a/nvflare/private/fed/utils/fed_utils.py b/nvflare/private/fed/utils/fed_utils.py index 7bb43be35a..2eddb292b5 100644 --- a/nvflare/private/fed/utils/fed_utils.py +++ b/nvflare/private/fed/utils/fed_utils.py @@ -41,6 +41,7 @@ from nvflare.security.logging import secure_format_exception from nvflare.security.security import EmptyAuthorizer, FLAuthorizer +from ..simulator.simulator_const import SimulatorConstants from .app_authz import AppAuthzService @@ -334,3 +335,7 @@ def get_return_code(process, job_id, workspace, logger): else: return_code = process.poll() return return_code + + +def get_simulator_app_root(simulator_root, site_name): + return os.path.join(simulator_root, site_name, SimulatorConstants.JOB_NAME, "app_" + site_name)