diff --git a/docs/getting_started.rst b/docs/getting_started.rst index cee26e8f9e..171a5e6c8c 100644 --- a/docs/getting_started.rst +++ b/docs/getting_started.rst @@ -304,7 +304,7 @@ we can install these in the Python virtual environment by running: .. code-block:: shell source nvflare-env/bin/activate - python3 -m pip install -r simulator-example/requirements.txt + python3 -m pip install -r simulator-example/hello-pt/requirements.txt If using the Dockerfile above to run in a container, these dependencies have already been installed. diff --git a/docs/programming_guide/experiment_tracking.rst b/docs/programming_guide/experiment_tracking.rst index c0ccbbe939..10a2fe27b7 100644 --- a/docs/programming_guide/experiment_tracking.rst +++ b/docs/programming_guide/experiment_tracking.rst @@ -4,150 +4,17 @@ Experiment Tracking ################### -*********************** -Overview and Approaches -*********************** +FLARE seamlessly integrates with leading experiment tracking systems—MLflow, Weights & Biases, and TensorBoard—to facilitate comprehensive monitoring of metrics. -In a federated computing setting, the data is distributed across multiple devices or systems, and training is run -on each device independently while preserving each client's data privacy. +You can choose between decentralized and centralized tracking configurations: -Assuming a federated system consisting of one server and many clients and the server coordinating the ML training of clients, -we can interact with ML experiment tracking tools in two different ways: +- **Decentralized tracking**: Each client manages its own metrics and experiment tracking server locally, maintaining training metric privacy. However, this setup limits the ability to compare data across different sites. +- **Centralized tracking**: All metrics are streamed to a central FL server, which then pushes the data to a selected tracking system. This setup supports effective cross-site metric comparisons - - Client-side experiment tracking: Each client will directly send the log metrics/parameters to the ML experiment - tracking server (like MLflow or Weights and Biases) or local file system (like tensorboard) - - Aggregated experiment tracking: Clients will send the log metrics/parameters to FL server, and the FL server will - send the metrics to ML experiment tracking server or local file system +We provide solutions for different client execution types. For the Client API, use the corresponding experiment tracking APIs. For Executors or Learners, use the experiment tracking LogWriters. -Each approach will have its use cases and unique challenges. In NVFLARE, we developed a server-side approach (in the -provided examples, the Receiver is on the FL server, but it could also be on the FL client): +.. toctree:: + :maxdepth: 1 - - Clients don't need to have access to the tracking server, avoiding the additional - authentication for every client. In many cases, the clients may be from different organizations - and different from the host organization of the experiment tracking server. - - Since we reduced connections to the tracking server from N clients to just one server, the traffic to the tracking server - can be highly reduced. In some cases, such as in MLFLow, the events can be buffered in the server and sent to the tracking - server in batches, further reducing the traffic to the tracking server. The buffer may add additional latency, so you can - disable the buffering if you can set the buffer flush time to 0 assuming the tracking server can take the traffic. - - Another key benefit of using server-side experiment tracking is that we separate the metrics data collection - from the metrics data delivery to the tracking server. Clients are only responsible for collecting metrics, and only the server needs to - know about the tracking server. This allows us to have different tools for data collection and data delivery. - For example, if the client has training code with logging in Tensorboard syntax, without changing the code, the server can - receive the logged data and deliver the metrics to MLflow. - - Server-side experiment tracking also can organize different clients' results into different experiment runs so they can be easily - compared side-by-side. - -.. note:: - - This page covers experiment tracking using :class:`LogWriters <nvflare.app_common.tracking.log_writer.LogWriter>`, - which are configured and used with :ref:`executor` or :ref:`model_learner` on the FLARE-side code. - However if using the Client API, please refer to :ref:`client_api` and :ref:`nvflare.client.tracking` for adding experiment tracking to your custom training code. - - -************************************** -Tools, Sender, LogWriter and Receivers -************************************** - -With the "experiment_tracking" examples in the advanced examples directory, you can see how to track and visualize -experiments in real time and compare results by leveraging several experiment tracking solutions: - - - `Tensorboard <https://www.tensorflow.org/tensorboard>`_ - - `MLflow <https://mlflow.org/>`_ - - `Weights and Biases <https://wandb.ai/site>`_ - -.. note:: - - The user needs to sign up at Weights and Biases to access the service, NVFlare can not provide access. - -In the Federated Learning phase, users can choose an API syntax that they are used to from one -of above tools. NVFlare has developed components that mimic these APIs called -:class:`LogWriters <nvflare.app_common.tracking.log_writer.LogWriter>`. All clients experiment logs -are streamed over to the FL server (with :class:`ConvertToFedEvent<nvflare.app_common.widgets.convert_to_fed_event.ConvertToFedEvent>`), -where the actual experiment logs are recorded. The components that receive -these logs are called Receivers based on :class:`AnalyticsReceiver <nvflare.app_common.widgets.streaming.AnalyticsReceiver>`. -The receiver component leverages the experiment tracking tool and records the logs during the experiment run. - -In a normal setting, we would have pairs of sender and receivers, with some provided implementations in :mod:`nvflare.app_opt.tracking`: - - - TBWriter <-> TBAnalyticsReceiver - - MLflowWriter <-> MLflowReceiver - - WandBWriter <-> WandBReceiver - -You can also mix and match any combination of LogWriter and Receiver so you can write the ML code using one API -but use any experiment tracking tool or tools (you can use multiple receivers for the same log data sent from one sender). - -.. image:: ../resources/experiment_tracking.jpg - -************************* -Experiment logs streaming -************************* - -On the client side, when a :class:`LogWriters <nvflare.app_common.tracking.log_writer.LogWriter>` writes the -metrics, instead of writing to files, it actually generates an NVFLARE event (of type `analytix_log_stats` by default). -The `ConvertToFedEvent` widget will turn the local event `analytix_log_stats` into a -fed event `fed.analytix_log_stats`, which will be delivered to the server side. - -On the server side, the :class:`AnalyticsReceiver <nvflare.app_common.widgets.streaming.AnalyticsReceiver>` is configured -to process `fed.analytix_log_stats` events, which writes received log data to the appropriate tracking solution. - -**************************************** -Support custom experiment tracking tools -**************************************** - -There are many different experiment tracking tools, and you might want to write a custom writer and/or receiver for your needs. - -There are three things to consider for developing a custom experiment tracking tool. - -Data Type -========= - -Currently, the supported data types are listed in :class:`AnalyticsDataType <nvflare.apis.analytix.AnalyticsDataType>`, and other data types can be added as needed. - -Writer -====== -Implement :class:`LogWriter <nvflare.app_common.tracking.log_writer.LogWriter>` interface with the API syntax. For each tool, we mimic the API syntax of the underlying tool, -so users can use what they are familiar with without learning a new API. -For example, for Tensorboard, TBWriter uses add_scalar() and add_scalars(); for MLflow, the syntax is -log_metric(), log_metrics(), log_parameter(), and log_parameters(); for W&B, the writer just has log(). -The data collected with these calls will all send to the AnalyticsSender to deliver to the FL server. - -Receiver -======== - -Implement :class:`AnalyticsReceiver <nvflare.app_common.widgets.streaming.AnalyticsReceiver>` interface and determine how to represent different sites' logs. In all three implementations -(Tensorboard, MLflow, WandB), each site's log is represented as one run. Depending on the individual tool, the implementation -can be different. For example, for both Tensorboard and MLflow, we create different runs for each client and map to the -site name. In the WandB implementation, we have to leverage multiprocess and let each run in a different process. - -***************** -Examples Overview -***************** - -The :github_nvflare_link:`experiment tracking examples <examples/advanced/experiment-tracking>` -illustrate how to leverage different writers and receivers. All examples are based upon the hello-pt example. - -TensorBoard -=========== -The example in the "tensorboard" directory shows how to use the Tensorboard Tracking Tool (for both the -sender and receiver). See :ref:`tensorboard_streaming` for details. - -MLflow -====== -Under the "mlflow" directory, the "hello-pt-mlflow" job shows how to use MLflow for tracking with both the MLflow sender -and receiver. The "hello-pt-tb-mlflow" job shows how to use the Tensorboard Sender, while the receiver is MLflow. -See :ref:`experiment_tracking_mlflow` for details. - -Weights & Biases -================ -Under the :github_nvflare_link:`wandb <examples/advanced/experiment-tracking/wandb>` directory, the -"hello-pt-wandb" job shows how to use Weights and Biases for experiment tracking with -the WandBWriter and WandBReceiver to log metrics. - -MONAI Integration -================= - -:github_nvflare_link:`Integration with MONAI <integration/monai>` uses the `NVFlareStatsHandler` -:class:`LogWriterForMetricsExchanger <nvflare.app_common.tracking.LogWriterForMetricsExchanger>` to connect to -:class:`MetricsRetriever <nvflare.app_common.metrics_exchange.MetricsRetriever>`. See the job -:github_nvflare_link:`spleen_ct_segmentation_local <integration/monai/examples/spleen_ct_segmentation_local/jobs/spleen_ct_segmentation_local>` -for more details on this configuration. + experiment_tracking/experiment_tracking_apis + experiment_tracking/experiment_tracking_log_writer diff --git a/docs/programming_guide/experiment_tracking/experiment_tracking_apis.rst b/docs/programming_guide/experiment_tracking/experiment_tracking_apis.rst new file mode 100644 index 0000000000..000b9dbf62 --- /dev/null +++ b/docs/programming_guide/experiment_tracking/experiment_tracking_apis.rst @@ -0,0 +1,211 @@ +.. _experiment_tracking_apis: + +######################## +Experiment Tracking APIs +######################## + +.. figure:: ../../resources/experiment_tracking_diagram.png + :height: 500px + +To track training metrics such as accuracy or loss or AUC, we need to log these metrics with one of the experiment tracking systems. +Here we will discuss the following topics: + +- Logging metrics with MLflow, TensorBoard, or Weights & Biases +- Streaming metrics to the FL server +- Streaming to FL clients + +Logging metrics with MLflow, TensorBoard, or Weights & Biases +============================================================= + +Integrate MLflow logging to efficiently stream metrics to the MLflow server with just three lines of code: + +.. code-block:: python + + from nvflare.client.tracking import MLflowWriter + + mlflow = MLflowWriter() + + mlflow.log_metric("loss", running_loss / 2000, global_step) + +In this setup, we use ``MLflowWriter`` instead of using the MLflow API directly. +This abstraction is important, as it enables users to flexibly redirect your logging metrics to any destination, which we discuss in more detail later. + +The use of MLflow, TensorBoard, or Weights & Biases syntax will all work to stream the collected metrics to any supported experiment tracking system. +Choosing to use TBWriter, MLflowWriter, or WandBWriter is user preference based on your existing code and requirements. + +- ``MLflowWriter`` uses the Mlflow API operation syntax ``log_metric()`` +- ``TBWriter`` uses the TensorBoard SummaryWriter operation ``add_scalar()`` +- ``WandBWriter`` uses the Weights & Biases API operation ``log()`` + +Here are the APIs: + +.. code-block:: python + + class TBWriter(LogWriter): + def add_scalar(self, tag: str, scalar: float, global_step: Optional[int] = None, **kwargs): + def add_scalars(self, tag: str, scalars: dict, global_step: Optional[int] = None, **kwargs): + + + class WandBWriter(LogWriter): + def log(self, metrics: Dict[str, float], step: Optional[int] = None): + + + class MLflowWriter(LogWriter): + def log_param(self, key: str, value: any) -> None: + def log_params(self, values: dict) -> None: + def log_metric(self, key: str, value: float, step: Optional[int] = None) -> None: + def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None: + def log_text(self, text: str, artifact_file_path: str) -> None: + def set_tag(self, key: str, tag: any) -> None: + def set_tags(self, tags: dict) -> None: + + +After you've modified the training code, you can use the NVFlare's job configuration to configure the system to stream the logs appropriately. + +Streaming metrics to FL server +============================== + +All metric key values are captured as events, with the flexibility to stream them to the most suitable destinations. +Let's add the ``ConvertToFedEvent`` to convert these metrics events to federated events so they will be sent to the server. + +Add this component to config_fed_client.json: + +.. code-block:: yaml + + { + "id": "event_to_fed", + "name": "ConvertToFedEvent", + "args": {"events_to_convert": ["analytix_log_stats"], "fed_event_prefix": "fed."} + } + +If using the subprocess Client API with the ClientAPILauncherExecutor (rather than the in-process Client API with the InProcessClientAPIExecutor), +we need to add the ``MetricRelay`` to fire fed events, a ``CellPipe`` for metrics, and an ``ExternalConfiguator`` for client api initialization. + +.. code-block:: yaml + { + id = "metric_relay" + path = "nvflare.app_common.widgets.metric_relay.MetricRelay" + args { + pipe_id = "metrics_pipe" + event_type = "fed.analytix_log_stats" + read_interval = 0.1 + } + }, + { + id = "metrics_pipe" + path = "nvflare.fuel.utils.pipe.cell_pipe.CellPipe" + args { + mode = "PASSIVE" + site_name = "{SITE_NAME}" + token = "{JOB_ID}" + root_url = "{ROOT_URL}" + secure_mode = "{SECURE_MODE}" + workspace_dir = "{WORKSPACE}" + } + }, + { + id = "config_preparer" + path = "nvflare.app_common.widgets.external_configurator.ExternalConfigurator" + args { + component_ids = ["metric_relay"] + } + } + + +On the server, configure the experiment tracking system in ``config_fed_server.conf`` using one of the following receivers. +Note that any of the receivers can be used regardless of the which writer is used. + +- ``MLflowReceiver`` for MLflow +- ``TBAnalyticsReceiver`` for TensorBoard +- ``WandBReceiver`` for Weights & Biases + +For example, here we add the ``MLflowReceiver`` component to the components configuration array: + +.. code-block:: yaml + + { + "id": "mlflow_receiver_with_tracking_uri", + "path": "nvflare.app_opt.tracking.mlflow.mlflow_receiver.MLflowReceiver", + "args": { + tracking_uri = "file:///{WORKSPACE}/{JOB_ID}/mlruns" + "kwargs": { + "experiment_name": "hello-pt-experiment", + "run_name": "hello-pt-with-mlflow", + "experiment_tags": { + "mlflow.note.content": "markdown for the experiment" + }, + "run_tags": { + "mlflow.note.content": "markdown describes details of experiment" + } + }, + "artifact_location": "artifacts" + } + } + +Notice the args{} are user defined, such as tracking_uri, experiment_name, tags etc., and will be specific to which receiver is configured. + +The MLflow tracking URL argument ``tracking_uri`` is None by default, which uses the MLflow default URL, ``http://localhost:5000``. +To make this accessible from another machine, make sure to change it to the correct URL, or point to to the ``mlruns`` directory in the workspace. + +:: + + tracking_uri = <the Mlflow Server endpoint URL> + +:: + + tracking_uri = "file:///{WORKSPACE}/{JOB_ID}/mlruns" + +You can change other arguments such as experiments, run_name, tags (using Markdown syntax), and artifact location. + +Start the MLflow server with one of the following commands: + +:: + + mlflow server --host 127.0.0.1 --port 5000 + +:: + + mlflow ui -port 5000 + +For more information with an example walkthrough, see the :github_nvflare_link:`FedAvg with SAG with MLflow tutorial <examples/hello-world/step-by-step/cifar10/sag_mlflow/sag_mlflow.ipynb>`. + + +Streaming metrics to FL clients +=============================== + +If streaming metrics to the FL server isn't preferred due to privacy or other concerns, users can alternatively stream metrics to the FL client. +In such cases, there's no need to add the ``ConvertToFedEvent`` component on the client side. +Additionally, since we're not streaming to the server side, there's no requirement to configure receivers in the server configuration. + +Instead to receive records on the client side, configure the metrics receiver in the client configuration instead of the server configuration. + +For example, in the TensorBoard configuration, add this component to ``config_fed_client.conf``: + +.. code-block:: yaml + + { + "id": "tb_analytics_receiver", + "name": "TBAnalyticsReceiver", + "args": {"events": ["analytix_log_stats"]} + } + +Note that the ``events`` argument is ``analytix_log_stats``, not ``fed.analytix_log_stats``, indicating that this is a local event. + +If using the ``MetricRelay`` component, we can similarly component event_type value from ``fed.analytix_log_stats`` to ``analytix_log_stats`` for convention. +We then must set the ``MetricRelay`` argument ``fed_event`` to ``false`` to fire local events rather than the default fed events. + +.. code-block:: yaml + + { + id = "metric_relay" + path = "nvflare.app_common.widgets.metric_relay.MetricRelay" + args { + pipe_id = "metrics_pipe" + event_type = "analytix_log_stats" + # how fast should it read from the peer + read_interval = 0.1 + fed_event = false + } + }, + +Then, the metrics will stream to the client. diff --git a/docs/programming_guide/experiment_tracking/experiment_tracking_log_writer.rst b/docs/programming_guide/experiment_tracking/experiment_tracking_log_writer.rst new file mode 100644 index 0000000000..58fc744f55 --- /dev/null +++ b/docs/programming_guide/experiment_tracking/experiment_tracking_log_writer.rst @@ -0,0 +1,152 @@ +.. _experiment_tracking_log_writer: + +############################## +Experiment Tracking Log Writer +############################## + +.. note:: + + This page covers experiment tracking using :class:`LogWriters <nvflare.app_common.tracking.log_writer.LogWriter>`, + which are configured and used with :ref:`executor` or :ref:`model_learner` on the FLARE-side code. + If using the Client API, please refer to :ref:`experiment_tracking_apis` and :ref:`client_api` for adding experiment tracking to your custom training code. + +*********************** +Overview and Approaches +*********************** + +In a federated computing setting, the data is distributed across multiple devices or systems, and training is run +on each device independently while preserving each client's data privacy. + +Assuming a federated system consisting of one server and many clients and the server coordinating the ML training of clients, +we can interact with ML experiment tracking tools in two different ways: + + - Client-side experiment tracking: Each client will directly send the log metrics/parameters to the ML experiment + tracking server (like MLflow or Weights and Biases) or local file system (like tensorboard) + - Aggregated experiment tracking: Clients will send the log metrics/parameters to FL server, and the FL server will + send the metrics to ML experiment tracking server or local file system + +This is enabled by Receivers, which can be configured on the FL server, FL client, or on both. Each approach will have its use cases and unique challenges. +Here we provide examples and describe the server-side approach: + + - Clients don't need to have access to the tracking server, avoiding the additional + authentication for every client. In many cases, the clients may be from different organizations + and different from the host organization of the experiment tracking server. + - Since we reduced connections to the tracking server from N FL clients to just one FL server, the traffic to the tracking server + can be highly reduced. In some cases, such as in MLFLow, the events can be buffered in the server and sent to the tracking + server in batches, further reducing the traffic to the tracking server. The buffer may add additional latency, so you can + disable the buffering if you can set the buffer flush time to 0 assuming the tracking server can take the traffic. + - Another key benefit of using server-side experiment tracking is that we separate the metrics data collection + from the metrics data delivery to the tracking server. Clients are only responsible for collecting metrics, and only the server needs to + know about the tracking server. This allows us to have different tools for data collection and data delivery. + For example, if the client has training code with logging in Tensorboard syntax, without changing the code, the server can + receive the logged data and deliver the metrics to MLflow. + - Server-side experiment tracking also can organize different clients' results into different experiment runs so they can be easily + compared side-by-side. + +************************************** +Tools, Sender, LogWriter and Receivers +************************************** + +With the "experiment_tracking" examples in the advanced examples directory, you can see how to track and visualize +experiments in real time and compare results by leveraging several experiment tracking solutions: + + - `Tensorboard <https://www.tensorflow.org/tensorboard>`_ + - `MLflow <https://mlflow.org/>`_ + - `Weights and Biases <https://wandb.ai/site>`_ + +.. note:: + + The user needs to sign up at Weights and Biases to access the service, NVFlare can not provide access. + +In the Federated Learning phase, users can choose an API syntax that they are used to from one +of above tools. NVFlare has developed components that mimic these APIs called +:class:`LogWriters <nvflare.app_common.tracking.log_writer.LogWriter>`. All clients experiment logs +are streamed over to the FL server (with :class:`ConvertToFedEvent<nvflare.app_common.widgets.convert_to_fed_event.ConvertToFedEvent>`), +where the actual experiment logs are recorded. The components that receive +these logs are called Receivers based on :class:`AnalyticsReceiver <nvflare.app_common.widgets.streaming.AnalyticsReceiver>`. +The receiver component leverages the experiment tracking tool and records the logs during the experiment run. + +In a normal setting, we would have pairs of sender and receivers, with some provided implementations in :mod:`nvflare.app_opt.tracking`: + + - TBWriter <-> TBAnalyticsReceiver + - MLflowWriter <-> MLflowReceiver + - WandBWriter <-> WandBReceiver + +You can also mix and match any combination of LogWriter and Receiver so you can write the ML code using one API +but use any experiment tracking tool or tools (you can use multiple receivers for the same log data sent from one sender). + +.. image:: ../../resources/experiment_tracking.jpg + +************************* +Experiment logs streaming +************************* + +On the client side, when a :class:`LogWriters <nvflare.app_common.tracking.log_writer.LogWriter>` writes the +metrics, instead of writing to files, it actually generates an NVFLARE event (of type `analytix_log_stats` by default). +The `ConvertToFedEvent` widget will turn the local event `analytix_log_stats` into a +fed event `fed.analytix_log_stats`, which will be delivered to the server side. + +On the server side, the :class:`AnalyticsReceiver <nvflare.app_common.widgets.streaming.AnalyticsReceiver>` is configured +to process `fed.analytix_log_stats` events, which writes received log data to the appropriate tracking solution. + +**************************************** +Support custom experiment tracking tools +**************************************** + +There are many different experiment tracking tools, and you might want to write a custom writer and/or receiver for your needs. + +There are three things to consider for developing a custom experiment tracking tool. + +Data Type +========= + +Currently, the supported data types are listed in :class:`AnalyticsDataType <nvflare.apis.analytix.AnalyticsDataType>`, and other data types can be added as needed. + +Writer +====== +Implement :class:`LogWriter <nvflare.app_common.tracking.log_writer.LogWriter>` interface with the API syntax. For each tool, we mimic the API syntax of the underlying tool, +so users can use what they are familiar with without learning a new API. +For example, for Tensorboard, TBWriter uses add_scalar() and add_scalars(); for MLflow, the syntax is +log_metric(), log_metrics(), log_parameter(), and log_parameters(); for W&B, the writer just has log(). +The data collected with these calls will all send to the AnalyticsSender to deliver to the FL server. + +Receiver +======== + +Implement :class:`AnalyticsReceiver <nvflare.app_common.widgets.streaming.AnalyticsReceiver>` interface and determine how to represent different sites' logs. In all three implementations +(Tensorboard, MLflow, WandB), each site's log is represented as one run. Depending on the individual tool, the implementation +can be different. For example, for both Tensorboard and MLflow, we create different runs for each client and map to the +site name. In the WandB implementation, we have to leverage multiprocess and let each run in a different process. + +***************** +Examples Overview +***************** + +The :github_nvflare_link:`experiment tracking examples <examples/advanced/experiment-tracking>` +illustrate how to leverage different writers and receivers. All examples are based upon the hello-pt example. + +TensorBoard +=========== +The example in the "tensorboard" directory shows how to use the Tensorboard Tracking Tool (for both the +sender and receiver). See :ref:`tensorboard_streaming` for details. + +MLflow +====== +Under the "mlflow" directory, the "hello-pt-mlflow" job shows how to use MLflow for tracking with both the MLflow sender +and receiver. The "hello-pt-tb-mlflow" job shows how to use the Tensorboard Sender, while the receiver is MLflow. +See :ref:`experiment_tracking_mlflow` for details. + +Weights & Biases +================ +Under the :github_nvflare_link:`wandb <examples/advanced/experiment-tracking/wandb>` directory, the +"hello-pt-wandb" job shows how to use Weights and Biases for experiment tracking with +the WandBWriter and WandBReceiver to log metrics. + +MONAI Integration +================= + +:github_nvflare_link:`Integration with MONAI <integration/monai>` uses the `NVFlareStatsHandler` +:class:`LogWriterForMetricsExchanger <nvflare.app_common.tracking.LogWriterForMetricsExchanger>` to connect to +:class:`MetricsRetriever <nvflare.app_common.metrics_exchange.MetricsRetriever>`. See the job +:github_nvflare_link:`spleen_ct_segmentation_local <integration/monai/examples/spleen_ct_segmentation_local/jobs/spleen_ct_segmentation_local>` +for more details on this configuration. diff --git a/docs/resources/experiment_tracking_diagram.png b/docs/resources/experiment_tracking_diagram.png new file mode 100644 index 0000000000..58f5d2fc5e Binary files /dev/null and b/docs/resources/experiment_tracking_diagram.png differ diff --git a/docs/resources/processor_interface_design.png b/docs/resources/processor_interface_design.png new file mode 100644 index 0000000000..ff84dae5d2 Binary files /dev/null and b/docs/resources/processor_interface_design.png differ diff --git a/docs/resources/secure_horizontal_xgb.png b/docs/resources/secure_horizontal_xgb.png new file mode 100644 index 0000000000..c09c847e2c Binary files /dev/null and b/docs/resources/secure_horizontal_xgb.png differ diff --git a/docs/resources/secure_vertical_xgb.png b/docs/resources/secure_vertical_xgb.png new file mode 100644 index 0000000000..f8e596d6dc Binary files /dev/null and b/docs/resources/secure_vertical_xgb.png differ diff --git a/docs/resources/xgb_communicator.jpg b/docs/resources/xgb_communicator.jpg new file mode 100644 index 0000000000..c04c547674 Binary files /dev/null and b/docs/resources/xgb_communicator.jpg differ diff --git a/docs/user_guide/federated_xgboost.rst b/docs/user_guide/federated_xgboost.rst new file mode 100644 index 0000000000..4f38f0587b --- /dev/null +++ b/docs/user_guide/federated_xgboost.rst @@ -0,0 +1,28 @@ +############################## +Federated XGBoost with NVFlare +############################## + +XGBoost (https://github.com/dmlc/xgboost) is an open-source project that +implements machine learning algorithms under the Gradient Boosting framework. +It is an optimized distributed gradient boosting library designed to be highly +efficient, flexible and portable. +This implementation uses MPI (message passing interface) for client +communication and synchronization. + +MPI requires the underlying communication network to be perfect - a single +message drop causes the training to fail. + +This is usually achieved via a highly reliable special-purpose network like NCCL. + +The open-source XGBoost supports federated paradigm, where clients are in different +locations and communicate with each other with gRPC over internet connections. + +We introduce federated XGBoost with NVFlare for a more reliable federated setup. + +.. toctree:: + :maxdepth: 1 + + federated_xgboost/reliable_xgboost_design + federated_xgboost/reliable_xgboost_timeout + federated_xgboost/secure_xgboost_design + federated_xgboost/secure_xgboost_user_guide diff --git a/docs/user_guide/federated_xgboost/reliable_xgboost_design.rst b/docs/user_guide/federated_xgboost/reliable_xgboost_design.rst new file mode 100644 index 0000000000..99747dcd11 --- /dev/null +++ b/docs/user_guide/federated_xgboost/reliable_xgboost_design.rst @@ -0,0 +1,65 @@ +################################# +Reliable Federated XGBoost Design +################################# + + +************************* +Flare as XGBoost Launcher +************************* + +NVFLARE serves as a launchpad to start the XGBoost system. +Once started, the XGBoost system runs independently of FLARE, +as illustrated in the following figure. + +.. figure:: ../../resources/loose_xgb.png + :height: 500px + +There are a few potential problems with this approach: + + - As we know, MPI requires a perfect communication network, + whereas the simple gRPC over the internet could be unstable. + + - For each job, the XGBoost Server must open a port for clients to connect to. + This adds burden to request IT for the additional port in the real-world situation. + Even if a fixed port is allowed to open, and we reuse that port, + multiple XGBoost jobs can not be run at the same time, + since each XGBoost job requires a different port number. + + +***************************** +Flare as XGBoost Communicator +***************************** + +FLARE provides a highly flexible, scalable and reliable communication mechanism. +We enhance the reliability of federated XGBoost by using FLARE as the communicator of XGBoost, +as shown here: + +.. figure:: ../../resources/tight_xgb.png + :height: 500px + +Detailed Design +=============== + +The open-source Federated XGBoost (c++) uses gRPC as the communication protocol. +To use FLARE as the communicator, we simply route XGBoost's gRPC messages through FLARE. +To do so, we change the server endpoint of each XGBoost client to a local gRPC server +(LGS) within the FLARE client. + +.. figure:: ../../resources/fed_xgb_detail.png + :height: 500px + +As shown in this diagram, there is a local GRPC server (LGS) for each site +that serves as the server endpoint for the XGBoost client on the site. +Similarly, there is a local GRPC Client (LGC) on the FL Server that +interacts with the XGBoost Server. The message path between the XGBoost Client and +the XGBoost Server is as follows: + + 1. The XGBoost client generates a gRPC message and sends it to the LGS in FLARE Client + 2. FLARE Client forwards the message to the FLARE Server. This is a reliable FLARE message. + 3. FLARE Server uses the LGC to send the message to the XGBoost Server. + 4. XGBoost Server sends the response back to the LGC in FLARE Server. + 5. FLARE Server sends the response back to the FLARE Client. + 6. FLARE Client sends the response back to the XGBoost Client via the LGS. + +Please note that the XGBoost Client (c++) component could be running as a separate process +or within the same process of FLARE Client. diff --git a/docs/user_guide/federated_xgboost/reliable_xgboost_timeout.rst b/docs/user_guide/federated_xgboost/reliable_xgboost_timeout.rst new file mode 100644 index 0000000000..62b8c2516a --- /dev/null +++ b/docs/user_guide/federated_xgboost/reliable_xgboost_timeout.rst @@ -0,0 +1,93 @@ +############################################ +Reliable Federated XGBoost Timeout Mechanism +############################################ + +NVFlare introduces a tightly-coupled integration between XGBoost and NVFlare. +NVFlare implements the ReliableMessage mechanism to make XGBoost’s server/client +interactions more robust over unstable internet connections. + +Unstable internet connection is the situation where the connections between +the communication endpoints have random disconnects/reconnects and unstable speed. +It is not meant to be an extended internet outage. + +ReliableMessage does not mean guaranteed delivery. +It only means that it will try its best to deliver the message to the peer. +If one attempt fails, it will keep trying until either the message is +successfully delivered or a specified "transaction timeout" is reached. + +***************** +Timeout Mechanism +***************** + +In runtime, the FLARE System is configured with a few important timeout parameters. + +ReliableMessage Timeout +======================= + +There are two timeout values to control the behavior of ReliableMessage (RM). + +Per-message Timeout +------------------- + +Essentially RM tries to resend the message until delivered successfully. +Each resend of the message requires a timeout value. +This value should be defined based on the message size, overall network speed, +and the amount of time needed to process the message in a normal situation. +For example, if an XGBoost message takes no more than 5 seconds to be +sent, processed, and replied. +The per-message timeout should be set to 5 seconds. + +.. note:: + + Note that the initial XGBoost message might take more than 100 seconds + depends on the dataset size. + +Transaction Timeout +------------------- + +This value defines how long you want RM to keep retrying until done, in case +of unstable connection. +This value should be defined based on the overall stability of the connection, +nature of the connection, and how quickly the connection is restored. +For occasional connection glitches, this value shouldn't have to be too big +(e.g. 20 seconds). +However if the outage is long (say 60 seconds or longer), then this value +should be big enough. + +.. note:: + + Note that even if you think the connection is restored (e.g. replugged + the internet cable or reactivated WIFI), the underlying connection + layer may take much longer to actually restore connections (e.g. up to + a few minutes)! + +.. note:: + + Note: if the transaction timeout is <= per-message timeout, then the + message will be sent through simple messaging - no retry will be done + in case of failure. + +XGBoost Client Operation Timeout +================================ + +To prevent a XGBoost client from running forever, the XGBoost/FLARE +integration lets you define a parameter (max_client_op_interval) on the +server side to control the max amount of time permitted for a client to be +silent (i.e. no messages sent to the server). +The default value of this parameter is 900 seconds, meaning that if no XGB +message is received from the client for over 900 seconds, then that client +is considered dead, and the whole job is aborted. + +*************************** +Configure Timeouts Properly +*************************** + +These timeout values are related. For example, if the transaction timeout +is greater than the server timeout, then it won't be that effective since +the server will treat the client to be dead once the server timeout is reached +anyway. Similarly, it does not make sense to have transaction timeout > XGBoost +client op timeout. + +In general, follow this rule: + +Per-message Timeout < Transaction Timeout < XGBoost Client Operation Timeout diff --git a/docs/user_guide/federated_xgboost/secure_xgboost_design.rst b/docs/user_guide/federated_xgboost/secure_xgboost_design.rst new file mode 100644 index 0000000000..6f69255f59 --- /dev/null +++ b/docs/user_guide/federated_xgboost/secure_xgboost_design.rst @@ -0,0 +1,85 @@ +############################### +Secure Federated XGBoost Design +############################### + +Collaboration Modes and Secure Patterns +======================================= + +Horizontal Secure +----------------- + +For horizontal XGBoost, each party holds "equal status" - whole feature and label for partial population, while the federated server performs aggregation, without owning any data. +Hence in this case, the federated server is the "minor contributor" from model training perspective, and clients have a concern of leaking any information to the server. +Under this setting, the protection is mainly against the federated server over local histograms. + +To protect the local histograms for horizontal collaboration, the local histograms will be encrypted before sending to the federated server for aggregation. +The aggregation will then be performed over ciphertexts and the encrypted global histograms will be returned to clients, where they will be decrypted and used for tree building. + +Vertical Secure +--------------- + +For vertical XGBoost, the active party holds the label, which cannot be accessed by passive parties and can be considered the most valuable asset for the whole process. +Therefore, the active party in this case is the "major contributor" from model training perspective, and it will have a concern of leaking this information to passive clients. +In this case, the security protection is mainly against passive clients over the label information. + +To protect label information for vertical collaboration, at every round of XGBoost after the active party computes the gradients for each sample at the active party, the gradients will be encrypted before sending to passive parties. +Upon receiving the encrypted gradients (ciphertext), they will be accumulated according to the specific feature distribution at each passive party. +The resulting cumulative histograms will be returned to the active party, decrypted, and further be used for tree building at the active party. + +Decoupled Encryption with Processor Interface +============================================= + +In our current design, XGBoost communication is routed through the NVIDIA FLARE Communicator layer via local gRPC handlers. +From communication's perspective, the previous direct messages within XGBoost are now handled by FL communicator - they become "external communications" to and from XGBoost via FL system. +This gives us flexibilities in performing message operations both within XGBoost (before entering FL communicator) and within FL system (by FL communicator) + +.. figure:: ../../resources/xgb_communicator.jpg + :height: 500px + +With NVFlare, the XGBoost plugin will be implemented in C++, while the FL system communicator will be implemented in Python. A processor interface is designed and developed to properly connect the two by taking plugins implemented towards a specific HE method and collaboration mode: + +.. figure:: ../../resources/processor_interface_design.png + :height: 500px + +Processor Interface Design + + 1. Upon receiving specific MPI calls from XGBoost, each corresponding party calls interface for data processing (serialization, etc.), providing necessary information: g/h pairs, or local G/H histograms + 2. Processor interface performs necessary processing (and encryption), and send the results back as a processed buffer + 3. Each party then forward the message to local gRPC handler on FL system side + 4. After FL communication involving message routing and computation, each party receives the result buffer upon MPI calls. + 5. Each FL party then sends the received buffer to processor interface for interpretation + 6. Interface performs necessary processing (deserialization, etc.), recovers proper information, and sends the result back to XGBoost for further computation + + +Note that encryption/decryption can be performed either by processor interface (C++), or at local gRPC handler (Python) depending on the particular HE library and scheme being considered. + +System Design +============= +With the secure solutions, communication patterns, and processor interface, below we provide example designs for secure federated XGBoost - both vertical and horizontal. + +For vertical pipeline: + + 1. active party first compute g/h with the label information it owns + 2. g/h data will be sent to processor interface, encrypted with C++ based encryption util library, and sent to passive party via FL communication + 3. passive party provides indexing information for histogram computation according to local feature distributions, and the processor interface will perform aggregation with E(g/h) received. + 4. The resulting E(G/H) will be sent to active party via FL message routing + 5. Decrypted by processor interface on active party side, tree building can be performed with global histogram information + +.. figure:: ../../resources/secure_vertical_xgb.png + :height: 500px + +Secure Vertical Federated XGBoost with XGBoost-side Encryption +In this case, the "heavy-lifting" jobs - encryption, secure aggregation, etc. - are done by processor interface. + +For horizontal pipeline: + + 1. All parties sends their local G/H histograms to FL side via processor interface, in this design processor interface only performs buffer preparation without any complicated processing steps + 2. Before sending to federated server, the G/H histograms will be encrypted at local gRPC handler with Python-based encryption util library + 3. Federated server will perform secure aggregation over received partial E(G/H), and distribute the global E(G/H) to each clients, where the global histograms will be decrypted, and used for further tree-building + +.. figure:: ../../resources/secure_horizontal_xgb.png + :height: 500px + +Secure Horizontal Federated XGBoost with FL-side Encryption +In this case, the encryption is done on the FL system side. + diff --git a/docs/user_guide/federated_xgboost/secure_xgboost_user_guide.rst b/docs/user_guide/federated_xgboost/secure_xgboost_user_guide.rst new file mode 100644 index 0000000000..fa5f4df7f4 --- /dev/null +++ b/docs/user_guide/federated_xgboost/secure_xgboost_user_guide.rst @@ -0,0 +1,382 @@ +########################## +NVFlare XGBoost User Guide +########################## + +Overview +======== +NVFlare supports federated training with XGBoost. It provides the following advantages over doing the training natively with XGBoost: + +- Secure training with Homomorphic Encryption (HE) +- Lifecycle management of XGBoost processes. +- Reliable messaging which can overcome network glitches +- Training over complicated networks with relays. + +It supports federated training in the following 4 modes: + +1. **horizontal(h)**: Row split without encryption +2. **vertical(v)**: Column split without encryption +3. **horizontal_secure(hs)**: Row split with HE (Requires at least 3 clients. With 2 clients, the other client's histogram can be deduced.) +4. **vertical_secure(vs)**: Column split with HE + +When running with NVFlare, all the GRPC connections in XGBoost are local and the messages are forwarded to other clients through NVFlare's CellNet communication. The local GRPC ports are selected automatically by NVFlare. + +The encryption, decryption, and secure aggregation in XGBoost are handled by processor plugins, which are external components that can be installed at runtime. The plugins are bundled with NVFlare. + +Prerequisites +============= +Required Python Packages +------------------------ + +NVFlare 2.4.2 or above, + +.. code-block:: bash + + pip install nvflare~=2.4.2 + +A version of XGBoost that supports secure mode is required, which can be installed using this command, + +.. code-block:: bash + + pip install https://s3-us-west-2.amazonaws.com/xgboost-nightly-builds/vertical-federated-learning/xgboost-2.1.0.dev0%2Bde4013fc733648dfe5c2c803a13e2782056e00a2-py3-none-manylinux_2_28_x86_64.whl + +``TenSEAL`` package is needed for horizontal secure training, + +.. code-block:: bash + + pip install tenseal + +``ipcl_python`` package is required for vertical secure training if **nvflare** plugin is used. This package is not needed if **cuda_paillier** plugin is used. + +.. code-block:: bash + + pip install ipcl-python + +This package is only available for Python 3.8 on PyPI. For other versions of python, it needs to be installed from github, + +.. code-block:: bash + + pip install git+https://github.com/intel/pailliercryptolib_python.git@development + +System Environments +------------------- +To support secure training, several homomorphic encryption libraries are used. Those libraries require Intel CPU or Nvidia GPU. + +Linux is the preferred OS. It's tested extensively under Ubuntu 22.4. + +The following docker image is recommended for GPU training: + +:: + + nvcr.io/nvidia/pytorch:24.03-py3 + +Most Linux distributions are supported, as long as they have a recent glibc. The oldest glibc version tested is 2.35. Systems with older glibc may run into issues. + +NVFlare Provisioning +-------------------- +For horizontal secure training, the NVFlare system must be provisioned with homomorphic encryption context. The HEBuilder in ``project.yml`` is used to achieve this. +An example configuration can be found at :github_nvflare_link:`secure_project.yml <examples/advanced/cifar10/cifar10-real-world/workspaces/secure_project.yml#L64>`. + +This is a snippet of the ``secure_project.yml`` file with the HEBuilder: + +.. code-block:: yaml + + api_version: 3 + name: secure_project + description: NVIDIA FLARE sample project yaml file for CIFAR-10 example + + participants: + + ... + + builders: + - path: nvflare.lighter.impl.workspace.WorkspaceBuilder + args: + template_file: master_template.yml + - path: nvflare.lighter.impl.template.TemplateBuilder + - path: nvflare.lighter.impl.static_file.StaticFileBuilder + args: + config_folder: config + overseer_agent: + path: nvflare.ha.dummy_overseer_agent.DummyOverseerAgent + overseer_exists: false + args: + sp_end_point: localhost:8102:8103 + heartbeat_interval: 6 + - path: nvflare.lighter.impl.he.HEBuilder + args: + poly_modulus_degree: 8192 + coeff_mod_bit_sizes: [60, 40, 40] + scale_bits: 40 + scheme: CKKS + - path: nvflare.lighter.impl.cert.CertBuilder + - path: nvflare.lighter.impl.signature.SignatureBuilder + + +Data Preparation +================ +Data must be properly formatted for federated XGBoost training based on split mode (horizontal or vertical). + +For horizontal training, the datasets on all clients must share the same columns. + +For vertical training, the datasets on all clients contain different columns, but must share overlapping rows. For more details on vertical split preprocessing, refer to the :github_nvflare_link:`Vertical XGBoost Example <examples/advanced/vertical_xgboost>`. + +XGBoost Plugin Configuration +============================ +XGBoost requires a plugin to handle secure training. +Two plugins are initially shipped with NVFlare, + +- **cuda_paillier**: The default plugin. This plugin uses GPU for cryptographic operations. +- **nvflare**: This plugin forwards data locally to NVFlare process for encryption. + +Vertical (Non-secure) +--------------------- +Any plugin can be used for vertical training. No configuration change is needed. + +Horizontal (Non-secure) +----------------------- +Any plugin can be used for horizontal training. No configuration change is needed. + +Vertical Secure +--------------- +Both plugins can be used for vertical secure training. + +The default cuda_paillier plugin is preferred because it uses GPU for faster cryptographic operations. + +.. note:: + + **cuda_paillier** plugin requires NVIDIA GPUs that support compute capability 7.0 or higher. Please refer to https://developer.nvidia.com/cuda-gpus for more information. + +If you see the following errors in the log, it means either no GPU is available or the GPU does not meet the requirements: + +:: + + CUDA runtime API error no kernel image is available for execution on the device at line 241 in file /my_home/nvflare-internal/processor/src/cuda-plugin/paillier.h + 2024-07-01 12:19:15,683 - SimulatorClientRunner - ERROR - run_client_thread error: EOFError: + + +In this case, the nvflare plugin can be used to perform encryption on CPUs, which requires the ipcl-python package. +The plugin can be configured in the ``local/resources.json`` file on clients: + +.. code-block:: json + + { + "xgb_plugin_name": "nvflare" + } + +or by setting this environment variable, + +.. code-block:: bash + + export NVFLARE_XGB_PLUGIN_NAME=nvflare + +Horizontal Secure +----------------- +The plugin setup is the same as vertical secure. + +This mode requires the tenseal package for all plugins. +The provisioning of NVFlare systems must include tenseal context. +See :ref:`provisioning` for details. + + +Job Configuration +================= +Controller +---------- + +On the server side, following controller must be configured in workflows, + +``nvflare.app_opt.xgboost.histogram_based_v2.fed_controller.XGBFedController`` + +Even though the XGBoost training is performed on clients, the parameters are configured on the server so all clients share the same configuration. +XGBoost parameters are defined here, https://xgboost.readthedocs.io/en/stable/python/python_intro.html#setting-parameters + +- **num_rounds**: Number of training rounds +- **training_mode**: Training mode, must be one of the following: horizontal, vertical, horizontal_secure, vertical_secure. +- **xgb_params**: The training parameters defined in this dict are passed to XGBoost as params. +- **xgb_options**: This dict contains other optional parameters passed to XGBoost. Currently, only early_stopping_rounds is supported. +- **client_ranks**: A dict that maps client name to rank. + +Executor +-------- + +On the client side, following executor must be configured in executors, + +``nvflare.app_opt.xgboost.histogram_based_v2.fed_executor.FedXGBHistogramExecutor`` + +Only one parameter is required for executor, + +- **data_loader_id**: The component ID of Data Loader + +Data Loader +----------- + +On the client side, a data loader must be configured in the components. The SecureDataLoader can be used if the data is pre-processed. For example, + +.. code-block:: json + + { + "id": "dataloader", + "path": "nvflare.app_opt.xgboost.histogram_based_v2.secure_data_loader.SecureDataLoader", + "args": { + "rank": 0, + "folder": "/opt/dataset/vertical_xgb_data" + } + } + + +If the data requires any special processing, a custom loader can be implemented. The loader must implement the XGBDataLoader interface. + + +Job Example +=========== + +Vertical Training +----------------- + +Here are the configuration files for a vertical secure training job. If encryption is not needed, just change the training_mode to vertical. + +config_fed_server.json + +.. code-block:: json + + { + "format_version": 2, + "num_rounds": 3, + "workflows": [ + { + "id": "xgb_controller", + "path": "nvflare.app_opt.xgboost.histogram_based_v2.fed_controller.XGBFedController", + "args": { + "num_rounds": "{num_rounds}", + "training_mode": "vertical_secure", + "xgb_options": { + "early_stopping_rounds": 2 + }, + "xgb_params": { + "max_depth": 3, + "eta": 0.1, + "objective": "binary:logistic", + "eval_metric": "auc", + "tree_method": "hist", + "nthread": 1 + }, + "client_ranks": { + "site-1": 0, + "site-2": 1 + } + } + } + ] + } + + +config_fed_client.json + +.. code-block:: json + + { + "format_version": 2, + "executors": [ + { + "tasks": [ + "config", + "start" + ], + "executor": { + "id": "Executor", + "path": "nvflare.app_opt.xgboost.histogram_based_v2.fed_executor.FedXGBHistogramExecutor", + "args": { + "data_loader_id": "dataloader" + } + } + } + ], + "components": [ + { + "id": "dataloader", + "path": "nvflare.app_opt.xgboost.histogram_based_v2.secure_data_loader.SecureDataLoader", + "args": { + "rank": 0, + "folder": "/opt/dataset/vertical_xgb_data" + } + } + ] + } + + +Horizontal Training +------------------- + +The configuration for horizontal training is the same as vertical except training_mode and the data loader must point to horizontal split data. + +config_fed_server.json + +.. code-block:: json + + { + "format_version": 2, + "num_rounds": 3, + "workflows": [ + { + "id": "xgb_controller", + "path": "nvflare.app_opt.xgboost.histogram_based_v2.fed_controller.XGBFedController", + "args": { + "num_rounds": "{num_rounds}", + "training_mode": "horizontal_secure", + "xgb_options": { + "early_stopping_rounds": 2 + }, + "xgb_params": { + "max_depth": 3, + "eta": 0.1, + "objective": "binary:logistic", + "eval_metric": "auc", + "tree_method": "hist", + "nthread": 1 + }, + "client_ranks": { + "site-1": 0, + "site-2": 1 + }, + "in_process": true + } + } + ] + } + + + + +config_fed_client.json + +.. code-block:: json + + { + "format_version": 2, + "executors": [ + { + "tasks": [ + "config", + "start" + ], + "executor": { + "id": "Executor", + "path": "nvflare.app_opt.xgboost.histogram_based_v2.fed_executor.FedXGBHistogramExecutor", + "args": { + "data_loader_id": "dataloader", + "in_process": true + } + } + } + ], + "components": [ + { + "id": "dataloader", + "path": "nvflare.app_opt.xgboost.histogram_based_v2.secure_data_loader.SecureDataLoader", + "args": { + "rank": 0, + "folder": "/data/xgboost_secure/dataset/horizontal_xgb_data" + } + } + ] + } diff --git a/examples/getting_started/README.md b/examples/getting_started/README.md new file mode 100644 index 0000000000..bb49a50c88 --- /dev/null +++ b/examples/getting_started/README.md @@ -0,0 +1,29 @@ +# Getting Started with NVFlare +NVFlare is an open-source framework that allows researchers and data scientists to seamlessly move their +machine learning and deep learning workflows into a federated paradigm. + +### Basic Concepts +At the heart of NVFlare lies the concept of collaboration through "tasks." An FL controller assigns tasks +(e.g., training on local data) to one or more FL clients, processes returned results (e.g., model weight updates), +and may assign additional tasks based on these results and other factors (e.g., a pre-configured number of training rounds). +The clients run executors which can listen for tasks and perform the necessary computations locally, such as model training. +This task-based interaction repeats until the experiment’s objectives are met. + +We can also add data filters (for example, for [homomorphic encryption](https://www.usenix.org/conference/atc20/presentation/zhang-chengliang) +or [differential privacy filters](https://arxiv.org/abs/1910.00962)) to the task data +or results received or produced by the server or clients. + + + +### Examples +We provide several examples to quickly get you started using NVFlare's Job API. +Each example folder includes basic job configurations for running different FL algorithms. +Starting from [FedAvg](https://arxiv.org/abs/1602.05629), to more advanced ones, +such as [FedOpt](https://arxiv.org/abs/2003.00295), or [SCAFFOLD](https://arxiv.org/abs/1910.06378). + +### 1. [PyTorch Examples](./pt/README.md) +### 2. [Tensorflow Examples](./tf/README.md) +### 3. [Scikit-Learn Examples](./sklearn/README.md) + +> [!NOTE] +> More examples can be found at https://nvidia.github.io/NVFlare. diff --git a/examples/getting_started/pt/README.md b/examples/getting_started/pt/README.md new file mode 100644 index 0000000000..4509bb7635 --- /dev/null +++ b/examples/getting_started/pt/README.md @@ -0,0 +1,68 @@ +# Getting Started with NVFlare (PyTorch) +[](https://pytorch.org) + +We provide several examples to quickly get you started using NVFlare's Job API. +All examples in this folder are based on using [PyTorch](https://pytorch.org/) as the model training framework. +Furthermore, we support [PyTorch Lightning](https://lightning.ai). + +## Setup environment +First, install nvflare and dependencies: +```commandline +pip install -r requirements.txt +``` + +## Tutorials +A good starting point for understanding the Job API scripts and NVFlare components are the following tutorials. +### 1. [Federated averaging using script executor](./nvflare_pt_getting_started.ipynb) +Tutorial on [FedAvg](https://arxiv.org/abs/1602.05629) using the [Client API](https://nvflare.readthedocs.io/en/main/programming_guide/execution_api_type/client_api.html). + +### 2. [Federated averaging using script executor with Lightning API](./nvflare_lightning_getting_started.ipynb) +Tutorial on [FedAvg](https://arxiv.org/abs/1602.05629) using the [Lightning Client API](https://nvflare.readthedocs.io/en/main/programming_guide/execution_api_type/client_api.html#id4) + +## Examples +You can also run any of the below scripts directly using +```commandline +python "script_name.py" +``` +### 1. [Federated averaging using script executor](./fedavg_script_executor_cifar10.py) +Implementation of [FedAvg](https://arxiv.org/abs/1602.05629) using the [Client API](https://nvflare.readthedocs.io/en/main/programming_guide/execution_api_type/client_api.html). +```commandline +python fedavg_script_executor_cifar10.py +``` +### 2. [Federated averaging using script executor with Lightning API](./fedavg_script_executor_lightning_cifar10.py) +Implementation of [FedAvg](https://arxiv.org/abs/1602.05629) using the [Lightning Client API](https://nvflare.readthedocs.io/en/main/programming_guide/execution_api_type/client_api.html#id4) +```commandline +python fedavg_script_executor_lightning_cifar10.py +``` +### 3. [Federated averaging using the script executor for all clients](./fedavg_script_executor_cifar10_all.py) +Implementation of [FedAvg](https://arxiv.org/abs/1602.05629) using the [Client API](https://nvflare.readthedocs.io/en/main/programming_guide/execution_api_type/client_api.html). +Here, we deploy the same configuration to all clients. +```commandline +python fedavg_script_executor_cifar10_all.py +``` +### 4. [Federated averaging using script executor and differential privacy filter](./fedavg_script_executor_dp_filter_cifar10.py) +Implementation of [FedAvg](https://arxiv.org/abs/1602.05629) using the [Client API](https://nvflare.readthedocs.io/en/main/programming_guide/execution_api_type/client_api.html) +with additional [differential privacy filters](https://arxiv.org/abs/1910.00962) on the client side. +```commandline +python fedavg_script_executor_dp_filter_cifar10.py +``` +### 5. [Swarm learning using script executor](./swarm_script_executor_cifar10.py) +Implementation of [swarm learning](https://www.nature.com/articles/s41586-021-03583-3) using the [Client API](https://nvflare.readthedocs.io/en/main/programming_guide/execution_api_type/client_api.html) +```commandline +python swarm_script_executor_cifar10.py +``` +### 6. [Cyclic weight transfer using script executor](./cyclic_cc_script_executor_cifar10.py) +Implementation of [cyclic weight transfer](https://arxiv.org/abs/1709.05929) using the [Client API](https://nvflare.readthedocs.io/en/main/programming_guide/execution_api_type/client_api.html) +```commandline +python cyclic_cc_script_executor_cifar10.py +``` +### 7. [Federated averaging using model learning](./fedavg_model_learner_xsite_val_cifar10.py)) +Implementation of [FedAvg](https://arxiv.org/abs/1602.05629) using the [model learner class](https://nvflare.readthedocs.io/en/main/programming_guide/execution_api_type/model_learner.html), +followed by [cross site validation](https://nvflare.readthedocs.io/en/main/programming_guide/controllers/cross_site_model_evaluation.html) +for federated model evaluation. +```commandline +python fedavg_model_learner_xsite_val_cifar10.py +``` + +> [!NOTE] +> More examples can be found at https://nvidia.github.io/NVFlare. diff --git a/examples/getting_started/pt/src/train_eval_submit.py b/examples/getting_started/pt/src/train_eval_submit.py index 1f4cad1b88..0273652746 100644 --- a/examples/getting_started/pt/src/train_eval_submit.py +++ b/examples/getting_started/pt/src/train_eval_submit.py @@ -29,7 +29,7 @@ CIFAR10_ROOT = "/tmp/nvflare/data/cifar10" # (optional) We change to use GPU to speed things up. # if you want to use CPU, change DEVICE="cpu" -DEVICE = "cuda:0" +DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu" def define_parser(): diff --git a/examples/getting_started/pt/swarm_script_executor_cifar10.py b/examples/getting_started/pt/swarm_script_executor_cifar10.py index 4a04f10f40..db0092d83e 100644 --- a/examples/getting_started/pt/swarm_script_executor_cifar10.py +++ b/examples/getting_started/pt/swarm_script_executor_cifar10.py @@ -47,19 +47,23 @@ executor = ScriptExecutor(task_script_path=train_script) job.to(executor, f"site-{i}", gpu=0, tasks=["train", "validate", "submit_model"]) - client_controller = SwarmClientController() - job.to(client_controller, f"site-{i}", tasks=["swarm_*"]) - - client_controller = CrossSiteEvalClientController() - job.to(client_controller, f"site-{i}", tasks=["cse_*"]) - # In swarm learning, each client acts also as an aggregator aggregator = InTimeAccumulateWeightedAggregator(expected_data_kind=DataKind.WEIGHTS) - job.to(aggregator, f"site-{i}") # In swarm learning, each client uses a model persistor and shareable_generator - job.to(PTFileModelPersistor(model=Net()), f"site-{i}") - job.to(SimpleModelShareableGenerator(), f"site-{i}") + persistor = PTFileModelPersistor(model=Net()) + shareable_generator = SimpleModelShareableGenerator() + + persistor_id = job.as_id(persistor) + client_controller = SwarmClientController( + aggregator_id=job.as_id(aggregator), + persistor_id=persistor_id, + shareable_generator_id=job.as_id(shareable_generator), + ) + job.to(client_controller, f"site-{i}", tasks=["swarm_*"]) + + client_controller = CrossSiteEvalClientController(persistor_id=persistor_id) + job.to(client_controller, f"site-{i}", tasks=["cse_*"]) # job.export_job("/tmp/nvflare/jobs/job_config") job.simulator_run("/tmp/nvflare/jobs/workdir") diff --git a/examples/getting_started/sklearn/README.md b/examples/getting_started/sklearn/README.md new file mode 100644 index 0000000000..9c73e9fe34 --- /dev/null +++ b/examples/getting_started/sklearn/README.md @@ -0,0 +1,25 @@ +# Getting Started with NVFlare (scikit-learn) +[](https://scikit-learn.org/) + +We provide examples to quickly get you started using NVFlare's Job API. +All examples in this folder are based on using [scikit-learn](https://scikit-learn.org/), a popular library for general machine learning with Python. + +## Setup environment +First, install nvflare and dependencies: +```commandline +pip install -r requirements.txt +``` + +## Examples +You can also run any of the below scripts directly using +```commandline +python "script_name.py" +``` +### 1. [Federated K-Means Clustering](./kmeans_script_executor_higgs.py) +Implementation of [K-Means](https://arxiv.org/abs/1602.05629). For more details see this [example](../../advanced/sklearn-kmeans/README.md). +```commandline +python kmeans_script_executor_higgs.py +``` + +> [!NOTE] +> More examples can be found at https://nvidia.github.io/NVFlare. diff --git a/examples/getting_started/tf/README.md b/examples/getting_started/tf/README.md new file mode 100644 index 0000000000..ab5a57127e --- /dev/null +++ b/examples/getting_started/tf/README.md @@ -0,0 +1,189 @@ +# Getting Started with NVFlare (TensorFlow) +[](https://tensorflow.org/) + +We provide several examples to quickly get you started using NVFlare's Job API. +All examples in this folder are based on using [TensorFlow](https://tensorflow.org/) as the model training framework. + +## Simulated Federated Learning with CIFAR10 Using Tensorflow + +This example shows `Tensorflow`-based classic Federated Learning +algorithms, namely FedAvg and FedOpt on CIFAR10 +dataset. This example is analogous to [the example using `Pytorch` +backend](https://github.com/NVIDIA/NVFlare/tree/main/examples/advanced/cifar10/cifar10-sim) +on the same dataset, where same experiments +were conducted and analyzed. You should expect the same +experimental results when comparing this example with the `Pytorch` one. + +In this example, the latest Client APIs were used to implement +client-side training logics (details in file +[`cifar10_tf_fl_alpha_split.py`](src/cifar10_tf_fl_alpha_split.py)), +and the new +[`FedJob`](https://github.com/NVIDIA/NVFlare/blob/main/nvflare/job_config/fed_job.py#L106) +APIs were used to programmatically set up an +`nvflare` job to be exported or ran by simulator (details in file +[`tf_fl_script_executor_cifar10.py`](tf_fl_script_executor_cifar10.py)), +alleviating the need of writing job config files, simplifying +development process. + +Before continuing with the following sections, you can first refer to +the [getting started notebook](nvflare_tf_getting_started.ipynb) +included under this folder, to learn more about the implementation +details, with an example walkthrough of FedAvg using a small +Tensorflow model. + +## 1. Install requirements + +Install required packages +``` +pip install --upgrade pip +pip install -r ./requirements.txt +``` + +> **_NOTE:_** We recommend either using a containerized deployment or virtual environment, +> please refer to [getting started](https://nvflare.readthedocs.io/en/latest/getting_started.html). + + +## 2. Run experiments + +This example uses simulator to run all experiments. The script +[`tf_fl_script_executor_cifar10.py`](tf_fl_script_executor_cifar10.py) +is the main script to be used to launch different experiments with +different arguments (see sections below for details). A script +[`run_jobs.sh`](run_jobs.sh) is also provided to run all experiments +described below at once: +``` +bash ./run_jobs.sh +``` +The CIFAR10 dataset will be downloaded when running any experiment for +the first time. `Tensorboard` summary logs will be generated during +any experiment, and you can use `Tensorboard` to visualize the +training and validation process as the experiment runs. Data split +files, summary logs and results will be saved in a workspace +directory, which defaults to `/tmp` and can be configured by setting +`--workspace` argument of the `tf_fl_script_executor_cifar10.py` +script. + +> [!WARNING] +> If you are using GPU, make sure to set the following +> environment variables before running a training job, to prevent +> `Tensoflow` from allocating full GPU memory all at once: +> `export TF_FORCE_GPU_ALLOW_GROWTH=true && export +> TF_GPU_ALLOCATOR=cuda_malloc_asyncp` + +The set-up of all experiments in this example are kept the same as +[the example using `Pytorch` +backend](https://github.com/NVIDIA/NVFlare/tree/main/examples/advanced/cifar10/cifar10-sim). Refer +to the `Pytorch` example for more details. Similar to the Pytorch +example, we here also use Dirichelet sampling on CIFAR10 data labels +to simulate data heterogeneity among data splits for different client +sites, controlled by an alpha value, ranging from 0 (not including 0) +to 1. A high alpha value indicates less data heterogeneity, i.e., an +alpha value equal to 1.0 would result in homogeneous data distribution +among different splits. + +### 2.1 Centralized training + +To simulate a centralized training baseline, we run FedAvg algorithm +with 1 client for 25 rounds, where each round consists of one single epoch. + +``` +python ./tf_fl_script_executor_cifar10.py \ + --algo centralized \ + --n_clients 1 \ + --num_rounds 25 \ + --batch_size 64 \ + --epochs 1 \ + --alpha 0.0 +``` +Note, here `--alpha 0.0` is a placeholder value used to disable data +splits for centralized training. + +### 2.2 FedAvg with different data heterogeneity (alpha values) + +Here we run FedAvg for 50 rounds, each round with 4 local epochs. This +corresponds roughly to the same number of iterations across clients as +in the centralized baseline above (50*4 divided by 8 clients is 25): +``` +for alpha in 1.0 0.5 0.3 0.1; do + + python ./tf_fl_script_executor_cifar10.py \ + --algo fedavg \ + --n_clients 8 \ + --num_rounds 50 \ + --batch_size 64 \ + --epochs 4 \ + --alpha $alpha + +done +``` + +### 2.3 Advanced FL algorithms (FedOpt) + +Next, let's try some different FL algorithms on a more heterogeneous split: + +[FedOpt](https://arxiv.org/abs/2003.00295) uses optimizers on server +side to update the global model from client-side gradients. Here we +use SGD with momentum and cosine learning rate decay: +``` +python ./tf_fl_script_executor_cifar10.py \ + --algo fedopt \ + --n_clients 8 \ + --num_rounds 50 \ + --batch_size 64 \ + --epochs 4 \ + --alpha 0.1 +``` + + +## 3. Results + +Now let's compare experimental results. + +### 3.1 Centralized training vs. FedAvg for homogeneous split +Let's first compare FedAvg with homogeneous data split +(i.e. `alpha=1.0`) and centralized training. As can be seen from the +figure and table below, FedAvg can achieve similar performance to +centralized training under homogeneous data split, i.e., when there is +no difference in data distributions among different clients. + +| Config | Alpha | Val score | +|-----------------|-------|-----------| +| cifar10_central | n.a. | 0.8758 | +| cifar10_fedavg | 1.0 | 0.8839 | + + + +### 3.2 Impact of client data heterogeneity + +Here we compare the impact of data heterogeneity by varying the +`alpha` value, where lower values cause higher heterogeneity. As can +be observed in the table below, performance of the FedAvg decreases +as data heterogeneity becomes higher. + +| Config | Alpha | Val score | +| ----------- | ----------- | ----------- | +| cifar10_fedavg | 1.0 | 0.8838 | +| cifar10_fedavg | 0.5 | 0.8685 | +| cifar10_fedavg | 0.3 | 0.8323 | +| cifar10_fedavg | 0.1 | 0.7903 | + + + +### 3.3 Impact of different FL algorithms + +Lastly, we compare the performance of different FL algorithms, with +`alpha` value fixed to 0.1, i.e., a high client data heterogeneity. We +can observe from the figure below that, FedOpt achieves better +performance, with better convergence rates compared to FedAvg with the +same alpha setting. + +| Config | Alpha | Val score | +| ----------- | ----------- | ----------- | +| cifar10_fedavg | 0.1 | 0.7903 | +| cifar10_fedopt | 0.1 | 0.8145 | + + + +> [!NOTE] +> More examples can be found at https://nvidia.github.io/NVFlare. diff --git a/examples/getting_started/tf/figs/fedavg-diff-algos.png b/examples/getting_started/tf/figs/fedavg-diff-algos.png new file mode 100755 index 0000000000..ba58185859 Binary files /dev/null and b/examples/getting_started/tf/figs/fedavg-diff-algos.png differ diff --git a/examples/getting_started/tf/figs/fedavg-diff-alphas.png b/examples/getting_started/tf/figs/fedavg-diff-alphas.png new file mode 100755 index 0000000000..24cebdaee5 Binary files /dev/null and b/examples/getting_started/tf/figs/fedavg-diff-alphas.png differ diff --git a/examples/getting_started/tf/figs/fedavg-vs-centralized.png b/examples/getting_started/tf/figs/fedavg-vs-centralized.png new file mode 100755 index 0000000000..295645c675 Binary files /dev/null and b/examples/getting_started/tf/figs/fedavg-vs-centralized.png differ diff --git a/examples/getting_started/tf/nvflare_tf_getting_started.ipynb b/examples/getting_started/tf/nvflare_tf_getting_started.ipynb index af87d0af9b..17374b9356 100644 --- a/examples/getting_started/tf/nvflare_tf_getting_started.ipynb +++ b/examples/getting_started/tf/nvflare_tf_getting_started.ipynb @@ -55,7 +55,7 @@ "outputs": [], "source": [ "! pip install --ignore-installed blinker\n", - "! pip install nvflare~=2.5.0rc tensorflow" + "! pip install -r ./requirements.txt" ] }, { @@ -410,6 +410,16 @@ "source": [ "! nvflare simulator -w /tmp/nvflare/jobs/workdir -n 2 -t 2 -gpu 0 /tmp/nvflare/jobs/job_config/cifar10_tf_fedavg" ] + }, + { + "cell_type": "markdown", + "id": "387662f4-7d05-4840-bcc7-a2523e03c2c2", + "metadata": {}, + "source": [ + "#### 8. Next steps\n", + "\n", + "Continue with the steps described in the [README.md](README.md) to run more experiments with a more complex model and more advanced FL algorithms. " + ] } ], "metadata": { @@ -428,7 +438,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.7" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/examples/getting_started/tf/run_jobs.sh b/examples/getting_started/tf/run_jobs.sh new file mode 100755 index 0000000000..7195e71474 --- /dev/null +++ b/examples/getting_started/tf/run_jobs.sh @@ -0,0 +1,64 @@ +#!/usr/bin/env bash +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +export TF_FORCE_GPU_ALLOW_GROWTH=true +export TF_GPU_ALLOCATOR=cuda_malloc_asyncp + + +# You can change GPU index if multiple GPUs are available +GPU_INDX=0 + +# You can change workspace - where results and artefact will be saved. +WORKSPACE=/tmp + +# Run centralized training job +python ./tf_fl_script_executor_cifar10.py \ + --algo centralized \ + --n_clients 1 \ + --num_rounds 25 \ + --batch_size 64 \ + --epochs 1 \ + --alpha 0.0 \ + --gpu $GPU_INDX \ + --workspace $WORKSPACE + + +# Run FedAvg with different alpha values +for alpha in 1.0 0.5 0.3 0.1; do + + python ./tf_fl_script_executor_cifar10.py \ + --algo fedavg \ + --n_clients 8 \ + --num_rounds 50 \ + --batch_size 64 \ + --epochs 4 \ + --alpha $alpha \ + --gpu $GPU_INDX \ + --workspace $WORKSPACE + +done + + +# Run FedOpt job +python ./tf_fl_script_executor_cifar10.py \ + --algo fedopt \ + --n_clients 8 \ + --num_rounds 50 \ + --batch_size 64 \ + --epochs 4 \ + --alpha 0.1 \ + --gpu $GPU_INDX \ + --workspace $WORKSPACE diff --git a/examples/getting_started/tf/src/cifar10_data_split.py b/examples/getting_started/tf/src/cifar10_data_split.py new file mode 100644 index 0000000000..1dd05f6385 --- /dev/null +++ b/examples/getting_started/tf/src/cifar10_data_split.py @@ -0,0 +1,125 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This Dirichlet sampling strategy for creating a heterogeneous partition is adopted +# from FedMA (https://github.com/IBM/FedMA). + +# MIT License + +# Copyright (c) 2020 International Business Machines + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + + +import json +import os + +import numpy as np +from tensorflow.keras import datasets + + +def cifar10_split(split_dir: str = None, num_sites: int = 8, alpha: float = 0.5, seed: int = 0): + if split_dir is None: + raise ValueError("You need to define a valid `split_dir` for splitting the data.") + if not os.path.isabs(split_dir): + raise ValueError("`split_dir` needs to be absolute path.") + if alpha < 0.0: + raise ValueError(f"Alpha should be larger or equal 0.0 but was" f" {alpha}!") + + np.random.seed(seed) + + train_idx_paths = [] + + print(f"Partition CIFAR-10 dataset into {num_sites} sites with Dirichlet sampling under alpha {alpha}") + site_idx, class_sum = _partition_data(num_sites, alpha) + + # write to files + if not os.path.isdir(split_dir): + os.makedirs(split_dir) + sum_file_name = os.path.join(split_dir, "summary.txt") + with open(sum_file_name, "w") as sum_file: + sum_file.write(f"Number of clients: {num_sites} \n") + sum_file.write(f"Dirichlet sampling parameter: {alpha} \n") + sum_file.write("Class counts for each client: \n") + sum_file.write(json.dumps(class_sum)) + + site_file_path = os.path.join(split_dir, "site-") + for site in range(num_sites): + site_file_name = site_file_path + str(site + 1) + ".npy" + print(f"Save split index {site+1} of {num_sites} to {site_file_name}") + np.save(site_file_name, np.array(site_idx[site])) + train_idx_paths.append(site_file_name) + + return train_idx_paths + + +def _get_site_class_summary(train_label, site_idx): + class_sum = {} + + for site, data_idx in site_idx.items(): + unq, unq_cnt = np.unique(train_label[data_idx], return_counts=True) + tmp = {int(unq[i]): int(unq_cnt[i]) for i in range(len(unq))} + class_sum[site] = tmp + return class_sum + + +def _partition_data(num_sites, alpha): + # only training label is needed for doing split + (train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data() + + min_size = 0 + K = 10 + N = train_labels.shape[0] + site_idx = {} + + # split + while min_size < 10: + idx_batch = [[] for _ in range(num_sites)] + # for each class in the dataset + for k in range(K): + idx_k = np.where(train_labels == k)[0] + np.random.shuffle(idx_k) + proportions = np.random.dirichlet(np.repeat(alpha, num_sites)) + # Balance + proportions = np.array([p * (len(idx_j) < N / num_sites) for p, idx_j in zip(proportions, idx_batch)]) + proportions = proportions / proportions.sum() + proportions = (np.cumsum(proportions) * len(idx_k)).astype(int)[:-1] + idx_batch = [idx_j + idx.tolist() for idx_j, idx in zip(idx_batch, np.split(idx_k, proportions))] + min_size = min([len(idx_j) for idx_j in idx_batch]) + + # shuffle + for j in range(num_sites): + np.random.shuffle(idx_batch[j]) + site_idx[j] = idx_batch[j] + + # collect class summary + class_sum = _get_site_class_summary(train_labels, site_idx) + + return site_idx, class_sum diff --git a/examples/getting_started/tf/src/cifar10_tf_fl.py b/examples/getting_started/tf/src/cifar10_tf_fl.py index 15512a2eb8..5058a4025b 100644 --- a/examples/getting_started/tf/src/cifar10_tf_fl.py +++ b/examples/getting_started/tf/src/cifar10_tf_fl.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + import tensorflow as tf from tensorflow.keras import datasets from tf_net import TFNet @@ -23,6 +24,9 @@ def main(): + # (2) initializes NVFlare client API + flare.init() + (train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data() # Normalize pixel values to be between 0 and 1 @@ -35,9 +39,6 @@ def main(): ) model.summary() - # (2) initializes NVFlare client API - flare.init() - # (3) gets FLModel from NVFlare while flare.is_running(): input_model = flare.receive() @@ -54,7 +55,7 @@ def main(): # (5) evaluate aggregated/received model _, test_global_acc = model.evaluate(test_images, test_labels, verbose=2) print( - f"Accuracy of the received model on round {input_model.current_round} on the 10000 test images: {test_global_acc * 100} %" + f"Accuracy of the received model on round {input_model.current_round} on the {len(test_images)} test images: {test_global_acc * 100} %" ) model.fit(train_images, train_labels, epochs=1, validation_data=(test_images, test_labels)) @@ -64,7 +65,7 @@ def main(): model.save_weights(PATH) _, test_acc = model.evaluate(test_images, test_labels, verbose=2) - print(f"Accuracy of the model on the 10000 test images: {test_acc * 100} %") + print(f"Accuracy of the model on the {len(test_images)} test images: {test_acc * 100} %") # (6) construct trained FL model (A dict of {layer name: layer weights} from the keras model) output_model = flare.FLModel( diff --git a/examples/getting_started/tf/src/cifar10_tf_fl_alpha_split.py b/examples/getting_started/tf/src/cifar10_tf_fl_alpha_split.py new file mode 100644 index 0000000000..fbad375336 --- /dev/null +++ b/examples/getting_started/tf/src/cifar10_tf_fl_alpha_split.py @@ -0,0 +1,207 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse + +import numpy as np +import tensorflow as tf +from tensorflow.keras import datasets, losses +from tf_net import ModerateTFNet + +# (1) import nvflare client API +import nvflare.client as flare + +PATH = "./tf_model.weights.h5" + + +def preprocess_dataset(dataset, is_training, batch_size=1): + """ + Apply pre-processing transformations to CIFAR10 dataset. + + Same pre-processings are used as in the Pytorch tutorial + on CIFAR10: https://github.com/NVIDIA/NVFlare/tree/main/examples/advanced/cifar10/cifar10-sim + + Training time pre-processings are (in-order): + - Image padding with 4 pixels in "reflect" mode on each side + - RandomCrop of 32 x 32 images + - RandomHorizontalFlip + - Normalize to [0, 1]: dividing pixels values by given CIFAR10 data mean & std + - Random shuffle + + Testing/Validation time pre-processings are: + - Normalize: dividing pixels values by 255 + + Args + ---------- + dataset: tf.data.Datset + Tensorflow Dataset + + is_training: bool + Boolean flag indicating if current phase is training phase. + + batch_size: int + Batch size + + Returns + ---------- + tf.data.Dataset + Tensorflow Dataset with pre-processings applied. + + """ + # Values from: https://github.com/NVIDIA/NVFlare/blob/main/examples/advanced/cifar10/pt/learners/cifar10_model_learner.py#L147 + mean_cifar10 = tf.constant([125.3, 123.0, 113.9], dtype=tf.float32) + std_cifar10 = tf.constant([63.0, 62.1, 66.7], dtype=tf.float32) + + if is_training: + + # Padding each dimension by 4 pixels each side + dataset = dataset.map( + lambda image, label: ( + tf.stack( + [ + tf.pad(tf.squeeze(t, [2]), [[4, 4], [4, 4]], mode="REFLECT") + for t in tf.split(image, num_or_size_splits=3, axis=2) + ], + axis=2, + ), + label, + ) + ) + # Random crop of 32 x 32 x 3 + dataset = dataset.map(lambda image, label: (tf.image.random_crop(image, size=(32, 32, 3)), label)) + # Random horizontal flip + dataset = dataset.map(lambda image, label: (tf.image.random_flip_left_right(image), label)) + # Normalize by dividing by given mean & std + dataset = dataset.map(lambda image, label: ((tf.cast(image, tf.float32) - mean_cifar10) / std_cifar10, label)) + # Random shuffle + dataset = dataset.shuffle(len(dataset), reshuffle_each_iteration=True) + # Convert to batches. + return dataset.batch(batch_size) + + else: + + # For validation / test only do normalization. + return dataset.map( + lambda image, label: ((tf.cast(image, tf.float32) - mean_cifar10) / std_cifar10, label) + ).batch(batch_size) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--batch_size", type=int, required=True) + parser.add_argument("--epochs", type=int, required=True) + parser.add_argument("--train_idx_path", type=str, required=True) + args = parser.parse_args() + + (train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data() + + # Use alpha-split per-site data to simulate data heteogeniety, + # only if if train_idx_path is not None. + # + if args.train_idx_path != "None": + + print(f"Loading train indices from {args.train_idx_path}") + train_idx = np.load(args.train_idx_path) + train_images = train_images[train_idx] + train_labels = train_labels[train_idx] + + unq, unq_cnt = np.unique(train_labels, return_counts=True) + print( + ( + f"Loaded {len(train_idx)} training indices from {args.train_idx_path} " + "with label distribution:\nUnique labels: {unq}\nUnique Counts: {unq_cnt}" + ) + ) + + # Convert training & testing data to datasets + train_ds = tf.data.Dataset.from_tensor_slices((train_images, train_labels)) + test_ds = tf.data.Dataset.from_tensor_slices((test_images, test_labels)) + + # Preprocessing + train_ds = preprocess_dataset(train_ds, is_training=True, batch_size=args.batch_size) + test_ds = preprocess_dataset(test_ds, is_training=False, batch_size=args.batch_size) + + model = ModerateTFNet() + model.build(input_shape=(None, 32, 32, 3)) + + # Tensorboard logs for each local training epoch + callbacks = [tf.keras.callbacks.TensorBoard(log_dir="./logs/epochs", write_graph=False)] + # Tensorboard logs for each aggregation run + tf_summary_writer = tf.summary.create_file_writer(logdir="./logs/rounds") + + # Define loss function. + loss = losses.SparseCategoricalCrossentropy(from_logits=True) + + model.compile(optimizer=tf.keras.optimizers.SGD(learning_rate=0.01, momentum=0.9), loss=loss, metrics=["accuracy"]) + model.summary() + + # (2) initializes NVFlare client API + flare.init() + + while flare.is_running(): + # (3) receives FLModel from NVFlare + input_model = flare.receive() + print(f"current_round={input_model.current_round}") + + # (optional) print system info + system_info = flare.system_info() + print(f"NVFlare system info: {system_info}") + + # (4) loads model from NVFlare + for k, v in input_model.params.items(): + model.get_layer(k).set_weights(v) + + # (5) evaluate aggregated/received model + _, test_global_acc = model.evaluate(x=test_ds, verbose=2) + + with tf_summary_writer.as_default(): + tf.summary.scalar("global_model_accuracy", test_global_acc, input_model.current_round) + print( + f"Accuracy of the received model on round {input_model.current_round} on the {len(test_images)} test images: {test_global_acc * 100} %" + ) + + start_epoch = args.epochs * input_model.current_round + end_epoch = start_epoch + args.epochs + + print(f"Train from epoch {start_epoch} to {end_epoch}") + model.fit( + x=train_ds, + epochs=end_epoch, + validation_data=test_ds, + callbacks=callbacks, + initial_epoch=start_epoch, + validation_freq=1, + ) + + print("Finished Training") + + model.save_weights(PATH) + + _, test_acc = model.evaluate(x=test_ds, verbose=2) + + with tf_summary_writer.as_default(): + tf.summary.scalar("local_model_accuracy", test_acc, input_model.current_round) + print(f"Accuracy of the model on the {len(test_images)} test images: {test_acc * 100} %") + + # (6) construct trained FL model (A dict of {layer name: layer weights} from the keras model) + output_model = flare.FLModel( + params={layer.name: layer.get_weights() for layer in model.layers}, metrics={"accuracy": test_global_acc} + ) + # (7) send model back to NVFlare + flare.send(output_model) + + +if __name__ == "__main__": + main() diff --git a/examples/getting_started/tf/src/tf_net.py b/examples/getting_started/tf/src/tf_net.py index 95bf8ab462..bdd8d29147 100644 --- a/examples/getting_started/tf/src/tf_net.py +++ b/examples/getting_started/tf/src/tf_net.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -29,3 +29,37 @@ def __init__(self, input_shape=(None, 32, 32, 3)): self.add(layers.Flatten()) self.add(layers.Dense(64, activation="relu")) self.add(layers.Dense(10)) + + +class ModerateTFNet(models.Sequential): + # Follow ModerateCNN architecture from cifar10_nets.py + def __init__(self, input_shape=(None, 32, 32, 3)): + super().__init__() + self._input_shape = input_shape + + # Do not specify input as we will use delayed built only during runtime of the model + # self.add(layers.Input(shape=(32, 32, 3))) + + # Conv Layer block 1 + self.add(layers.Conv2D(32, (3, 3), activation="relu", padding="same")) + self.add(layers.Conv2D(64, (3, 3), activation="relu", padding="same")) + self.add(layers.MaxPooling2D((2, 2))) + + # Conv Layer block 2 + self.add(layers.Conv2D(128, (3, 3), activation="relu", padding="same")) + self.add(layers.Conv2D(128, (3, 3), activation="relu", padding="same")) + self.add(layers.MaxPooling2D((2, 2))) + self.add(layers.Dropout(rate=0.05)) + + # Conv Layer block 3 + self.add(layers.Conv2D(256, (3, 3), activation="relu", padding="same")) + self.add(layers.Conv2D(256, (3, 3), activation="relu", padding="same")) + self.add(layers.MaxPooling2D((2, 2))) + self.add(layers.Flatten()) + + # FC Layer + self.add(layers.Dropout(rate=0.1)) + self.add(layers.Dense(512, activation="relu")) + self.add(layers.Dense(512, activation="relu")) + self.add(layers.Dropout(rate=0.1)) + self.add(layers.Dense(10)) diff --git a/examples/getting_started/tf/tf_fl_script_executor_cifar10.py b/examples/getting_started/tf/tf_fl_script_executor_cifar10.py new file mode 100644 index 0000000000..0add8dc257 --- /dev/null +++ b/examples/getting_started/tf/tf_fl_script_executor_cifar10.py @@ -0,0 +1,140 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import multiprocessing + +import tensorflow as tf +from src.cifar10_data_split import cifar10_split +from src.tf_net import ModerateTFNet + +from nvflare import FedJob, ScriptExecutor + +gpu_devices = tf.config.experimental.list_physical_devices("GPU") +for device in gpu_devices: + tf.config.experimental.set_memory_growth(device, True) + + +CENTRALIZED_ALGO = "centralized" +FEDAVG_ALGO = "fedavg" +FEDOPT_ALGO = "fedopt" + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--algo", + type=str, + required=True, + ) + parser.add_argument( + "--n_clients", + type=int, + default=8, + ) + parser.add_argument( + "--num_rounds", + type=int, + default=50, + ) + parser.add_argument( + "--batch_size", + type=int, + default=64, + ) + parser.add_argument( + "--epochs", + type=int, + default=4, + ) + parser.add_argument( + "--alpha", + type=float, + default=1.0, + ) + parser.add_argument( + "--workspace", + type=str, + default="/tmp", + ) + parser.add_argument( + "--gpu", + type=int, + default=0, + ) + + args = parser.parse_args() + multiprocessing.set_start_method("spawn") + + supported_algos = (CENTRALIZED_ALGO, FEDAVG_ALGO, FEDOPT_ALGO) + + if args.algo not in supported_algos: + raise ValueError(f"--algo should be one of: {supported_algos}, got: {args.algo}") + + train_script = "src/cifar10_tf_fl_alpha_split.py" + train_split_root = ( + f"{args.workspace}/cifar10_splits/clients{args.n_clients}_alpha{args.alpha}" # avoid overwriting results + ) + + # Prepare data splits + if args.alpha > 0.0: + + # Do alpha splitting if alpha value > 0.0 + print(f"preparing CIFAR10 and doing alpha split with alpha = {args.alpha}") + train_idx_paths = cifar10_split(num_sites=args.n_clients, alpha=args.alpha, split_dir=train_split_root) + + print(train_idx_paths) + else: + train_idx_paths = [None for __ in range(args.n_clients)] + + # Define job + job = FedJob(name=f"cifar10_tf_{args.algo}_alpha{args.alpha}") + + # Define the controller workflow and send to server + controller = None + task_script_args = f"--batch_size {args.batch_size} --epochs {args.epochs}" + + if args.algo == FEDAVG_ALGO or args.algo == CENTRALIZED_ALGO: + from nvflare import FedAvg + + controller = FedAvg( + num_clients=args.n_clients, + num_rounds=args.num_rounds, + ) + + elif args.algo == FEDOPT_ALGO: + from nvflare.app_opt.tf.fedopt_ctl import FedOpt + + controller = FedOpt( + num_clients=args.n_clients, + num_rounds=args.num_rounds, + ) + + job.to(controller, "server") + + # Define the initial global model and send to server + job.to(ModerateTFNet(input_shape=(None, 32, 32, 3)), "server") + + # Add clients + for i, train_idx_path in enumerate(train_idx_paths): + curr_task_script_args = task_script_args + f" --train_idx_path {train_idx_path}" + executor = ScriptExecutor(task_script_path=train_script, task_script_args=curr_task_script_args) + job.to(executor, f"site-{i+1}", gpu=args.gpu) + + # Can export current job to folder. + # job.export_job(f"{args.workspace}/nvflare/jobs/job_config") + + # Here we launch the job using simulator. + job.simulator_run(f"{args.workspace}/nvflare/jobs/{job.name}") diff --git a/examples/hello-world/ml-to-fl/np/code/train_loop.py b/examples/hello-world/ml-to-fl/np/code/train_loop.py index 442aacf9bb..e809268a57 100755 --- a/examples/hello-world/ml-to-fl/np/code/train_loop.py +++ b/examples/hello-world/ml-to-fl/np/code/train_loop.py @@ -34,11 +34,11 @@ def main(): # get system information sys_info = flare.system_info() - print(f"system info is: {sys_info}") + print(f"system info is: {sys_info}", flush=True) while flare.is_running(): input_model = flare.receive() - print(f"received weights is: {input_model.params}") + print(f"received weights is: {input_model.params}", flush=True) input_numpy_array = input_model.params["numpy_key"] @@ -49,11 +49,11 @@ def main(): metrics = evaluate(input_numpy_array) sys_info = flare.system_info() - print(f"system info is: {sys_info}") - print(f"finish round: {input_model.current_round}") + print(f"system info is: {sys_info}", flush=True) + print(f"finish round: {input_model.current_round}", flush=True) # send back the model - print(f"send back: {output_numpy_array}") + print(f"send back: {output_numpy_array}", flush=True) flare.send( flare.FLModel( params={"numpy_key": output_numpy_array}, diff --git a/examples/tutorials/setup_poc.ipynb b/examples/tutorials/setup_poc.ipynb index d1a1aa0007..d07c579882 100644 --- a/examples/tutorials/setup_poc.ipynb +++ b/examples/tutorials/setup_poc.ipynb @@ -81,12 +81,12 @@ "If you prefer not to use environment variable, you can do the followings: \n", "\n", "```\n", - "! nvflare config -pw /tmp/nvflare/poc\n", + "! nvflare config -pw /tmp/nvflare/poc --job_templates_dir ../../job_templates\n", "\n", "```\n", "or \n", "```\n", - "! nvflare config -poc_workspace_dir /tmp/nvflare/poc\n", + "! nvflare config --poc_workspace_dir /tmp/nvflare/poc --job_templates_dir ../../job_templates\n", "```" ] }, @@ -99,7 +99,7 @@ }, "outputs": [], "source": [ - "! nvflare config -pw /tmp/nvflare/poc" + "! nvflare config -pw /tmp/nvflare/poc --job_templates_dir ../../job_templates" ] }, { diff --git a/integration/monai/examples/spleen_ct_segmentation_local/README.md b/integration/monai/examples/spleen_ct_segmentation_local/README.md index 42772aa9e9..44b28e2670 100644 --- a/integration/monai/examples/spleen_ct_segmentation_local/README.md +++ b/integration/monai/examples/spleen_ct_segmentation_local/README.md @@ -160,17 +160,17 @@ Experiment tracking for the FLARE-MONAI integration now uses `NVFlareStatsHandle In this example, the `spleen_ct_segmentation_local` job is configured to automatically log metrics to MLflow through the FL server. -- The `config_fed_client.json` contains the `NVFlareStatsHandler`, `MetricsSender`, and `MetricRelay` (with their respective pipes) to send the metrics to the server side as federated events. -- Then in `config_fed_server.json`, the `MLflowReceiver` is configured for the server to write the results to the default MLflow tracking server URI. +- The `config_fed_client.conf` contains the `NVFlareStatsHandler`, `MetricsSender`, and `MetricRelay` (with their respective pipes) to send the metrics to the server side as federated events. +- Then in `config_fed_server.conf`, the `MLflowReceiver` is configured for the server to write the results to the MLflow tracking server URI `http://127.0.0.1:5000`. -With this configuration the MLflow tracking server must be started before running the job: +We need to start MLflow tracking server before running this job: ``` mlflow server ``` > **_NOTE:_** The receiver on the server side can be easily configured to support other experiment tracking formats. - In addition to the `MLflowReceiver`, the `WandBReceiver` and `TBAnalyticsReceiver` can also be used in `config_fed_server.json` for Tensorboard and Weights & Biases experiment tracking streaming to the server. + In addition to the `MLflowReceiver`, the `WandBReceiver` and `TBAnalyticsReceiver` can also be used in `config_fed_server.conf` for Tensorboard and Weights & Biases experiment tracking streaming to the server. Next, we can submit the job. @@ -219,10 +219,16 @@ nvflare job submit -j jobs/spleen_ct_segementation_he ### 5.4 MLflow experiment tracking results -To view the results, you can access the MLflow dashboard in your browser using the default tracking uri `http://127.0.0.1:5000`. - -> **_NOTE:_** To write the results to the server workspace instead of using the MLflow server, users can remove the `tracking_uri` argument from the `MLflowReceiver` configuration and instead view the results by running `mlflow ui --port 5000` in the directory that contains the `mlruns/` directory in the server workspace. +To view the results, you can access the MLflow dashboard in your browser using the tracking uri `http://127.0.0.1:5000`. Once the training is started, you can see the experiment curves for the local clients in the current run on the MLflow dashboard. - \ No newline at end of file + + + +> **_NOTE:_** If you prefer not to start the MLflow server before federated training, +> you can alternatively choose to write the metrics streaming results to the server's +> job workspace directory. Remove the tracking_uri argument from the MLflowReceiver +> configuration. After the job finishes, download the server job workspace and unzip it. +> You can view the results by running mlflow ui --port 5000 in the directory containing +> the mlruns/ directory within the server job workspace. diff --git a/integration/monai/setup.py b/integration/monai/setup.py index a31011b3eb..8b41d0a099 100644 --- a/integration/monai/setup.py +++ b/integration/monai/setup.py @@ -24,14 +24,14 @@ release = os.environ.get("MONAI_NVFL_RELEASE") if release == "1": package_name = "monai-nvflare" - version = "0.2.4" + version = "0.2.9" else: package_name = "monai-nvflare-nightly" today = datetime.date.today().timetuple() year = today[0] % 1000 month = today[1] day = today[2] - version = f"0.2.3.{year:02d}{month:02d}{day:02d}" + version = f"0.2.9.{year:02d}{month:02d}{day:02d}" setup( name=package_name, @@ -57,5 +57,5 @@ long_description=long_description, long_description_content_type="text/markdown", python_requires=">=3.8,<3.11", - install_requires=["monai>=1.3.0", "nvflare==2.4.0rc6"], + install_requires=["monai>=1.3.1", "nvflare~=2.5.0rc1"], ) diff --git a/integration/xgboost/encryption_plugins/.editorconfig b/integration/xgboost/encryption_plugins/.editorconfig new file mode 100644 index 0000000000..97a7bc133a --- /dev/null +++ b/integration/xgboost/encryption_plugins/.editorconfig @@ -0,0 +1,11 @@ +root = true + +[*] +charset=utf-8 +indent_style = space +indent_size = 2 +insert_final_newline = true + +[*.py] +indent_style = space +indent_size = 4 diff --git a/integration/xgboost/encryption_plugins/CMakeLists.txt b/integration/xgboost/encryption_plugins/CMakeLists.txt new file mode 100644 index 0000000000..f5d71dd61c --- /dev/null +++ b/integration/xgboost/encryption_plugins/CMakeLists.txt @@ -0,0 +1,41 @@ +cmake_minimum_required(VERSION 3.19) +project(xgb_nvflare LANGUAGES CXX C VERSION 1.0) +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_BUILD_TYPE Debug) + +option(GOOGLE_TEST "Build google tests" OFF) + +file(GLOB_RECURSE LIB_SRC "src/*.cc") + +add_library(nvflare SHARED ${LIB_SRC}) +set_target_properties(nvflare PROPERTIES + CXX_STANDARD 17 + CXX_STANDARD_REQUIRED ON + POSITION_INDEPENDENT_CODE ON + ENABLE_EXPORTS ON +) +target_include_directories(nvflare PRIVATE ${xgb_nvflare_SOURCE_DIR}/src/include) + +if (APPLE) + add_link_options("LINKER:-object_path_lto,$<TARGET_PROPERTY:NAME>_lto.o") + add_link_options("LINKER:-cache_path_lto,${CMAKE_BINARY_DIR}/LTOCache") +endif () + +#-- Unit Tests +if(GOOGLE_TEST) + find_package(GTest REQUIRED) + enable_testing() + add_executable(nvflare_test) + target_link_libraries(nvflare_test PRIVATE nvflare) + + + target_include_directories(nvflare_test PRIVATE ${xgb_nvflare_SOURCE_DIR}/src/include) + + add_subdirectory(${xgb_nvflare_SOURCE_DIR}/tests) + + add_test( + NAME TestNvflarePlugins + COMMAND nvflare_test + WORKING_DIRECTORY ${xgb_nvflare_BINARY_DIR}) + +endif() diff --git a/integration/xgboost/encryption_plugins/README.md b/integration/xgboost/encryption_plugins/README.md new file mode 100644 index 0000000000..57f2c4621e --- /dev/null +++ b/integration/xgboost/encryption_plugins/README.md @@ -0,0 +1,9 @@ +# Build Instruction + +cd NVFlare/integration/xgboost/encryption_plugins +mkdir build +cd build +cmake .. +make + +The library is libxgb_nvflare.so diff --git a/integration/xgboost/processor/src/README.md b/integration/xgboost/encryption_plugins/src/README.md similarity index 100% rename from integration/xgboost/processor/src/README.md rename to integration/xgboost/encryption_plugins/src/README.md diff --git a/integration/xgboost/processor/src/dam/README.md b/integration/xgboost/encryption_plugins/src/dam/README.md similarity index 100% rename from integration/xgboost/processor/src/dam/README.md rename to integration/xgboost/encryption_plugins/src/dam/README.md diff --git a/integration/xgboost/encryption_plugins/src/dam/dam.cc b/integration/xgboost/encryption_plugins/src/dam/dam.cc new file mode 100644 index 0000000000..9fdb7d8582 --- /dev/null +++ b/integration/xgboost/encryption_plugins/src/dam/dam.cc @@ -0,0 +1,274 @@ +/** + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include <iostream> +#include <cstring> +#include "dam.h" + + +void print_hex(const uint8_t *buffer, std::size_t size) { + std::cout << std::hex; + for (int i = 0; i < size; i++) { + int c = buffer[i]; + std::cout << c << " "; + } + std::cout << std::endl << std::dec; +} + +void print_buffer(const uint8_t *buffer, std::size_t size) { + if (size <= 64) { + std::cout << "Whole buffer: " << size << " bytes" << std::endl; + print_hex(buffer, size); + return; + } + + std::cout << "First chunk, Total: " << size << " bytes" << std::endl; + print_hex(buffer, 32); + std::cout << "Last chunk, Offset: " << size-16 << " bytes" << std::endl; + print_hex(buffer+size-32, 32); +} + +size_t align(const size_t length) { + return ((length + 7)/8)*8; +} + +// DamEncoder ====== +void DamEncoder::AddBuffer(const Buffer &buffer) { + if (debug_) { + std::cout << "AddBuffer called, size: " << buffer.buf_size << std::endl; + } + if (encoded_) { + std::cout << "Buffer is already encoded" << std::endl; + return; + } + // print_buffer(buffer, buf_size); + entries_.emplace_back(kDataTypeBuffer, static_cast<const uint8_t *>(buffer.buffer), buffer.buf_size); +} + +void DamEncoder::AddFloatArray(const std::vector<double> &value) { + if (debug_) { + std::cout << "AddFloatArray called, size: " << value.size() << std::endl; + } + + if (encoded_) { + std::cout << "Buffer is already encoded" << std::endl; + return; + } + // print_buffer(reinterpret_cast<uint8_t *>(value.data()), value.size() * 8); + entries_.emplace_back(kDataTypeFloatArray, reinterpret_cast<const uint8_t *>(value.data()), value.size()); +} + +void DamEncoder::AddIntArray(const std::vector<int64_t> &value) { + if (debug_) { + std::cout << "AddIntArray called, size: " << value.size() << std::endl; + } + + if (encoded_) { + std::cout << "Buffer is already encoded" << std::endl; + return; + } + // print_buffer(buffer, buf_size); + entries_.emplace_back(kDataTypeIntArray, reinterpret_cast<const uint8_t *>(value.data()), value.size()); +} + +void DamEncoder::AddBufferArray(const std::vector<Buffer> &value) { + if (debug_) { + std::cout << "AddBufferArray called, size: " << value.size() << std::endl; + } + + if (encoded_) { + std::cout << "Buffer is already encoded" << std::endl; + return; + } + size_t size = 0; + for (auto &buf: value) { + size += buf.buf_size; + } + size += 8*value.size(); + entries_.emplace_back(kDataTypeBufferArray, reinterpret_cast<const uint8_t *>(&value), size); +} + + +std::uint8_t * DamEncoder::Finish(size_t &size) { + encoded_ = true; + + size = CalculateSize(); + auto buf = static_cast<uint8_t *>(calloc(size, 1)); + auto pointer = buf; + auto sig = local_version_ ? kSignatureLocal : kSignature; + memcpy(pointer, sig, strlen(sig)); + memcpy(pointer+8, &size, 8); + memcpy(pointer+16, &data_set_id_, 8); + + pointer += kPrefixLen; + for (auto& entry : entries_) { + std::size_t len; + if (entry.data_type == kDataTypeBufferArray) { + auto buffers = reinterpret_cast<const std::vector<Buffer> *>(entry.pointer); + memcpy(pointer, &entry.data_type, 8); + pointer += 8; + auto array_size = static_cast<int64_t>(buffers->size()); + memcpy(pointer, &array_size, 8); + pointer += 8; + auto sizes = reinterpret_cast<int64_t *>(pointer); + for (auto &item : *buffers) { + *sizes = static_cast<int64_t>(item.buf_size); + sizes++; + } + len = 8*buffers->size(); + auto buf_ptr = pointer + len; + for (auto &item : *buffers) { + if (item.buf_size > 0) { + memcpy(buf_ptr, item.buffer, item.buf_size); + } + buf_ptr += item.buf_size; + len += item.buf_size; + } + } else { + memcpy(pointer, &entry.data_type, 8); + pointer += 8; + memcpy(pointer, &entry.size, 8); + pointer += 8; + len = entry.size * entry.ItemSize(); + if (len) { + memcpy(pointer, entry.pointer, len); + } + } + pointer += align(len); + } + + if ((pointer - buf) != size) { + std::cout << "Invalid encoded size: " << (pointer - buf) << std::endl; + return nullptr; + } + + return buf; +} + +std::size_t DamEncoder::CalculateSize() { + std::size_t size = kPrefixLen; + + for (auto& entry : entries_) { + size += 16; // The Type and Len + auto len = entry.size * entry.ItemSize(); + size += align(len); + } + + return size; +} + + +// DamDecoder ====== + +DamDecoder::DamDecoder(std::uint8_t *buffer, std::size_t size, bool local_version, bool debug) { + local_version_ = local_version; + buffer_ = buffer; + buf_size_ = size; + pos_ = buffer + kPrefixLen; + debug_ = debug; + + if (size >= kPrefixLen) { + memcpy(&len_, buffer + 8, 8); + memcpy(&data_set_id_, buffer + 16, 8); + } else { + len_ = 0; + data_set_id_ = 0; + } +} + +bool DamDecoder::IsValid() const { + auto sig = local_version_ ? kSignatureLocal : kSignature; + return buf_size_ >= kPrefixLen && memcmp(buffer_, sig, strlen(sig)) == 0; +} + +Buffer DamDecoder::DecodeBuffer() { + auto type = *reinterpret_cast<int64_t *>(pos_); + if (type != kDataTypeBuffer) { + std::cout << "Data type " << type << " doesn't match bytes" << std::endl; + return {}; + } + pos_ += 8; + + auto size = *reinterpret_cast<int64_t *>(pos_); + pos_ += 8; + + if (size == 0) { + return {}; + } + + auto ptr = reinterpret_cast<void *>(pos_); + pos_ += align(size); + return{ ptr, static_cast<std::size_t>(size)}; +} + +std::vector<int64_t> DamDecoder::DecodeIntArray() { + auto type = *reinterpret_cast<int64_t *>(pos_); + if (type != kDataTypeIntArray) { + std::cout << "Data type " << type << " doesn't match Int Array" << std::endl; + return {}; + } + pos_ += 8; + + auto array_size = *reinterpret_cast<int64_t *>(pos_); + pos_ += 8; + auto ptr = reinterpret_cast<int64_t *>(pos_); + pos_ += align(8 * array_size); + return {ptr, ptr + array_size}; +} + +std::vector<double> DamDecoder::DecodeFloatArray() { + auto type = *reinterpret_cast<int64_t *>(pos_); + if (type != kDataTypeFloatArray) { + std::cout << "Data type " << type << " doesn't match Float Array" << std::endl; + return {}; + } + pos_ += 8; + + auto array_size = *reinterpret_cast<int64_t *>(pos_); + pos_ += 8; + + auto ptr = reinterpret_cast<double *>(pos_); + pos_ += align(8 * array_size); + return {ptr, ptr + array_size}; +} + +std::vector<Buffer> DamDecoder::DecodeBufferArray() { + auto type = *reinterpret_cast<int64_t *>(pos_); + if (type != kDataTypeBufferArray) { + std::cout << "Data type " << type << " doesn't match Bytes Array" << std::endl; + return {}; + } + pos_ += 8; + + auto num = *reinterpret_cast<int64_t *>(pos_); + pos_ += 8; + + auto size_ptr = reinterpret_cast<int64_t *>(pos_); + auto buf_ptr = pos_ + 8 * num; + size_t total_size = 8 * num; + auto result = std::vector<Buffer>(num); + for (int i = 0; i < num; i++) { + auto size = size_ptr[i]; + if (buf_size_ > 0) { + result[i].buf_size = size; + result[i].buffer = buf_ptr; + buf_ptr += size; + } + total_size += size; + } + + pos_ += align(total_size); + return result; +} diff --git a/integration/xgboost/encryption_plugins/src/include/base_plugin.h b/integration/xgboost/encryption_plugins/src/include/base_plugin.h new file mode 100644 index 0000000000..dddd5a7911 --- /dev/null +++ b/integration/xgboost/encryption_plugins/src/include/base_plugin.h @@ -0,0 +1,155 @@ +/** + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include <cstdint> // for uint8_t, uint32_t, int32_t, int64_t +#include <string_view> // for string_view +#include <utility> // for pair +#include <vector> // for vector +#include <sstream> +#include <iomanip> +#include <unistd.h> + +#include "util.h" + +namespace nvflare { + +/** + * @brief Abstract interface for the encryption plugin + * + * All plugin implementations must inherit this class. + */ +class BasePlugin { +protected: + bool debug_ = false; + bool print_timing_ = false; + bool dam_debug_ = false; + +public: +/** + * @brief Constructor + * + * All inherited classes should call this constructor. + * + * @param args Entries from federated_plugin in communicator environments. + */ + explicit BasePlugin( + std::vector<std::pair<std::string_view, std::string_view>> const &args) { + debug_ = get_bool(args, "debug"); + print_timing_ = get_bool(args, "print_timing"); + dam_debug_ = get_bool(args, "dam_debug"); + } + + /** + * @brief Destructor + */ + virtual ~BasePlugin() = default; + + /** + * @brief Identity for the plugin used for debug + * + * This is a string with instance address and process id. + */ + std::string Ident() { + std::stringstream ss; + ss << std::hex << std::uppercase << std::setw(sizeof(void*) * 2) << std::setfill('0') << + reinterpret_cast<uintptr_t>(this); + return ss.str() + "-" + std::to_string(getpid()); + } + + /** + * @brief Encrypt the gradient pairs + * + * @param in_gpair Input g and h pairs for each record + * @param n_in The array size (2xnum_of_records) + * @param out_gpair Pointer to encrypted buffer + * @param n_out Encrypted buffer size + */ + virtual void EncryptGPairs(float const *in_gpair, std::size_t n_in, + std::uint8_t **out_gpair, std::size_t *n_out) = 0; + + /** + * @brief Process encrypted gradient pairs + * + * @param in_gpair Encrypted gradient pairs + * @param n_bytes Buffer size of Encrypted gradient + * @param out_gpair Pointer to decrypted gradient pairs + * @param out_n_bytes Decrypted buffer size + */ + virtual void SyncEncryptedGPairs(std::uint8_t const *in_gpair, std::size_t n_bytes, + std::uint8_t const **out_gpair, + std::size_t *out_n_bytes) = 0; + + /** + * @brief Reset the histogram context + * + * @param cutptrs Cut-pointers for the flattened histograms + * @param cutptr_len cutptrs array size (number of features plus one) + * @param bin_idx An array (flattened matrix) of slot index for each record/feature + * @param n_idx The size of above array + */ + virtual void ResetHistContext(std::uint32_t const *cutptrs, std::size_t cutptr_len, + std::int32_t const *bin_idx, std::size_t n_idx) = 0; + + /** + * @brief Encrypt histograms for horizontal training + * + * @param in_histogram The array for the histogram + * @param len The array size + * @param out_hist Pointer to encrypted buffer + * @param out_len Encrypted buffer size + */ + virtual void BuildEncryptedHistHori(double const *in_histogram, std::size_t len, + std::uint8_t **out_hist, std::size_t *out_len) = 0; + + /** + * @brief Process encrypted histograms for horizontal training + * + * @param buffer Buffer for encrypted histograms + * @param len Buffer size of encrypted histograms + * @param out_hist Pointer to decrypted histograms + * @param out_len Size of above array + */ + virtual void SyncEncryptedHistHori(std::uint8_t const *buffer, std::size_t len, + double **out_hist, std::size_t *out_len) = 0; + + /** + * @brief Build histograms in encrypted space for vertical training + * + * @param ridx Pointer to a matrix of row IDs for each node + * @param sizes An array of sizes of each node + * @param nidx An array for each node ID + * @param len Number of nodes + * @param out_hist Pointer to encrypted histogram buffer + * @param out_len Buffer size + */ + virtual void BuildEncryptedHistVert(std::uint64_t const **ridx, + std::size_t const *sizes, + std::int32_t const *nidx, std::size_t len, + std::uint8_t **out_hist, std::size_t *out_len) = 0; + + /** + * @brief Decrypt histogram for vertical training + * + * @param hist_buffer Encrypted histogram buffer + * @param len Buffer size of encrypted histogram + * @param out Pointer to decrypted histograms + * @param out_len Size of above array + */ + virtual void SyncEncryptedHistVert(std::uint8_t *hist_buffer, std::size_t len, + double **out, std::size_t *out_len) = 0; +}; +} // namespace nvflare diff --git a/integration/xgboost/encryption_plugins/src/include/dam.h b/integration/xgboost/encryption_plugins/src/include/dam.h new file mode 100644 index 0000000000..8677a413b1 --- /dev/null +++ b/integration/xgboost/encryption_plugins/src/include/dam.h @@ -0,0 +1,143 @@ +/** + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include <vector> +#include <map> + +constexpr char kSignature[] = "NVDADAM1"; // DAM (Direct Accessible Marshalling) V1 +constexpr char kSignatureLocal[] = "NVDADAML"; // DAM Local version +constexpr int kPrefixLen = 24; + +constexpr int kDataTypeInt = 1; +constexpr int kDataTypeFloat = 2; +constexpr int kDataTypeString = 3; +constexpr int kDataTypeBuffer = 4; +constexpr int kDataTypeIntArray = 257; +constexpr int kDataTypeFloatArray = 258; +constexpr int kDataTypeBufferArray = 259; +constexpr int kDataTypeMap = 1025; + +/*! \brief A replacement for std::span */ +class Buffer { +public: + void *buffer; + size_t buf_size; + bool allocated; + + Buffer() : buffer(nullptr), buf_size(0), allocated(false) { + } + + Buffer(void *buffer, size_t buf_size, bool allocated=false) : + buffer(buffer), buf_size(buf_size), allocated(allocated) { + } + + Buffer(const Buffer &that): + buffer(that.buffer), buf_size(that.buf_size), allocated(false) { + } +}; + +class Entry { + public: + int64_t data_type; + const uint8_t * pointer; + int64_t size; + + Entry(int64_t data_type, const uint8_t *pointer, int64_t size) { + this->data_type = data_type; + this->pointer = pointer; + this->size = size; + } + + [[nodiscard]] std::size_t ItemSize() const + { + size_t item_size; + switch (data_type) { + case kDataTypeBuffer: + case kDataTypeString: + case kDataTypeBufferArray: + item_size = 1; + break; + default: + item_size = 8; + } + return item_size; + } +}; + +class DamEncoder { + private: + bool encoded_ = false; + bool local_version_ = false; + bool debug_ = false; + int64_t data_set_id_; + std::vector<Entry> entries_; + + public: + explicit DamEncoder(int64_t data_set_id, bool local_version=false, bool debug=false) { + data_set_id_ = data_set_id; + local_version_ = local_version; + debug_ = debug; + + } + + void AddBuffer(const Buffer &buffer); + + void AddIntArray(const std::vector<int64_t> &value); + + void AddFloatArray(const std::vector<double> &value); + + void AddBufferArray(const std::vector<Buffer> &value); + + std::uint8_t * Finish(size_t &size); + + private: + std::size_t CalculateSize(); +}; + +class DamDecoder { + private: + bool local_version_ = false; + std::uint8_t *buffer_ = nullptr; + std::size_t buf_size_ = 0; + std::uint8_t *pos_ = nullptr; + std::size_t remaining_ = 0; + int64_t data_set_id_ = 0; + int64_t len_ = 0; + bool debug_ = false; + + public: + explicit DamDecoder(std::uint8_t *buffer, std::size_t size, bool local_version=false, bool debug=false); + + [[nodiscard]] std::size_t Size() const { + return len_; + } + + [[nodiscard]] int64_t GetDataSetId() const { + return data_set_id_; + } + + [[nodiscard]] bool IsValid() const; + + Buffer DecodeBuffer(); + + std::vector<int64_t> DecodeIntArray(); + + std::vector<double> DecodeFloatArray(); + + std::vector<Buffer> DecodeBufferArray(); +}; + +void print_buffer(const uint8_t *buffer, std::size_t size); diff --git a/integration/xgboost/encryption_plugins/src/include/data_set_ids.h b/integration/xgboost/encryption_plugins/src/include/data_set_ids.h new file mode 100644 index 0000000000..98eb20e838 --- /dev/null +++ b/integration/xgboost/encryption_plugins/src/include/data_set_ids.h @@ -0,0 +1,23 @@ +/** + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +constexpr int kDataSetGHPairs = 1; +constexpr int kDataSetAggregation = 2; +constexpr int kDataSetAggregationWithFeatures = 3; +constexpr int kDataSetAggregationResult = 4; +constexpr int kDataSetHistograms = 5; +constexpr int kDataSetHistogramResult = 6; diff --git a/integration/xgboost/encryption_plugins/src/include/delegated_plugin.h b/integration/xgboost/encryption_plugins/src/include/delegated_plugin.h new file mode 100644 index 0000000000..7b4f353b21 --- /dev/null +++ b/integration/xgboost/encryption_plugins/src/include/delegated_plugin.h @@ -0,0 +1,66 @@ +/** + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include "base_plugin.h" + +namespace nvflare { + +// Plugin that delegates to other real plugins +class DelegatedPlugin : public BasePlugin { + + BasePlugin *plugin_{nullptr}; + +public: + explicit DelegatedPlugin(std::vector<std::pair<std::string_view, std::string_view>> const &args); + + ~DelegatedPlugin() override { + delete plugin_; + } + + void EncryptGPairs(const float* in_gpair, std::size_t n_in, std::uint8_t** out_gpair, std::size_t* n_out) override { + plugin_->EncryptGPairs(in_gpair, n_in, out_gpair, n_out); + } + + void SyncEncryptedGPairs(const std::uint8_t* in_gpair, std::size_t n_bytes, const std::uint8_t** out_gpair, + std::size_t* out_n_bytes) override { + plugin_->SyncEncryptedGPairs(in_gpair, n_bytes, out_gpair, out_n_bytes); + } + + void ResetHistContext(const std::uint32_t* cutptrs, std::size_t cutptr_len, const std::int32_t* bin_idx, + std::size_t n_idx) override { + plugin_->ResetHistContext(cutptrs, cutptr_len, bin_idx, n_idx); + } + + void BuildEncryptedHistHori(const double* in_histogram, std::size_t len, std::uint8_t** out_hist, + std::size_t* out_len) override { + plugin_->BuildEncryptedHistHori(in_histogram, len, out_hist, out_len); + } + + void SyncEncryptedHistHori(const std::uint8_t* buffer, std::size_t len, double** out_hist, + std::size_t* out_len) override { + plugin_->SyncEncryptedHistHori(buffer, len, out_hist, out_len); + } + + void BuildEncryptedHistVert(const std::uint64_t** ridx, const std::size_t* sizes, const std::int32_t* nidx, + std::size_t len, std::uint8_t** out_hist, std::size_t* out_len) override { + plugin_->BuildEncryptedHistVert(ridx, sizes, nidx, len, out_hist, out_len); + } + + void SyncEncryptedHistVert(std::uint8_t* hist_buffer, std::size_t len, double** out, std::size_t* out_len) override { + plugin_->SyncEncryptedHistVert(hist_buffer, len, out, out_len); + } +}; +} // namespace nvflare diff --git a/integration/xgboost/encryption_plugins/src/include/local_plugin.h b/integration/xgboost/encryption_plugins/src/include/local_plugin.h new file mode 100644 index 0000000000..2022322266 --- /dev/null +++ b/integration/xgboost/encryption_plugins/src/include/local_plugin.h @@ -0,0 +1,107 @@ +/** + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "base_plugin.h" +#include "dam.h" + +namespace nvflare { + +// A base plugin for all plugins that handle encryption locally in C++ +class LocalPlugin : public BasePlugin { +protected: + std::vector<double> gh_pairs_; + std::vector<uint8_t> encrypted_gh_; + std::vector<double> histo_; + std::vector<uint32_t> cuts_; + std::vector<int32_t> slots_; + std::vector<uint8_t> buffer_; + +public: + explicit LocalPlugin(std::vector<std::pair<std::string_view, std::string_view>> const &args) : + BasePlugin(args) {} + + ~LocalPlugin() override = default; + + void EncryptGPairs(const float *in_gpair, std::size_t n_in, std::uint8_t **out_gpair, + std::size_t *n_out) override; + + void SyncEncryptedGPairs(const std::uint8_t *in_gpair, std::size_t n_bytes, const std::uint8_t **out_gpair, + std::size_t *out_n_bytes) override; + + void ResetHistContext(const std::uint32_t *cutptrs, std::size_t cutptr_len, const std::int32_t *bin_idx, + std::size_t n_idx) override; + + void BuildEncryptedHistHori(const double *in_histogram, std::size_t len, std::uint8_t **out_hist, + std::size_t *out_len) override; + + void SyncEncryptedHistHori(const std::uint8_t *buffer, std::size_t len, double **out_hist, + std::size_t *out_len) override; + + void BuildEncryptedHistVert(const std::uint64_t **ridx, const std::size_t *sizes, const std::int32_t *nidx, + std::size_t len, std::uint8_t **out_hist, std::size_t *out_len) override; + + void SyncEncryptedHistVert(std::uint8_t *hist_buffer, std::size_t len, double **out, + std::size_t *out_len) override; + + // Method needs to be implemented by local plugins + + /*! + * \brief Encrypt a vector of float-pointing numbers + * \param cleartext A vector of numbers in cleartext + * \return A buffer with serialized ciphertext + */ + virtual Buffer EncryptVector(const std::vector<double> &cleartext) = 0; + + /*! + * \brief Decrypt a serialized ciphertext into an array of numbers + * \param ciphertext A serialzied buffer of ciphertext + * \return An array of numbers + */ + virtual std::vector<double> DecryptVector(const std::vector<Buffer> &ciphertext) = 0; + + /*! + * \brief Add the G&H pairs for a series of samples + * \param sample_ids A map of slot number and an array of sample IDs + * \return A map of the serialized encrypted sum of G and H for each slot + * The input and output maps must have the same size + */ + virtual std::map<int, Buffer> AddGHPairs(const std::map<int, std::vector<int>> &sample_ids) = 0; + + /*! + * \brief Free encrypted data buffer + * \param ciphertext The buffer for encrypted data + */ + virtual void FreeEncryptedData(Buffer &ciphertext) { + if (ciphertext.allocated && ciphertext.buffer != nullptr) { + free(ciphertext.buffer); + ciphertext.allocated = false; + } + ciphertext.buffer = nullptr; + ciphertext.buf_size = 0; + }; + +private: + + void BuildEncryptedHistVertActive(const std::uint64_t **ridx, const std::size_t *sizes, const std::int32_t *nidx, + std::size_t len, std::uint8_t **out_hist, std::size_t *out_len); + + void BuildEncryptedHistVertPassive(const std::uint64_t **ridx, const std::size_t *sizes, const std::int32_t *nidx, + std::size_t len, std::uint8_t **out_hist, std::size_t *out_len); + +}; + +} // namespace nvflare diff --git a/integration/xgboost/processor/src/include/nvflare_processor.h b/integration/xgboost/encryption_plugins/src/include/nvflare_plugin.h similarity index 77% rename from integration/xgboost/processor/src/include/nvflare_processor.h rename to integration/xgboost/encryption_plugins/src/include/nvflare_plugin.h index cb7076eaf4..87f47d622c 100644 --- a/integration/xgboost/processor/src/include/nvflare_processor.h +++ b/integration/xgboost/encryption_plugins/src/include/nvflare_plugin.h @@ -14,61 +14,65 @@ * limitations under the License. */ #pragma once + #include <cstdint> // for uint8_t, uint32_t, int32_t, int64_t #include <string_view> // for string_view #include <utility> // for pair #include <vector> // for vector -const int kDataSetHGPairs = 1; -const int kDataSetAggregation = 2; -const int kDataSetAggregationWithFeatures = 3; -const int kDataSetAggregationResult = 4; -const int kDataSetHistograms = 5; -const int kDataSetHistogramResult = 6; - -// Opaque pointer type for the C API. -typedef void *FederatedPluginHandle; // NOLINT +#include "base_plugin.h" namespace nvflare { -// Plugin that uses Python tenseal and GRPC. -class TensealPlugin { + +// Plugin that uses Python TenSeal and GRPC. +class NvflarePlugin : public BasePlugin { // Buffer for storing encrypted gradient pairs. std::vector<std::uint8_t> encrypted_gpairs_; // Buffer for histogram cut pointers (indptr of a CSC). std::vector<std::uint32_t> cut_ptrs_; // Buffer for histogram index. std::vector<std::int32_t> bin_idx_; + std::vector<double> gh_pairs_; bool feature_sent_{false}; // The feature index. std::vector<std::int64_t> features_; // Buffer for output histogram. std::vector<std::uint8_t> encrypted_hist_; - std::vector<double> hist_; + // A temporary buffer to hold return value + std::vector<std::uint8_t> buffer_; + // Buffer for clear histogram + std::vector<double> histo_; public: - TensealPlugin( - std::vector<std::pair<std::string_view, std::string_view>> const &args); + explicit NvflarePlugin(std::vector<std::pair<std::string_view, std::string_view>> const &args) : BasePlugin(args) {} + + ~NvflarePlugin() override = default; + // Gradient pairs void EncryptGPairs(float const *in_gpair, std::size_t n_in, - std::uint8_t **out_gpair, std::size_t *n_out); + std::uint8_t **out_gpair, std::size_t *n_out) override; + void SyncEncryptedGPairs(std::uint8_t const *in_gpair, std::size_t n_bytes, std::uint8_t const **out_gpair, - std::size_t *out_n_bytes); + std::size_t *out_n_bytes) override; // Histogram void ResetHistContext(std::uint32_t const *cutptrs, std::size_t cutptr_len, - std::int32_t const *bin_idx, std::size_t n_idx); + std::int32_t const *bin_idx, std::size_t n_idx) override; + void BuildEncryptedHistHori(double const *in_histogram, std::size_t len, - std::uint8_t **out_hist, std::size_t *out_len); + std::uint8_t **out_hist, std::size_t *out_len) override; + void SyncEncryptedHistHori(std::uint8_t const *buffer, std::size_t len, - double **out_hist, std::size_t *out_len); + double **out_hist, std::size_t *out_len) override; - void BuildEncryptedHistVert(std::size_t const **ridx, + void BuildEncryptedHistVert(std::uint64_t const **ridx, std::size_t const *sizes, std::int32_t const *nidx, std::size_t len, - std::uint8_t **out_hist, std::size_t *out_len); + std::uint8_t **out_hist, std::size_t *out_len) override; + void SyncEncryptedHistVert(std::uint8_t *hist_buffer, std::size_t len, - double **out, std::size_t *out_len); + double **out, std::size_t *out_len) override; }; } // namespace nvflare diff --git a/integration/xgboost/encryption_plugins/src/include/pass_thru_plugin.h b/integration/xgboost/encryption_plugins/src/include/pass_thru_plugin.h new file mode 100644 index 0000000000..3abeee4b56 --- /dev/null +++ b/integration/xgboost/encryption_plugins/src/include/pass_thru_plugin.h @@ -0,0 +1,41 @@ +/** + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include "local_plugin.h" + +namespace nvflare { + // A pass-through plugin that doesn't encrypt any data + class PassThruPlugin : public LocalPlugin { + public: + explicit PassThruPlugin(std::vector<std::pair<std::string_view, std::string_view>> const &args) : + LocalPlugin(args) {} + + ~PassThruPlugin() override = default; + + // Horizontal in local plugin still goes through NVFlare, so it needs to be overwritten + void BuildEncryptedHistHori(const double *in_histogram, std::size_t len, std::uint8_t **out_hist, + std::size_t *out_len) override; + + void SyncEncryptedHistHori(const std::uint8_t *buffer, std::size_t len, double **out_hist, + std::size_t *out_len) override; + + Buffer EncryptVector(const std::vector<double> &cleartext) override; + + std::vector<double> DecryptVector(const std::vector<Buffer> &ciphertext) override; + + std::map<int, Buffer> AddGHPairs(const std::map<int, std::vector<int>> &sample_ids) override; + }; +} // namespace nvflare diff --git a/integration/xgboost/encryption_plugins/src/include/util.h b/integration/xgboost/encryption_plugins/src/include/util.h new file mode 100644 index 0000000000..bb8ba16d1a --- /dev/null +++ b/integration/xgboost/encryption_plugins/src/include/util.h @@ -0,0 +1,18 @@ +#pragma once +#include <string> +#include <vector> + +std::vector<std::pair<int, int>> distribute_work(size_t num_jobs, size_t num_workers); + +uint32_t to_int(double d); + +double to_double(uint32_t i); + +std::string get_string(std::vector<std::pair<std::string_view, std::string_view>> const &args, + std::string_view const &key,std::string_view default_value = ""); + +bool get_bool(std::vector<std::pair<std::string_view, std::string_view>> const &args, + const std::string &key, bool default_value = false); + +int get_int(std::vector<std::pair<std::string_view, std::string_view>> const &args, + const std::string &key, int default_value = 0); diff --git a/integration/xgboost/encryption_plugins/src/plugins/delegated_plugin.cc b/integration/xgboost/encryption_plugins/src/plugins/delegated_plugin.cc new file mode 100644 index 0000000000..a026822799 --- /dev/null +++ b/integration/xgboost/encryption_plugins/src/plugins/delegated_plugin.cc @@ -0,0 +1,36 @@ +/** + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "delegated_plugin.h" +#include "pass_thru_plugin.h" +#include "nvflare_plugin.h" + +namespace nvflare { + +DelegatedPlugin::DelegatedPlugin(std::vector<std::pair<std::string_view, std::string_view>> const &args): + BasePlugin(args) { + + auto name = get_string(args, "name"); + // std::cout << "==== Name is " << name << std::endl; + if (name == "pass-thru") { + plugin_ = new PassThruPlugin(args); + } else if (name == "nvflare") { + plugin_ = new NvflarePlugin(args); + } else { + throw std::invalid_argument{"Unknown plugin name: " + name}; + } +} + +} // namespace nvflare diff --git a/integration/xgboost/encryption_plugins/src/plugins/local_plugin.cc b/integration/xgboost/encryption_plugins/src/plugins/local_plugin.cc new file mode 100644 index 0000000000..99e304ea77 --- /dev/null +++ b/integration/xgboost/encryption_plugins/src/plugins/local_plugin.cc @@ -0,0 +1,366 @@ +/** + * Copyright 2014-2024 by XGBoost Contributors + */ +#include <iostream> +#include <algorithm> +#include <chrono> +#include "local_plugin.h" +#include "data_set_ids.h" + +namespace nvflare { + +void LocalPlugin::EncryptGPairs(const float *in_gpair, std::size_t n_in, std::uint8_t **out_gpair, std::size_t *n_out) { + if (debug_) { + std::cout << Ident() << " LocalPlugin::EncryptGPairs called with pairs size: " << n_in << std::endl; + } + + if (print_timing_) { + std::cout << "Encrypting " << n_in / 2 << " GH Pairs" << std::endl; + } + auto start = std::chrono::system_clock::now(); + + auto pairs = std::vector<float>(in_gpair, in_gpair + n_in); + auto double_pairs = std::vector<double>(pairs.cbegin(), pairs.cend()); + auto encrypted_data = EncryptVector(double_pairs); + + if (print_timing_) { + auto end = std::chrono::system_clock::now(); + auto secs = static_cast<double>(std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count()) / 1000.0; + std::cout << "Encryption time: " << secs << " seconds" << std::endl; + } + + // Serialize with DAM so the buffers can be separated after all-gather + DamEncoder encoder(kDataSetGHPairs, true, dam_debug_); + encoder.AddBuffer(encrypted_data); + + std::size_t size; + auto buffer = encoder.Finish(size); + FreeEncryptedData(encrypted_data); + buffer_.resize(size); + std::copy_n(buffer, size, buffer_.begin()); + free(buffer); + + *out_gpair = buffer_.data(); + *n_out = buffer_.size(); + if (debug_) { + std::cout << "Encrypted GPairs:" << std::endl; + print_buffer(*out_gpair, *n_out); + } + + // Save pairs for future operations. This is only called on active site + gh_pairs_ = std::vector<double>(double_pairs); +} + +void LocalPlugin::SyncEncryptedGPairs(const std::uint8_t *in_gpair, std::size_t n_bytes, + const std::uint8_t **out_gpair, std::size_t *out_n_bytes) { + if (debug_) { + std::cout << Ident() << " LocalPlugin::SyncEncryptedGPairs called with buffer:" << std::endl; + print_buffer(in_gpair, n_bytes); + } + + *out_n_bytes = n_bytes; + *out_gpair = in_gpair; + auto decoder = DamDecoder(const_cast<std::uint8_t *>(in_gpair), n_bytes, true, dam_debug_); + if (!decoder.IsValid()) { + std::cout << "LocalPlugin::SyncEncryptedGPairs called with wrong data" << std::endl; + return; + } + + auto encrypted_buffer = decoder.DecodeBuffer(); + if (debug_) { + std::cout << "Encrypted buffer size: " << encrypted_buffer.buf_size << std::endl; + } + + // The caller may free buffer so a copy is needed + auto pointer = static_cast<u_int8_t *>(encrypted_buffer.buffer); + encrypted_gh_ = std::vector<std::uint8_t>(pointer, pointer + encrypted_buffer.buf_size); + FreeEncryptedData(encrypted_buffer); +} + +void LocalPlugin::ResetHistContext(const std::uint32_t *cutptrs, std::size_t cutptr_len, const std::int32_t *bin_idx, + std::size_t n_idx) { + if (debug_) { + std::cout << Ident() << " LocalPlugin::ResetHistContext called with cutptrs size: " << cutptr_len << " bin_idx size: " + << n_idx << std::endl; + } + + cuts_ = std::vector<uint32_t>(cutptrs, cutptrs + cutptr_len); + slots_ = std::vector<int32_t>(bin_idx, bin_idx + n_idx); +} + +void LocalPlugin::BuildEncryptedHistHori(const double *in_histogram, std::size_t len, std::uint8_t **out_hist, + std::size_t *out_len) { + if (debug_) { + std::cout << Ident() << " LocalPlugin::BuildEncryptedHistHori called with " << len << " entries" << std::endl; + print_buffer(reinterpret_cast<const uint8_t*>(in_histogram), len); + } + + // don't have a local implementation yet, just encoded it and let NVFlare handle it. + DamEncoder encoder(kDataSetHistograms, false, dam_debug_); + auto histograms = std::vector<double>(in_histogram, in_histogram + len); + encoder.AddFloatArray(histograms); + std::size_t size; + auto buffer = encoder.Finish(size); + buffer_.resize(size); + std::copy_n(buffer, size, buffer_.begin()); + free(buffer); + + *out_hist = buffer_.data(); + *out_len = buffer_.size(); + if (debug_) { + std::cout << "Output buffer" << std::endl; + print_buffer(*out_hist, *out_len); + } +} + +void LocalPlugin::SyncEncryptedHistHori(const std::uint8_t *buffer, std::size_t len, double **out_hist, + std::size_t *out_len) { + if (debug_) { + std::cout << Ident() << " LocalPlugin::SyncEncryptedHistHori called with buffer size: " << len << std::endl; + print_buffer(buffer, len); + } + auto remaining = len; + auto pointer = buffer; + + // The buffer is concatenated by AllGather. It may contain multiple DAM buffers + std::vector<double>& result = histo_; + result.clear(); + while (remaining > kPrefixLen) { + DamDecoder decoder(const_cast<std::uint8_t *>(pointer), remaining, false, dam_debug_); + if (!decoder.IsValid()) { + std::cout << "Not DAM encoded histogram ignored at offset: " + << static_cast<int>(pointer - buffer) << std::endl; + break; + } + + if (decoder.GetDataSetId() != kDataSetHistogramResult) { + throw std::runtime_error{"Invalid dataset: " + std::to_string(decoder.GetDataSetId())}; + } + + auto size = decoder.Size(); + auto histo = decoder.DecodeFloatArray(); + result.insert(result.end(), histo.cbegin(), histo.cend()); + + remaining -= size; + pointer += size; + } + + *out_hist = result.data(); + *out_len = result.size(); + + if (debug_) { + std::cout << "Output buffer" << std::endl; + print_buffer(reinterpret_cast<const uint8_t*>(*out_hist), histo_.size() * sizeof(double)); + } +} + +void LocalPlugin::BuildEncryptedHistVert(const std::uint64_t **ridx, const std::size_t *sizes, const std::int32_t *nidx, + std::size_t len, std::uint8_t **out_hist, std::size_t *out_len) { + if (debug_) { + std::cout << Ident() << " LocalPlugin::BuildEncryptedHistVert called with number of nodes: " << len << std::endl; + } + + if (gh_pairs_.empty()) { + BuildEncryptedHistVertPassive(ridx, sizes, nidx, len, out_hist, out_len); + } else { + BuildEncryptedHistVertActive(ridx, sizes, nidx, len, out_hist, out_len); + } + + if (debug_) { + std::cout << "Encrypted histogram output:" << std::endl; + print_buffer(*out_hist, *out_len); + } +} + +void LocalPlugin::BuildEncryptedHistVertActive(const std::uint64_t **ridx, const std::size_t *sizes, const std::int32_t *, + std::size_t len, std::uint8_t **out_hist, std::size_t *out_len) { + + if (debug_) { + std::cout << Ident() << " LocalPlugin::BuildEncryptedHistVertActive called with " << len << " nodes" << std::endl; + } + + auto total_bin_size = cuts_.back(); + auto histo_size = total_bin_size * 2; + auto total_size = histo_size * len; + + histo_.clear(); + histo_.resize(total_size); + size_t start = 0; + for (std::size_t i = 0; i < len; i++) { + for (std::size_t j = 0; j < sizes[i]; j++) { + auto row_id = ridx[i][j]; + auto num = cuts_.size() - 1; + for (std::size_t f = 0; f < num; f++) { + int slot = slots_[f + num * row_id]; + if ((slot < 0) || (slot >= total_bin_size)) { + continue; + } + auto g = gh_pairs_[row_id * 2]; + auto h = gh_pairs_[row_id * 2 + 1]; + (histo_)[start + slot * 2] += g; + (histo_)[start + slot * 2 + 1] += h; + } + } + start += histo_size; + } + + // Histogram is in clear, can't send to all_gather. Just return empty DAM buffer + auto encoder = DamEncoder(kDataSetAggregationResult, true, dam_debug_); + encoder.AddBuffer(Buffer()); + std::size_t size; + auto buffer = encoder.Finish(size); + buffer_.resize(size); + std::copy_n(buffer, size, buffer_.begin()); + free(buffer); + *out_hist = buffer_.data(); + *out_len = size; +} + +void LocalPlugin::BuildEncryptedHistVertPassive(const std::uint64_t **ridx, const std::size_t *sizes, const std::int32_t *, + std::size_t len, std::uint8_t **out_hist, std::size_t *out_len) { + if (debug_) { + std::cout << Ident() << " LocalPlugin::BuildEncryptedHistVertPassive called with " << len << " nodes" << std::endl; + } + + auto num_slot = cuts_.back(); + auto total_size = num_slot * len; + + auto encrypted_histo = std::vector<Buffer>(total_size); + size_t offset = 0; + for (std::size_t i = 0; i < len; i++) { + auto num = cuts_.size() - 1; + auto row_id_map = std::map<int, std::vector<int>>(); + + // Empty slot leaks data so fill everything with empty vectors + for (int slot = 0; slot < num_slot; slot++) { + row_id_map.insert({slot, std::vector<int>()}); + } + + for (std::size_t f = 0; f < num; f++) { + for (std::size_t j = 0; j < sizes[i]; j++) { + auto row_id = ridx[i][j]; + int slot = slots_[f + num * row_id]; + if ((slot < 0) || (slot >= num_slot)) { + continue; + } + auto &row_ids = row_id_map[slot]; + row_ids.push_back(static_cast<int>(row_id)); + } + } + + if (print_timing_) { + std::size_t add_ops = 0; + for (auto &item: row_id_map) { + add_ops += item.second.size(); + } + std::cout << "Aggregating with " << add_ops << " additions" << std::endl; + } + auto start = std::chrono::system_clock::now(); + + auto encrypted_sum = AddGHPairs(row_id_map); + + if (print_timing_) { + auto end = std::chrono::system_clock::now(); + auto secs = static_cast<double>(std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count()) / 1000.0; + std::cout << "Aggregation time: " << secs << " seconds" << std::endl; + } + + // Convert map back to array + for (int slot = 0; slot < num_slot; slot++) { + auto it = encrypted_sum.find(slot); + if (it != encrypted_sum.end()) { + encrypted_histo[offset + slot] = it->second; + } + } + + offset += num_slot; + } + + auto encoder = DamEncoder(kDataSetAggregationResult, true, dam_debug_); + encoder.AddBufferArray(encrypted_histo); + std::size_t size; + auto buffer = encoder.Finish(size); + for (auto &item: encrypted_histo) { + FreeEncryptedData(item); + } + buffer_.resize(size); + std::copy_n(buffer, size, buffer_.begin()); + free(buffer); + *out_hist = buffer_.data(); + *out_len = size; +} + +void LocalPlugin::SyncEncryptedHistVert(std::uint8_t *hist_buffer, std::size_t len, + double **out, std::size_t *out_len) { + if (debug_) { + std::cout << Ident() << " LocalPlugin::SyncEncryptedHistVert called with buffer size: " << len << " nodes" << std::endl; + print_buffer(hist_buffer, len); + } + + auto remaining = len; + auto pointer = hist_buffer; + + *out = nullptr; + *out_len = 0; + if (gh_pairs_.empty()) { + if (debug_) { + std::cout << Ident() << " LocalPlugin::SyncEncryptedHistVert Do nothing for passive worker" << std::endl; + } + // Do nothing for passive worker + return; + } + + // The buffer is concatenated by AllGather. It may contain multiple DAM buffers + auto first = true; + auto orig_size = histo_.size(); + while (remaining > kPrefixLen) { + DamDecoder decoder(pointer, remaining, true, dam_debug_); + if (!decoder.IsValid()) { + std::cout << "Not DAM encoded buffer ignored at offset: " + << static_cast<int>((pointer - hist_buffer)) << std::endl; + break; + } + auto size = decoder.Size(); + if (first) { + if (histo_.empty()) { + std::cout << "No clear histogram." << std::endl; + return; + } + first = false; + } else { + auto encrypted_buf = decoder.DecodeBufferArray(); + + if (print_timing_) { + std::cout << "Decrypting " << encrypted_buf.size() << " pairs" << std::endl; + } + auto start = std::chrono::system_clock::now(); + + auto decrypted_histo = DecryptVector(encrypted_buf); + + if (print_timing_) { + auto end = std::chrono::system_clock::now(); + auto secs = static_cast<double>(std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count()) / 1000.0; + std::cout << "Decryption time: " << secs << " seconds" << std::endl; + } + + if (decrypted_histo.size() != orig_size) { + std::cout << "Histo sizes are different: " << decrypted_histo.size() + << " != " << orig_size << std::endl; + } + histo_.insert(histo_.end(), decrypted_histo.cbegin(), decrypted_histo.cend()); + } + remaining -= size; + pointer += size; + } + + if (debug_) { + std::cout << Ident() << " Decrypted result size: " << histo_.size() << std::endl; + } + + // print_buffer(reinterpret_cast<uint8_t *>(result.data()), result.size()*8); + + *out = histo_.data(); + *out_len = histo_.size(); +} + +} // namespace nvflare diff --git a/integration/xgboost/encryption_plugins/src/plugins/nvflare_plugin.cc b/integration/xgboost/encryption_plugins/src/plugins/nvflare_plugin.cc new file mode 100644 index 0000000000..b062aecfa6 --- /dev/null +++ b/integration/xgboost/encryption_plugins/src/plugins/nvflare_plugin.cc @@ -0,0 +1,297 @@ +/** + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include <iostream> +#include <algorithm> // for copy_n, transform +#include <cstring> // for memcpy +#include <stdexcept> // for invalid_argument +#include <vector> // for vector + +#include "nvflare_plugin.h" +#include "data_set_ids.h" +#include "dam.h" // for DamEncoder + +namespace nvflare { + +void NvflarePlugin::EncryptGPairs(float const *in_gpair, std::size_t n_in, + std::uint8_t **out_gpair, + std::size_t *n_out) { + if (debug_) { + std::cout << Ident() << " NvflarePlugin::EncryptGPairs called with pairs size: " << n_in<< std::endl; + } + + auto pairs = std::vector<float>(in_gpair, in_gpair + n_in); + gh_pairs_ = std::vector<double>(pairs.cbegin(), pairs.cend()); + + DamEncoder encoder(kDataSetGHPairs, false, dam_debug_); + encoder.AddFloatArray(gh_pairs_); + std::size_t size; + auto buffer = encoder.Finish(size); + if (!out_gpair) { + throw std::invalid_argument{"Invalid pointer to output gpair."}; + } + buffer_.resize(size); + std::copy_n(buffer, size, buffer_.begin()); + free(buffer); + *out_gpair = buffer_.data(); + *n_out = size; +} + +void NvflarePlugin::SyncEncryptedGPairs(std::uint8_t const *in_gpair, + std::size_t n_bytes, + std::uint8_t const **out_gpair, + std::size_t *out_n_bytes) { + if (debug_) { + std::cout << Ident() << " NvflarePlugin::SyncEncryptedGPairs called with buffer size: " << n_bytes << std::endl; + } + + // For NVFlare plugin, nothing needs to be done here + *out_n_bytes = n_bytes; + *out_gpair = in_gpair; +} + +void NvflarePlugin::ResetHistContext(std::uint32_t const *cutptrs, + std::size_t cutptr_len, + std::int32_t const *bin_idx, + std::size_t n_idx) { + if (debug_) { + std::cout << Ident() << " NvFlarePlugin::ResetHistContext called with cutptrs size: " << cutptr_len << " bin_idx size: " + << n_idx<< std::endl; + } + + cut_ptrs_.resize(cutptr_len); + std::copy_n(cutptrs, cutptr_len, cut_ptrs_.begin()); + bin_idx_.resize(n_idx); + std::copy_n(bin_idx, n_idx, this->bin_idx_.begin()); +} + +void NvflarePlugin::BuildEncryptedHistVert(std::uint64_t const **ridx, + std::size_t const *sizes, + std::int32_t const *nidx, + std::size_t len, + std::uint8_t** out_hist, + std::size_t* out_len) { + if (debug_) { + std::cout << Ident() << " NvflarePlugin::BuildEncryptedHistVert called with len: " << len << std::endl; + } + + std::int64_t data_set_id; + if (!feature_sent_) { + data_set_id = kDataSetAggregationWithFeatures; + feature_sent_ = true; + } else { + data_set_id = kDataSetAggregation; + } + + DamEncoder encoder(data_set_id, false, dam_debug_); + + // Add cuts pointers + std::vector<int64_t> cuts_vec(cut_ptrs_.cbegin(), cut_ptrs_.cend()); + encoder.AddIntArray(cuts_vec); + + auto num_features = cut_ptrs_.size() - 1; + auto num_samples = bin_idx_.size() / num_features; + if (debug_) { + std::cout << "Samples: " << num_samples << " Features: " << num_features << std::endl; + } + + std::vector<int64_t> bins; + if (data_set_id == kDataSetAggregationWithFeatures) { + if (features_.empty()) { // when is it not empty? + for (int64_t f = 0; f < num_features; f++) { + auto slot = bin_idx_[f]; + if (slot >= 0) { + // what happens if it's missing? + features_.push_back(f); + } + } + } + encoder.AddIntArray(features_); + + for (int i = 0; i < num_samples; i++) { + for (auto f : features_) { + auto index = f + i * num_features; + if (index > bin_idx_.size()) { + throw std::out_of_range{"Index is out of range: " + + std::to_string(index)}; + } + auto slot = bin_idx_[index]; + bins.push_back(slot); + } + } + encoder.AddIntArray(bins); + } + + // Add nodes to build + std::vector<int64_t> node_vec(len); + for (std::size_t i = 0; i < len; i++) { + node_vec[i] = nidx[i]; + } + encoder.AddIntArray(node_vec); + + // For each node, get the row_id/slot pair + auto row_ids = std::vector<std::vector<int64_t>>(len); + for (std::size_t i = 0; i < len; ++i) { + auto& rows = row_ids[i]; + rows.resize(sizes[i]); + for (std::size_t j = 0; j < sizes[i]; j++) { + rows[j] = static_cast<int64_t>(ridx[i][j]); + } + encoder.AddIntArray(rows); + } + + std::size_t n{0}; + auto buffer = encoder.Finish(n); + if (debug_) { + std::cout << "Finished size: " << n << std::endl; + } + + // XGBoost doesn't allow the change of allgatherV sizes. Make sure it's big + // enough to carry histograms + auto max_slot = cut_ptrs_.back(); + auto histo_size = 2 * max_slot * sizeof(double) * len + 1024*1024; // 1M is DAM overhead + auto buf_size = histo_size > n ? histo_size : n; + + // Copy to an array so the buffer can be freed, should change encoder to return vector + buffer_.resize(buf_size); + std::copy_n(buffer, n, buffer_.begin()); + free(buffer); + + *out_hist = buffer_.data(); + *out_len = buffer_.size(); +} + +void NvflarePlugin::SyncEncryptedHistVert(std::uint8_t *buffer, + std::size_t buf_size, + double **out, + std::size_t *out_len) { + if (debug_) { + std::cout << Ident() << " NvflarePlugin::SyncEncryptedHistVert called with buffer size: " << buf_size << std::endl; + } + + auto remaining = buf_size; + char *pointer = reinterpret_cast<char *>(buffer); + + // The buffer is concatenated by AllGather. It may contain multiple DAM buffers + std::vector<double> &result = histo_; + result.clear(); + auto max_slot = cut_ptrs_.back(); + auto array_size = 2 * max_slot * sizeof(double); + + // A new histogram array? + auto slots = static_cast<double *>(malloc(array_size)); + while (remaining > kPrefixLen) { + DamDecoder decoder(reinterpret_cast<uint8_t *>(pointer), remaining, false, dam_debug_); + if (!decoder.IsValid()) { + std::cout << "Not DAM encoded buffer ignored at offset: " + << static_cast<int>((pointer - reinterpret_cast<char *>(buffer))) << std::endl; + break; + } + auto size = decoder.Size(); + auto node_list = decoder.DecodeIntArray(); + if (debug_) { + std::cout << "Number of nodes: " << node_list.size() << " Histo size: " << 2*max_slot << std::endl; + } + for ([[maybe_unused]] auto node : node_list) { + std::memset(slots, 0, array_size); + auto feature_list = decoder.DecodeIntArray(); + // Convert per-feature histo to a flat one + for (auto f : feature_list) { + auto base = cut_ptrs_[f]; // cut pointer for the current feature + auto bins = decoder.DecodeFloatArray(); + auto n = bins.size() / 2; + for (int i = 0; i < n; i++) { + auto index = base + i; + // [Q] Build local histogram? Why does it need to be built here? + slots[2 * index] += bins[2 * i]; + slots[2 * index + 1] += bins[2 * i + 1]; + } + } + result.insert(result.end(), slots, slots + 2 * max_slot); + } + remaining -= size; + pointer += size; + } + free(slots); + + // result is a reference to a histo_ + *out_len = result.size(); + *out = result.data(); + if (debug_) { + std::cout << "Total histogram size: " << *out_len << std::endl; + } +} + +void NvflarePlugin::BuildEncryptedHistHori(double const *in_histogram, + std::size_t len, + std::uint8_t **out_hist, + std::size_t *out_len) { + if (debug_) { + std::cout << Ident() << " NvflarePlugin::BuildEncryptedHistHori called with histo size: " << len << std::endl; + } + + DamEncoder encoder(kDataSetHistograms, false, dam_debug_); + std::vector<double> copy(in_histogram, in_histogram + len); + encoder.AddFloatArray(copy); + + std::size_t size{0}; + auto buffer = encoder.Finish(size); + buffer_.resize(size); + std::copy_n(buffer, size, buffer_.begin()); + free(buffer); + + *out_hist = this->buffer_.data(); + *out_len = this->buffer_.size(); +} + +void NvflarePlugin::SyncEncryptedHistHori(std::uint8_t const *buffer, + std::size_t len, + double **out_hist, + std::size_t *out_len) { + if (debug_) { + std::cout << Ident() << " NvflarePlugin::SyncEncryptedHistHori called with buffer size: " << len << std::endl; + } + + auto remaining = len; + auto pointer = buffer; + + // The buffer is concatenated by AllGather. It may contain multiple DAM buffers + std::vector<double>& result = histo_; + result.clear(); + while (remaining > kPrefixLen) { + DamDecoder decoder(const_cast<std::uint8_t *>(pointer), remaining, false, dam_debug_); + if (!decoder.IsValid()) { + std::cout << "Not DAM encoded histogram ignored at offset: " + << static_cast<int>(pointer - buffer) << std::endl; + break; + } + + if (decoder.GetDataSetId() != kDataSetHistogramResult) { + throw std::runtime_error{"Invalid dataset: " + std::to_string(decoder.GetDataSetId())}; + } + + auto size = decoder.Size(); + auto histo = decoder.DecodeFloatArray(); + result.insert(result.end(), histo.cbegin(), histo.cend()); + + remaining -= size; + pointer += size; + } + + *out_hist = result.data(); + *out_len = result.size(); +} + +} // namespace nvflare diff --git a/integration/xgboost/encryption_plugins/src/plugins/pass_thru_plugin.cc b/integration/xgboost/encryption_plugins/src/plugins/pass_thru_plugin.cc new file mode 100644 index 0000000000..4a29d0ed2b --- /dev/null +++ b/integration/xgboost/encryption_plugins/src/plugins/pass_thru_plugin.cc @@ -0,0 +1,130 @@ +/** + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include <iostream> +#include <algorithm> + +#include "pass_thru_plugin.h" +#include "data_set_ids.h" + +namespace nvflare { + +void PassThruPlugin::BuildEncryptedHistHori(const double *in_histogram, std::size_t len, + std::uint8_t **out_hist, std::size_t *out_len) { + if (debug_) { + std::cout << Ident() << " PassThruPlugin::BuildEncryptedHistHori called with " << len << " entries" << std::endl; + } + + DamEncoder encoder(kDataSetHistogramResult, true, dam_debug_); + auto array = std::vector<double>(in_histogram, in_histogram + len); + encoder.AddFloatArray(array); + std::size_t size; + auto buffer = encoder.Finish(size); + buffer_.resize(size); + std::copy_n(buffer, size, buffer_.begin()); + free(buffer); + *out_hist = buffer_.data(); + *out_len = buffer_.size(); +} + +void PassThruPlugin::SyncEncryptedHistHori(const std::uint8_t *buffer, std::size_t len, + double **out_hist, std::size_t *out_len) { + if (debug_) { + std::cout << Ident() << " PassThruPlugin::SyncEncryptedHistHori called with buffer size: " << len << std::endl; + } + + auto remaining = len; + auto pointer = buffer; + + // The buffer is concatenated by AllGather. It may contain multiple DAM buffers + std::vector<double>& result = histo_; + result.clear(); + while (remaining > kPrefixLen) { + DamDecoder decoder(const_cast<std::uint8_t *>(pointer), remaining, true, dam_debug_); + if (!decoder.IsValid()) { + std::cout << "Not DAM encoded histogram ignored at offset: " + << static_cast<int>(pointer - buffer) << std::endl; + break; + } + auto size = decoder.Size(); + auto histo = decoder.DecodeFloatArray(); + result.insert(result.end(), histo.cbegin(), histo.cend()); + + remaining -= size; + pointer += size; + } + + *out_hist = result.data(); + *out_len = result.size(); +} + +Buffer PassThruPlugin::EncryptVector(const std::vector<double>& cleartext) { + if (debug_ && cleartext.size() > 2) { + std::cout << "PassThruPlugin::EncryptVector called with cleartext size: " << cleartext.size() << std::endl; + } + + size_t size = cleartext.size() * sizeof(double); + auto buf = static_cast<std::uint8_t *>(malloc(size)); + std::copy_n(reinterpret_cast<std::uint8_t const*>(cleartext.data()), size, buf); + + return {buf, size, true}; +} + +std::vector<double> PassThruPlugin::DecryptVector(const std::vector<Buffer>& ciphertext) { + if (debug_) { + std::cout << "PassThruPlugin::DecryptVector with ciphertext size: " << ciphertext.size() << std::endl; + } + + std::vector<double> result; + + for (auto const &v : ciphertext) { + size_t n = v.buf_size/sizeof(double); + auto p = static_cast<double *>(v.buffer); + for (int i = 0; i < n; i++) { + result.push_back(p[i]); + } + } + + return result; +} + +std::map<int, Buffer> PassThruPlugin::AddGHPairs(const std::map<int, std::vector<int>>& sample_ids) { + if (debug_) { + std::cout << "PassThruPlugin::AddGHPairs called with " << sample_ids.size() << " slots" << std::endl; + } + + // Can't do this in real plugin. It needs to be broken into encrypted parts + auto gh_pairs = DecryptVector(std::vector<Buffer>{Buffer(encrypted_gh_.data(), encrypted_gh_.size())}); + + auto result = std::map<int, Buffer>(); + for (auto const &entry : sample_ids) { + auto rows = entry.second; + double g = 0.0; + double h = 0.0; + + for (auto row : rows) { + g += gh_pairs[2 * row]; + h += gh_pairs[2 * row + 1]; + } + // In real plugin, the sum should be still in encrypted state. No need to do this step + auto encrypted_sum = EncryptVector(std::vector<double>{g, h}); + // print_buffer(reinterpret_cast<uint8_t *>(encrypted_sum.buffer), encrypted_sum.buf_size); + result.insert({entry.first, encrypted_sum}); + } + + return result; +} + +} // namespace nvflare diff --git a/integration/xgboost/encryption_plugins/src/plugins/plugin_main.cc b/integration/xgboost/encryption_plugins/src/plugins/plugin_main.cc new file mode 100644 index 0000000000..4c1d43a6f8 --- /dev/null +++ b/integration/xgboost/encryption_plugins/src/plugins/plugin_main.cc @@ -0,0 +1,184 @@ +/** + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include <memory> // for shared_ptr +#include <stdexcept> // for invalid_argument +#include <string_view> // for string_view +#include <vector> // for vector +#include <algorithm> // for transform + +#include "delegated_plugin.h" + +// Opaque pointer type for the C API. +typedef void *FederatedPluginHandle; // NOLINT + +namespace nvflare { +namespace { +// The opaque type for the C handle. +using CHandleT = std::shared_ptr<BasePlugin> *; +// Actual representation used in C++ code base. +using HandleT = std::remove_pointer_t<CHandleT>; + +std::string &GlobalErrorMsg() { + static thread_local std::string msg; + return msg; +} + +// Perform handle handling for C API functions. +template <typename Fn> auto CApiGuard(FederatedPluginHandle handle, Fn &&fn) { + auto pptr = static_cast<CHandleT>(handle); + if (!pptr) { + return 1; + } + + try { + if constexpr (std::is_void_v<std::invoke_result_t<Fn, decltype(*pptr)>>) { + fn(*pptr); + return 0; + } else { + return fn(*pptr); + } + } catch (std::exception const &e) { + GlobalErrorMsg() = e.what(); + return 1; + } +} +} // namespace +} // namespace nvflare + +#if defined(_MSC_VER) || defined(_WIN32) +#define NVF_C __declspec(dllexport) +#else +#define NVF_C __attribute__((visibility("default"))) +#endif // defined(_MSC_VER) || defined(_WIN32) + +extern "C" { +NVF_C char const *FederatedPluginErrorMsg() { + return nvflare::GlobalErrorMsg().c_str(); +} + +FederatedPluginHandle NVF_C FederatedPluginCreate(int argc, char const **argv) { + // std::cout << "==== FedreatedPluginCreate called with argc=" << argc << std::endl; + using namespace nvflare; + try { + auto pptr = new std::shared_ptr<BasePlugin>; + std::vector<std::pair<std::string_view, std::string_view>> args; + std::transform( + argv, argv + argc, std::back_inserter(args), [](char const *carg) { + // Split a key value pair in contructor argument: `key=value` + std::string_view arg{carg}; + auto idx = arg.find('='); + if (idx == std::string_view::npos) { + // `=` not found + throw std::invalid_argument{"Invalid argument:" + std::string{arg}}; + } + auto key = arg.substr(0, idx); + auto value = arg.substr(idx + 1); + return std::make_pair(key, value); + }); + *pptr = std::make_shared<DelegatedPlugin>(args); + // std::cout << "==== Plugin created: " << pptr << std::endl; + return pptr; + } catch (std::exception const &e) { + // std::cout << "==== Create exception " << e.what() << std::endl; + GlobalErrorMsg() = e.what(); + return nullptr; + } +} + +int NVF_C FederatedPluginClose(FederatedPluginHandle handle) { + using namespace nvflare; + auto pptr = static_cast<CHandleT>(handle); + if (!pptr) { + return 1; + } + + delete pptr; + + return 0; +} + +int NVF_C FederatedPluginEncryptGPairs(FederatedPluginHandle handle, + float const *in_gpair, size_t n_in, + uint8_t **out_gpair, size_t *n_out) { + using namespace nvflare; + return CApiGuard(handle, [&](HandleT const &plugin) { + plugin->EncryptGPairs(in_gpair, n_in, out_gpair, n_out); + return 0; + }); +} + +int NVF_C FederatedPluginSyncEncryptedGPairs(FederatedPluginHandle handle, + uint8_t const *in_gpair, + size_t n_bytes, + uint8_t const **out_gpair, + size_t *n_out) { + using namespace nvflare; + return CApiGuard(handle, [&](HandleT const &plugin) { + plugin->SyncEncryptedGPairs(in_gpair, n_bytes, out_gpair, n_out); + }); +} + +int NVF_C FederatedPluginResetHistContextVert(FederatedPluginHandle handle, + uint32_t const *cutptrs, + size_t cutptr_len, + int32_t const *bin_idx, + size_t n_idx) { + using namespace nvflare; + return CApiGuard(handle, [&](HandleT const &plugin) { + plugin->ResetHistContext(cutptrs, cutptr_len, bin_idx, n_idx); + }); +} + +int NVF_C FederatedPluginBuildEncryptedHistVert( + FederatedPluginHandle handle, uint64_t const **ridx, size_t const *sizes, + int32_t const *nidx, size_t len, uint8_t **out_hist, size_t *out_len) { + using namespace nvflare; + return CApiGuard(handle, [&](HandleT const &plugin) { + plugin->BuildEncryptedHistVert(ridx, sizes, nidx, len, out_hist, out_len); + }); +} + +int NVF_C FederatedPluginSyncEncryptedHistVert(FederatedPluginHandle handle, + uint8_t *in_hist, size_t len, + double **out_hist, + size_t *out_len) { + using namespace nvflare; + return CApiGuard(handle, [&](HandleT const &plugin) { + plugin->SyncEncryptedHistVert(in_hist, len, out_hist, out_len); + }); +} + +int NVF_C FederatedPluginBuildEncryptedHistHori(FederatedPluginHandle handle, + double const *in_hist, + size_t len, uint8_t **out_hist, + size_t *out_len) { + using namespace nvflare; + return CApiGuard(handle, [&](HandleT const &plugin) { + plugin->BuildEncryptedHistHori(in_hist, len, out_hist, out_len); + }); +} + +int NVF_C FederatedPluginSyncEncryptedHistHori(FederatedPluginHandle handle, + uint8_t const *in_hist, + size_t len, double **out_hist, + size_t *out_len) { + using namespace nvflare; + return CApiGuard(handle, [&](HandleT const &plugin) { + plugin->SyncEncryptedHistHori(in_hist, len, out_hist, out_len); + return 0; + }); +} +} // extern "C" diff --git a/integration/xgboost/encryption_plugins/src/plugins/util.cc b/integration/xgboost/encryption_plugins/src/plugins/util.cc new file mode 100644 index 0000000000..a0cbd922d4 --- /dev/null +++ b/integration/xgboost/encryption_plugins/src/plugins/util.cc @@ -0,0 +1,99 @@ +/** + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include <iostream> +#include <set> +#include <algorithm> +#include "util.h" + + +constexpr double kScaleFactor = 1000000.0; + +std::vector<std::pair<int, int>> distribute_work(size_t num_jobs, size_t const num_workers) { + std::vector<std::pair<int, int>> result; + auto num = num_jobs / num_workers; + auto remainder = num_jobs % num_workers; + int start = 0; + for (int i = 0; i < num_workers; i++) { + auto stop = static_cast<int>((start + num - 1)); + if (i < remainder) { + // If jobs cannot be evenly distributed, first few workers take an extra one + stop += 1; + } + + if (start <= stop) { + result.emplace_back(start, stop); + } + start = stop + 1; + } + + // Verify all jobs are distributed + int sum = 0; + for (auto &item: result) { + sum += item.second - item.first + 1; + } + + if (sum != num_jobs) { + std::cout << "Distribution error" << std::endl; + } + + return result; +} + +uint32_t to_int(double d) { + auto int_val = static_cast<int32_t>(d * kScaleFactor); + return static_cast<uint32_t>(int_val); +} + +double to_double(uint32_t i) { + auto int_val = static_cast<int32_t>(i); + return static_cast<double>(int_val / kScaleFactor); +} + +std::string get_string(std::vector<std::pair<std::string_view, std::string_view>> const &args, + std::string_view const &key, std::string_view const default_value) { + + auto it = find_if( + args.begin(), args.end(), + [key](const auto &p) { return p.first == key; }); + + if (it != args.end()) { + return std::string{it->second}; + } + + return std::string{default_value}; +} + +bool get_bool(std::vector<std::pair<std::string_view, std::string_view>> const &args, + const std::string &key, bool default_value) { + std::string value = get_string(args, key, ""); + if (value.empty()) { + return default_value; + } + std::transform(value.begin(), value.end(), value.begin(), [](unsigned char c) { return std::tolower(c); }); + auto true_values = std::set < std::string_view > {"true", "yes", "y", "on", "1"}; + return true_values.count(value) > 0; +} + +int get_int(std::vector<std::pair<std::string_view, std::string_view>> const &args, + const std::string &key, int default_value) { + + auto value = get_string(args, key, ""); + if (value.empty()) { + return default_value; + } + + return stoi(value, nullptr); +} diff --git a/integration/xgboost/encryption_plugins/tests/CMakeLists.txt b/integration/xgboost/encryption_plugins/tests/CMakeLists.txt new file mode 100644 index 0000000000..04580bdd59 --- /dev/null +++ b/integration/xgboost/encryption_plugins/tests/CMakeLists.txt @@ -0,0 +1,14 @@ +file(GLOB_RECURSE TEST_SOURCES "*.cc") + +target_sources(xgb_nvflare_test PRIVATE ${TEST_SOURCES}) + +target_include_directories(xgb_nvflare_test + PRIVATE + ${GTEST_INCLUDE_DIRS} + ${xgb_nvflare_SOURCE_DIR/tests} + ${xgb_nvflare_SOURCE_DIR}/src) + +message("Include Dir: ${GTEST_INCLUDE_DIRS}") +target_link_libraries(xgb_nvflare_test + PRIVATE + ${GTEST_LIBRARIES}) diff --git a/integration/xgboost/processor/tests/test_dam.cc b/integration/xgboost/encryption_plugins/tests/test_dam.cc similarity index 65% rename from integration/xgboost/processor/tests/test_dam.cc rename to integration/xgboost/encryption_plugins/tests/test_dam.cc index 5573d5440d..345978b110 100644 --- a/integration/xgboost/processor/tests/test_dam.cc +++ b/integration/xgboost/encryption_plugins/tests/test_dam.cc @@ -19,20 +19,45 @@ TEST(DamTest, TestEncodeDecode) { double float_array[] = {1.1, 1.2, 1.3, 1.4}; int64_t int_array[] = {123, 456, 789}; + char buf1[] = "short"; + char buf2[] = "very long"; + DamEncoder encoder(123); + auto b1 = Buffer(buf1, strlen(buf1)); + auto b2 = Buffer(buf2, strlen(buf2)); + encoder.AddBuffer(b1); + encoder.AddBuffer(b2); + + std::vector<Buffer> b{b1, b2}; + encoder.AddBufferArray(b); + auto f = std::vector<double>(float_array, float_array + 4); encoder.AddFloatArray(f); + auto i = std::vector<int64_t>(int_array, int_array + 3); encoder.AddIntArray(i); + size_t size; auto buf = encoder.Finish(size); std::cout << "Encoded size is " << size << std::endl; - DamDecoder decoder(buf.data(), size); + // Decoding test + DamDecoder decoder(buf, size); EXPECT_EQ(decoder.IsValid(), true); EXPECT_EQ(decoder.GetDataSetId(), 123); + auto new_buf1 = decoder.DecodeBuffer(); + EXPECT_EQ(0, memcmp(new_buf1.buffer, buf1, new_buf1.buf_size)); + + auto new_buf2 = decoder.DecodeBuffer(); + EXPECT_EQ(0, memcmp(new_buf2.buffer, buf2, new_buf2.buf_size)); + + auto buf_vec = decoder.DecodeBufferArray(); + EXPECT_EQ(2, buf_vec.size()); + EXPECT_EQ(0, memcmp(buf_vec[0].buffer, buf1, buf_vec[0].buf_size)); + EXPECT_EQ(0, memcmp(buf_vec[1].buffer, buf2, buf_vec[1].buf_size)); + auto float_vec = decoder.DecodeFloatArray(); EXPECT_EQ(0, memcmp(float_vec.data(), float_array, float_vec.size()*8)); diff --git a/integration/xgboost/processor/tests/test_main.cc b/integration/xgboost/encryption_plugins/tests/test_main.cc similarity index 100% rename from integration/xgboost/processor/tests/test_main.cc rename to integration/xgboost/encryption_plugins/tests/test_main.cc diff --git a/integration/xgboost/processor/tests/test_tenseal.py b/integration/xgboost/encryption_plugins/tests/test_tenseal.py similarity index 100% rename from integration/xgboost/processor/tests/test_tenseal.py rename to integration/xgboost/encryption_plugins/tests/test_tenseal.py diff --git a/integration/xgboost/processor/CMakeLists.txt b/integration/xgboost/processor/CMakeLists.txt deleted file mode 100644 index 056fd365e2..0000000000 --- a/integration/xgboost/processor/CMakeLists.txt +++ /dev/null @@ -1,46 +0,0 @@ -cmake_minimum_required(VERSION 3.19) -project(proc_nvflare LANGUAGES CXX C VERSION 1.0) -set(CMAKE_CXX_STANDARD 17) -set(CMAKE_BUILD_TYPE Debug) - -option(GOOGLE_TEST "Build google tests" OFF) - -file(GLOB_RECURSE LIB_SRC "src/*.cc") - -add_library(proc_nvflare SHARED ${LIB_SRC}) -set_target_properties(proc_nvflare PROPERTIES - CXX_STANDARD 17 - CXX_STANDARD_REQUIRED ON - POSITION_INDEPENDENT_CODE ON - ENABLE_EXPORTS ON -) -target_include_directories(proc_nvflare PRIVATE ${proc_nvflare_SOURCE_DIR}/src/include) - -if (APPLE) - add_link_options("LINKER:-object_path_lto,$<TARGET_PROPERTY:NAME>_lto.o") - add_link_options("LINKER:-cache_path_lto,${CMAKE_BINARY_DIR}/LTOCache") -endif () - -#-- Unit Tests -if(GOOGLE_TEST) - find_package(GTest REQUIRED) - enable_testing() - add_executable(proc_test) - target_link_libraries(proc_test PRIVATE proc_nvflare) - - - target_include_directories(proc_test PRIVATE ${proc_nvflare_SOURCE_DIR}/src/include - ${XGB_SRC}/src - ${XGB_SRC}/rabit/include - ${XGB_SRC}/include - ${XGB_SRC}/dmlc-core/include - ${XGB_SRC}/tests) - - add_subdirectory(${proc_nvflare_SOURCE_DIR}/tests) - - add_test( - NAME TestProcessor - COMMAND proc_test - WORKING_DIRECTORY ${proc_nvflare_BINARY_DIR}) - -endif() diff --git a/integration/xgboost/processor/README.md b/integration/xgboost/processor/README.md deleted file mode 100644 index e879081b84..0000000000 --- a/integration/xgboost/processor/README.md +++ /dev/null @@ -1,11 +0,0 @@ -# Build Instruction - -``` sh -cd NVFlare/integration/xgboost/processor -mkdir build -cd build -cmake .. -make -``` - -See [tests](./tests) for simple examples. \ No newline at end of file diff --git a/integration/xgboost/processor/src/dam/dam.cc b/integration/xgboost/processor/src/dam/dam.cc deleted file mode 100644 index 10625ab9b5..0000000000 --- a/integration/xgboost/processor/src/dam/dam.cc +++ /dev/null @@ -1,146 +0,0 @@ -/** - * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include <iostream> -#include <cstring> -#include "dam.h" - -void print_buffer(uint8_t *buffer, int size) { - for (int i = 0; i < size; i++) { - auto c = buffer[i]; - std::cout << std::hex << (int) c << " "; - } - std::cout << std::endl << std::dec; -} - -// DamEncoder ====== -void DamEncoder::AddFloatArray(const std::vector<double> &value) { - if (encoded) { - std::cout << "Buffer is already encoded" << std::endl; - return; - } - auto buf_size = value.size() * 8; - uint8_t *buffer = static_cast<uint8_t *>(malloc(buf_size)); - memcpy(buffer, value.data(), buf_size); - entries->push_back(new Entry(kDataTypeFloatArray, buffer, value.size())); -} - -void DamEncoder::AddIntArray(const std::vector<int64_t> &value) { - std::cout << "AddIntArray called, size: " << value.size() << std::endl; - if (encoded) { - std::cout << "Buffer is already encoded" << std::endl; - return; - } - auto buf_size = value.size()*8; - std::cout << "Allocating " << buf_size << " bytes" << std::endl; - uint8_t *buffer = static_cast<uint8_t *>(malloc(buf_size)); - memcpy(buffer, value.data(), buf_size); - // print_buffer(buffer, buf_size); - entries->push_back(new Entry(kDataTypeIntArray, buffer, value.size())); -} - -std::vector<std::uint8_t> DamEncoder::Finish(size_t &size) { - encoded = true; - - size = calculate_size(); - std::vector<std::uint8_t> buf(size); - auto pointer = buf.data(); - memcpy(pointer, kSignature, strlen(kSignature)); - memcpy(pointer + 8, &size, 8); - memcpy(pointer + 16, &data_set_id, 8); - - pointer += kPrefixLen; - for (auto entry : *entries) { - memcpy(pointer, &entry->data_type, 8); - pointer += 8; - memcpy(pointer, &entry->size, 8); - pointer += 8; - int len = 8*entry->size; - memcpy(pointer, entry->pointer, len); - free(entry->pointer); - pointer += len; - // print_buffer(entry->pointer, entry->size*8); - } - - if ((pointer - buf.data()) != size) { - throw std::runtime_error{"Invalid encoded size: " + - std::to_string(pointer - buf.data())}; - } - - return buf; -} - -std::size_t DamEncoder::calculate_size() { - auto size = kPrefixLen; - - for (auto entry : *entries) { - size += 16; // The Type and Len - size += entry->size * 8; // All supported data types are 8 bytes - } - - return size; -} - - -// DamDecoder ====== - -DamDecoder::DamDecoder(std::uint8_t const *buffer, std::size_t size) { - this->buffer = buffer; - this->buf_size = size; - this->pos = buffer + kPrefixLen; - if (size >= kPrefixLen) { - memcpy(&len, buffer + 8, 8); - memcpy(&data_set_id, buffer + 16, 8); - } else { - len = 0; - data_set_id = 0; - } -} - -bool DamDecoder::IsValid() { - return buf_size >= kPrefixLen && memcmp(buffer, kSignature, strlen(kSignature)) == 0; -} - -std::vector<int64_t> DamDecoder::DecodeIntArray() { - auto type = *reinterpret_cast<int64_t const*>(pos); - if (type != kDataTypeIntArray) { - std::cout << "Data type " << type << " doesn't match Int Array" - << std::endl; - return std::vector<int64_t>(); - } - pos += 8; - - auto len = *reinterpret_cast<int64_t const *>(pos); - pos += 8; - auto ptr = reinterpret_cast<int64_t const *>(pos); - pos += 8 * len; - return std::vector<int64_t>(ptr, ptr + len); -} - -std::vector<double> DamDecoder::DecodeFloatArray() { - auto type = *reinterpret_cast<int64_t const*>(pos); - if (type != kDataTypeFloatArray) { - std::cout << "Data type " << type << " doesn't match Float Array" << std::endl; - return std::vector<double>(); - } - pos += 8; - - auto len = *reinterpret_cast<int64_t const *>(pos); - pos += 8; - - auto ptr = reinterpret_cast<double const *>(pos); - pos += 8*len; - return std::vector<double>(ptr, ptr + len); -} diff --git a/integration/xgboost/processor/src/include/dam.h b/integration/xgboost/processor/src/include/dam.h deleted file mode 100644 index 7afdf983af..0000000000 --- a/integration/xgboost/processor/src/include/dam.h +++ /dev/null @@ -1,93 +0,0 @@ -/** - * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once -#include <vector> -#include <cstdint> // for int64_t -#include <cstddef> // for size_t - -const char kSignature[] = "NVDADAM1"; // DAM (Direct Accessible Marshalling) V1 -const int kPrefixLen = 24; - -const int kDataTypeInt = 1; -const int kDataTypeFloat = 2; -const int kDataTypeString = 3; -const int kDataTypeIntArray = 257; -const int kDataTypeFloatArray = 258; - -const int kDataTypeMap = 1025; - -class Entry { - public: - int64_t data_type; - uint8_t * pointer; - int64_t size; - - Entry(int64_t data_type, uint8_t *pointer, int64_t size) { - this->data_type = data_type; - this->pointer = pointer; - this->size = size; - } -}; - -class DamEncoder { - private: - bool encoded = false; - int64_t data_set_id; - std::vector<Entry *> *entries = new std::vector<Entry *>(); - - public: - explicit DamEncoder(int64_t data_set_id) { - this->data_set_id = data_set_id; - } - - void AddIntArray(const std::vector<int64_t> &value); - - void AddFloatArray(const std::vector<double> &value); - - std::vector<std::uint8_t> Finish(size_t &size); - - private: - std::size_t calculate_size(); -}; - -class DamDecoder { - private: - std::uint8_t const *buffer = nullptr; - std::size_t buf_size = 0; - std::uint8_t const *pos = nullptr; - std::size_t remaining = 0; - int64_t data_set_id = 0; - int64_t len = 0; - - public: - explicit DamDecoder(std::uint8_t const *buffer, std::size_t size); - - size_t Size() { - return len; - } - - int64_t GetDataSetId() { - return data_set_id; - } - - bool IsValid(); - - std::vector<int64_t> DecodeIntArray(); - - std::vector<double> DecodeFloatArray(); -}; - -void print_buffer(uint8_t *buffer, int size); diff --git a/integration/xgboost/processor/src/nvflare-plugin/nvflare_processor.cc b/integration/xgboost/processor/src/nvflare-plugin/nvflare_processor.cc deleted file mode 100644 index 3e742b14ef..0000000000 --- a/integration/xgboost/processor/src/nvflare-plugin/nvflare_processor.cc +++ /dev/null @@ -1,378 +0,0 @@ -/** - * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "nvflare_processor.h" - -#include "dam.h" // for DamEncoder -#include <iostream> -#include <algorithm> // for copy_n, transform -#include <cstring> // for memcpy -#include <memory> // for shared_ptr -#include <stdexcept> // for invalid_argument -#include <string_view> // for string_view -#include <vector> // for vector - -namespace nvflare { -namespace { -// The opaque type for the C handle. -using CHandleT = std::shared_ptr<TensealPlugin> *; -// Actual representation used in C++ code base. -using HandleT = std::remove_pointer_t<CHandleT>; - -std::string &GlobalErrorMsg() { - static thread_local std::string msg; - return msg; -} - -// Perform handle handling for C API functions. -template <typename Fn> auto CApiGuard(FederatedPluginHandle handle, Fn &&fn) { - auto pptr = static_cast<CHandleT>(handle); - if (!pptr) { - return 1; - } - - try { - if constexpr (std::is_void_v<std::invoke_result_t<Fn, decltype(*pptr)>>) { - fn(*pptr); - return 0; - } else { - return fn(*pptr); - } - } catch (std::exception const &e) { - GlobalErrorMsg() = e.what(); - return 1; - } -} -} // namespace - -TensealPlugin::TensealPlugin( - std::vector<std::pair<std::string_view, std::string_view>> const &args) { - if (!args.empty()) { - throw std::invalid_argument{"Invaid arguments for the tenseal plugin."}; - } -} - -void TensealPlugin::EncryptGPairs(float const *in_gpair, std::size_t n_in, - std::uint8_t **out_gpair, - std::size_t *n_out) { - std::vector<double> pairs(n_in); - std::copy_n(in_gpair, n_in, pairs.begin()); - DamEncoder encoder(kDataSetHGPairs); - encoder.AddFloatArray(pairs); - encrypted_gpairs_ = encoder.Finish(*n_out); - if (!out_gpair) { - throw std::invalid_argument{"Invalid pointer to output gpair."}; - } - *out_gpair = encrypted_gpairs_.data(); - *n_out = encrypted_gpairs_.size(); -} - -void TensealPlugin::SyncEncryptedGPairs(std::uint8_t const *in_gpair, - std::size_t n_bytes, - std::uint8_t const **out_gpair, - std::size_t *out_n_bytes) { - *out_n_bytes = n_bytes; - *out_gpair = in_gpair; -} - -void TensealPlugin::ResetHistContext(std::uint32_t const *cutptrs, - std::size_t cutptr_len, - std::int32_t const *bin_idx, - std::size_t n_idx) { - // fixme: this doesn't have to be copied multiple times. - this->cut_ptrs_.resize(cutptr_len); - std::copy_n(cutptrs, cutptr_len, cut_ptrs_.begin()); - this->bin_idx_.resize(n_idx); - std::copy_n(bin_idx, n_idx, this->bin_idx_.begin()); -} - -void TensealPlugin::BuildEncryptedHistVert(std::size_t const **ridx, - std::size_t const *sizes, - std::int32_t const *nidx, - std::size_t len, - std::uint8_t** out_hist, - std::size_t* out_len) { - std::int64_t data_set_id; - if (!feature_sent_) { - data_set_id = kDataSetAggregationWithFeatures; - feature_sent_ = true; - } else { - data_set_id = kDataSetAggregation; - } - - DamEncoder encoder(data_set_id); - - // Add cuts pointers - std::vector<int64_t> cuts_vec(cut_ptrs_.cbegin(), cut_ptrs_.cend()); - encoder.AddIntArray(cuts_vec); - - auto num_features = cut_ptrs_.size() - 1; - auto num_samples = bin_idx_.size() / num_features; - - if (data_set_id == kDataSetAggregationWithFeatures) { - if (features_.empty()) { // when is it not empty? - for (std::size_t f = 0; f < num_features; f++) { - auto slot = bin_idx_[f]; - if (slot >= 0) { - // what happens if it's missing? - features_.push_back(f); - } - } - } - encoder.AddIntArray(features_); - - std::vector<int64_t> bins; - for (int i = 0; i < num_samples; i++) { - for (auto f : features_) { - auto index = f + i * num_features; - if (index > bin_idx_.size()) { - throw std::out_of_range{"Index is out of range: " + - std::to_string(index)}; - } - auto slot = bin_idx_[index]; - bins.push_back(slot); - } - } - encoder.AddIntArray(bins); - } - - // Add nodes to build - std::vector<int64_t> node_vec(len); - std::copy_n(nidx, len, node_vec.begin()); - encoder.AddIntArray(node_vec); - - // For each node, get the row_id/slot pair - for (std::size_t i = 0; i < len; ++i) { - std::vector<int64_t> rows(sizes[i]); - std::copy_n(ridx[i], sizes[i], rows.begin()); - encoder.AddIntArray(rows); - } - - std::size_t n{0}; - encrypted_hist_ = encoder.Finish(n); - - *out_hist = encrypted_hist_.data(); - *out_len = encrypted_hist_.size(); -} - -void TensealPlugin::SyncEncryptedHistVert(std::uint8_t *buffer, - std::size_t buf_size, double **out, - std::size_t *out_len) { - auto remaining = buf_size; - char *pointer = reinterpret_cast<char *>(buffer); - - // The buffer is concatenated by AllGather. It may contain multiple DAM - // buffers - std::vector<double> &result = hist_; - result.clear(); - auto max_slot = cut_ptrs_.back(); - auto array_size = 2 * max_slot * sizeof(double); - // A new histogram array? - double *slots = static_cast<double *>(malloc(array_size)); - while (remaining > kPrefixLen) { - DamDecoder decoder(reinterpret_cast<uint8_t *>(pointer), remaining); - if (!decoder.IsValid()) { - std::cout << "Not DAM encoded buffer ignored at offset: " - << static_cast<int>( - (pointer - reinterpret_cast<char *>(buffer))) - << std::endl; - break; - } - auto size = decoder.Size(); - auto node_list = decoder.DecodeIntArray(); - for (auto node : node_list) { - std::memset(slots, 0, array_size); - auto feature_list = decoder.DecodeIntArray(); - // Convert per-feature histo to a flat one - for (auto f : feature_list) { - auto base = cut_ptrs_[f]; // cut pointer for the current feature - auto bins = decoder.DecodeFloatArray(); - auto n = bins.size() / 2; - for (int i = 0; i < n; i++) { - auto index = base + i; - // [Q] Build local histogram? Why does it need to be built here? - slots[2 * index] += bins[2 * i]; - slots[2 * index + 1] += bins[2 * i + 1]; - } - } - result.insert(result.end(), slots, slots + 2 * max_slot); - } - remaining -= size; - pointer += size; - } - free(slots); - - *out_len = result.size(); - *out = result.data(); -} - -void TensealPlugin::BuildEncryptedHistHori(double const *in_histogram, - std::size_t len, - std::uint8_t **out_hist, - std::size_t *out_len) { - DamEncoder encoder(kDataSetHistograms); - std::vector<double> copy(in_histogram, in_histogram + len); - encoder.AddFloatArray(copy); - - std::size_t size{0}; - this->encrypted_hist_ = encoder.Finish(size); - - *out_hist = this->encrypted_hist_.data(); - *out_len = this->encrypted_hist_.size(); -} - -void TensealPlugin::SyncEncryptedHistHori(std::uint8_t const *buffer, - std::size_t len, double **out_hist, - std::size_t *out_len) { - DamDecoder decoder(reinterpret_cast<uint8_t const *>(buffer), len); - if (!decoder.IsValid()) { - std::cout << "Not DAM encoded buffer, ignored" << std::endl; - } - - if (decoder.GetDataSetId() != kDataSetHistogramResult) { - throw std::runtime_error{"Invalid dataset: " + - std::to_string(decoder.GetDataSetId())}; - } - this->hist_ = decoder.DecodeFloatArray(); - *out_hist = this->hist_.data(); - *out_len = this->hist_.size(); -} -} // namespace nvflare - -#if defined(_MSC_VER) || defined(_WIN32) -#define NVF_C __declspec(dllexport) -#else -#define NVF_C __attribute__((visibility("default"))) -#endif // defined(_MSC_VER) || defined(_WIN32) - -extern "C" { -NVF_C char const *FederatedPluginErrorMsg() { - return nvflare::GlobalErrorMsg().c_str(); -} - -FederatedPluginHandle NVF_C FederatedPluginCreate(int argc, char const **argv) { - using namespace nvflare; - try { - CHandleT pptr = new std::shared_ptr<TensealPlugin>; - std::vector<std::pair<std::string_view, std::string_view>> args; - std::transform( - argv, argv + argc, std::back_inserter(args), [](char const *carg) { - // Split a key value pair in contructor argument: `key=value` - std::string_view arg{carg}; - auto idx = arg.find('='); - if (idx == std::string_view::npos) { - // `=` not found - throw std::invalid_argument{"Invalid argument:" + std::string{arg}}; - } - auto key = arg.substr(0, idx); - auto value = arg.substr(idx + 1); - return std::make_pair(key, value); - }); - *pptr = std::make_shared<TensealPlugin>(args); - return pptr; - } catch (std::exception const &e) { - GlobalErrorMsg() = e.what(); - return nullptr; - } -} - -int NVF_C FederatedPluginClose(FederatedPluginHandle handle) { - using namespace nvflare; - auto pptr = static_cast<CHandleT>(handle); - if (!pptr) { - return 1; - } - - try { - delete pptr; - } catch (std::exception const &e) { - GlobalErrorMsg() = e.what(); - return 1; - } - return 0; -} - -int NVF_C FederatedPluginEncryptGPairs(FederatedPluginHandle handle, - float const *in_gpair, size_t n_in, - uint8_t **out_gpair, size_t *n_out) { - using namespace nvflare; - return CApiGuard(handle, [&](HandleT plugin) { - plugin->EncryptGPairs(in_gpair, n_in, out_gpair, n_out); - return 0; - }); -} - -int NVF_C FederatedPluginSyncEncryptedGPairs(FederatedPluginHandle handle, - uint8_t const *in_gpair, - size_t n_bytes, - uint8_t const **out_gpair, - size_t *n_out) { - using namespace nvflare; - return CApiGuard(handle, [&](HandleT plugin) { - plugin->SyncEncryptedGPairs(in_gpair, n_bytes, out_gpair, n_out); - }); -} - -int NVF_C FederatedPluginResetHistContextVert(FederatedPluginHandle handle, - uint32_t const *cutptrs, - size_t cutptr_len, - int32_t const *bin_idx, - size_t n_idx) { - using namespace nvflare; - return CApiGuard(handle, [&](HandleT plugin) { - plugin->ResetHistContext(cutptrs, cutptr_len, bin_idx, n_idx); - }); -} - -int NVF_C FederatedPluginBuildEncryptedHistVert( - FederatedPluginHandle handle, uint64_t const **ridx, size_t const *sizes, - int32_t const *nidx, size_t len, uint8_t **out_hist, size_t *out_len) { - using namespace nvflare; - return CApiGuard(handle, [&](HandleT plugin) { - plugin->BuildEncryptedHistVert(ridx, sizes, nidx, len, out_hist, out_len); - }); -} - -int NVF_C FederatedPluginSyncEnrcyptedHistVert(FederatedPluginHandle handle, - uint8_t *in_hist, size_t len, - double **out_hist, - size_t *out_len) { - using namespace nvflare; - return CApiGuard(handle, [&](HandleT plugin) { - plugin->SyncEncryptedHistVert(in_hist, len, out_hist, out_len); - }); -} - -int NVF_C FederatedPluginBuildEncryptedHistHori(FederatedPluginHandle handle, - double const *in_hist, - size_t len, uint8_t **out_hist, - size_t *out_len) { - using namespace nvflare; - return CApiGuard(handle, [&](HandleT plugin) { - plugin->BuildEncryptedHistHori(in_hist, len, out_hist, out_len); - }); -} - -int NVF_C FederatedPluginSyncEnrcyptedHistHori(FederatedPluginHandle handle, - uint8_t const *in_hist, - size_t len, double **out_hist, - size_t *out_len) { - using namespace nvflare; - return CApiGuard(handle, [&](HandleT plugin) { - plugin->SyncEncryptedHistHori(in_hist, len, out_hist, out_len); - return 0; - }); -} -} // extern "C" diff --git a/integration/xgboost/processor/tests/CMakeLists.txt b/integration/xgboost/processor/tests/CMakeLists.txt deleted file mode 100644 index 893d8738dc..0000000000 --- a/integration/xgboost/processor/tests/CMakeLists.txt +++ /dev/null @@ -1,14 +0,0 @@ -file(GLOB_RECURSE TEST_SOURCES "*.cc") - -target_sources(proc_test PRIVATE ${TEST_SOURCES}) - -target_include_directories(proc_test - PRIVATE - ${GTEST_INCLUDE_DIRS} - ${proc_nvflare_SOURCE_DIR/tests} - ${proc_nvflare_SOURCE_DIR}/src) - -message("Include Dir: ${GTEST_INCLUDE_DIRS}") -target_link_libraries(proc_test - PRIVATE - ${GTEST_LIBRARIES}) diff --git a/job_templates/cyclic_cc_pt/config_fed_client.conf b/job_templates/cyclic_cc_pt/config_fed_client.conf index db5f3ce318..f764c8837e 100644 --- a/job_templates/cyclic_cc_pt/config_fed_client.conf +++ b/job_templates/cyclic_cc_pt/config_fed_client.conf @@ -85,7 +85,7 @@ args { # the launcher will invoke the script - script = "python3 custom/{app_script} {app_config} " + script = "python3 -u custom/{app_script} {app_config} " # if launch_once is true, the SubprocessLauncher will launch once for the whole job # if launch_once is false, the SubprocessLauncher will launch a process for each task it receives from server launch_once = true diff --git a/job_templates/cyclic_pt/config_fed_client.conf b/job_templates/cyclic_pt/config_fed_client.conf index 2fd760d9d9..df8581a6d4 100644 --- a/job_templates/cyclic_pt/config_fed_client.conf +++ b/job_templates/cyclic_pt/config_fed_client.conf @@ -64,7 +64,7 @@ args { # the launcher will invoke the script - script = "python3 custom/{app_script} {app_config} " + script = "python3 -u custom/{app_script} {app_config} " # if launch_once is true, the SubprocessLauncher will launch once for the whole job # if launch_once is false, the SubprocessLauncher will launch a process for each task it receives from server launch_once = true diff --git a/job_templates/sag_cse_pt/config_fed_client.conf b/job_templates/sag_cse_pt/config_fed_client.conf index 83455cf91f..bc0a93f8f1 100644 --- a/job_templates/sag_cse_pt/config_fed_client.conf +++ b/job_templates/sag_cse_pt/config_fed_client.conf @@ -62,7 +62,7 @@ args { # the launcher will invoke the script - script = "python3 custom/{app_script} {app_config} " + script = "python3 -u custom/{app_script} {app_config} " # if launch_once is true, the SubprocessLauncher will launch once for the whole job # if launch_once is false, the SubprocessLauncher will launch a process for each task it receives from server launch_once = true diff --git a/job_templates/sag_gnn/app_1/config_fed_client.conf b/job_templates/sag_gnn/app_1/config_fed_client.conf index 81c6fdaf6d..b4a4f89680 100644 --- a/job_templates/sag_gnn/app_1/config_fed_client.conf +++ b/job_templates/sag_gnn/app_1/config_fed_client.conf @@ -64,7 +64,7 @@ args { # the launcher will invoke the script - script = "python3 custom/{app_script} {app_config} " + script = "python3 -u custom/{app_script} {app_config} " # if launch_once is true, the SubprocessLauncher will launch once for the whole job # if launch_once is false, the SubprocessLauncher will launch a process for each task it receives from server launch_once = true diff --git a/job_templates/sag_gnn/app_2/config_fed_client.conf b/job_templates/sag_gnn/app_2/config_fed_client.conf index 81c6fdaf6d..b4a4f89680 100644 --- a/job_templates/sag_gnn/app_2/config_fed_client.conf +++ b/job_templates/sag_gnn/app_2/config_fed_client.conf @@ -64,7 +64,7 @@ args { # the launcher will invoke the script - script = "python3 custom/{app_script} {app_config} " + script = "python3 -u custom/{app_script} {app_config} " # if launch_once is true, the SubprocessLauncher will launch once for the whole job # if launch_once is false, the SubprocessLauncher will launch a process for each task it receives from server launch_once = true diff --git a/job_templates/sag_np/config_fed_client.conf b/job_templates/sag_np/config_fed_client.conf index bfb440f7c6..6d0fbad6b5 100644 --- a/job_templates/sag_np/config_fed_client.conf +++ b/job_templates/sag_np/config_fed_client.conf @@ -67,7 +67,7 @@ args { # the launcher will invoke the script - script = "python3 custom/{app_script} {app_config} " + script = "python3 -u custom/{app_script} {app_config} " # if launch_once is true, the SubprocessLauncher will launch once for the whole job # if launch_once is false, the SubprocessLauncher will launch a process for each task it receives from server launch_once = true diff --git a/job_templates/sag_np_cell_pipe/config_fed_client.conf b/job_templates/sag_np_cell_pipe/config_fed_client.conf index ec02ec7042..7b5890a5f2 100644 --- a/job_templates/sag_np_cell_pipe/config_fed_client.conf +++ b/job_templates/sag_np_cell_pipe/config_fed_client.conf @@ -67,7 +67,7 @@ args { # the launcher will invoke the script - script = "python3 custom/{app_script} {app_config} " + script = "python3 -u custom/{app_script} {app_config} " # if launch_once is true, the SubprocessLauncher will launch once for the whole job # if launch_once is false, the SubprocessLauncher will launch a process for each task it receives from server launch_once = true diff --git a/job_templates/sag_np_metrics/config_fed_client.conf b/job_templates/sag_np_metrics/config_fed_client.conf index 7164f195ae..22d7b2fc0e 100644 --- a/job_templates/sag_np_metrics/config_fed_client.conf +++ b/job_templates/sag_np_metrics/config_fed_client.conf @@ -67,7 +67,7 @@ args { # the launcher will invoke the script - script = "python3 custom/{app_script} {app_config} " + script = "python3 -u custom/{app_script} {app_config} " # if launch_once is true, the SubprocessLauncher will launch once for the whole job # if launch_once is false, the SubprocessLauncher will launch a process for each task it receives from server launch_once = true diff --git a/job_templates/sag_pt/config_fed_client.conf b/job_templates/sag_pt/config_fed_client.conf index aaba629957..9d89605160 100644 --- a/job_templates/sag_pt/config_fed_client.conf +++ b/job_templates/sag_pt/config_fed_client.conf @@ -64,7 +64,7 @@ args { # the launcher will invoke the script - script = "python3 custom/{app_script} {app_config} " + script = "python3 -u custom/{app_script} {app_config} " # if launch_once is true, the SubprocessLauncher will launch once for the whole job # if launch_once is false, the SubprocessLauncher will launch a process for each task it receives from server launch_once = true diff --git a/job_templates/sag_pt_deploy_map/app_1/config_fed_client.conf b/job_templates/sag_pt_deploy_map/app_1/config_fed_client.conf index bb84af105c..6667f3e292 100644 --- a/job_templates/sag_pt_deploy_map/app_1/config_fed_client.conf +++ b/job_templates/sag_pt_deploy_map/app_1/config_fed_client.conf @@ -64,7 +64,7 @@ args { # the launcher will invoke the script - script = "python3 custom/{app_script} {app_config} " + script = "python3 -u custom/{app_script} {app_config} " # if launch_once is true, the SubprocessLauncher will launch once for the whole job # if launch_once is false, the SubprocessLauncher will launch a process for each task it receives from server launch_once = true diff --git a/job_templates/sag_pt_deploy_map/app_2/config_fed_client.conf b/job_templates/sag_pt_deploy_map/app_2/config_fed_client.conf index bb84af105c..6667f3e292 100644 --- a/job_templates/sag_pt_deploy_map/app_2/config_fed_client.conf +++ b/job_templates/sag_pt_deploy_map/app_2/config_fed_client.conf @@ -64,7 +64,7 @@ args { # the launcher will invoke the script - script = "python3 custom/{app_script} {app_config} " + script = "python3 -u custom/{app_script} {app_config} " # if launch_once is true, the SubprocessLauncher will launch once for the whole job # if launch_once is false, the SubprocessLauncher will launch a process for each task it receives from server launch_once = true diff --git a/job_templates/sag_pt_he/config_fed_client.conf b/job_templates/sag_pt_he/config_fed_client.conf index d6b8eeccea..21e73aa94f 100644 --- a/job_templates/sag_pt_he/config_fed_client.conf +++ b/job_templates/sag_pt_he/config_fed_client.conf @@ -84,7 +84,7 @@ args { # the launcher will invoke the script - script = "python3 custom/{app_script} {app_config} " + script = "python3 -u custom/{app_script} {app_config} " # if launch_once is true, the SubprocessLauncher will launch once for the whole job # if launch_once is false, the SubprocessLauncher will launch a process for each task it receives from server launch_once = true diff --git a/job_templates/sag_pt_mlflow/config_fed_client.conf b/job_templates/sag_pt_mlflow/config_fed_client.conf index aaba629957..9d89605160 100644 --- a/job_templates/sag_pt_mlflow/config_fed_client.conf +++ b/job_templates/sag_pt_mlflow/config_fed_client.conf @@ -64,7 +64,7 @@ args { # the launcher will invoke the script - script = "python3 custom/{app_script} {app_config} " + script = "python3 -u custom/{app_script} {app_config} " # if launch_once is true, the SubprocessLauncher will launch once for the whole job # if launch_once is false, the SubprocessLauncher will launch a process for each task it receives from server launch_once = true diff --git a/job_templates/sag_tf/config_fed_client.conf b/job_templates/sag_tf/config_fed_client.conf index c755f8fb45..39e77a9d82 100644 --- a/job_templates/sag_tf/config_fed_client.conf +++ b/job_templates/sag_tf/config_fed_client.conf @@ -55,7 +55,7 @@ components = [ args { # the launcher will invoke the script - script = "python3 custom/{app_script} {app_config} " + script = "python3 -u custom/{app_script} {app_config} " # if launch_once is true, the SubprocessLauncher will launch once for the whole job # if launch_once is false, the SubprocessLauncher will launch a process for each task it receives from server launch_once = true diff --git a/job_templates/sklearn_kmeans/config_fed_client.conf b/job_templates/sklearn_kmeans/config_fed_client.conf index 791cb1bb6d..d712cb5d9d 100755 --- a/job_templates/sklearn_kmeans/config_fed_client.conf +++ b/job_templates/sklearn_kmeans/config_fed_client.conf @@ -65,7 +65,7 @@ args { # the launcher will invoke the script - script = "python3 custom/{app_script} {app_config} " + script = "python3 -u custom/{app_script} {app_config} " # if launch_once is true, the SubprocessLauncher will launch once for the whole job # if launch_once is false, the SubprocessLauncher will launch a process for each task it receives from server launch_once = true diff --git a/job_templates/sklearn_linear/config_fed_client.conf b/job_templates/sklearn_linear/config_fed_client.conf index c77a9f41f2..763a0e3527 100755 --- a/job_templates/sklearn_linear/config_fed_client.conf +++ b/job_templates/sklearn_linear/config_fed_client.conf @@ -65,7 +65,7 @@ args { # the launcher will invoke the script - script = "python3 custom/{app_script} {app_config} " + script = "python3 -u custom/{app_script} {app_config} " # if launch_once is true, the SubprocessLauncher will launch once for the whole job # if launch_once is false, the SubprocessLauncher will launch a process for each task it receives from server launch_once = true diff --git a/job_templates/sklearn_svm/config_fed_client.conf b/job_templates/sklearn_svm/config_fed_client.conf index 5e42535bfd..2ab6705afc 100755 --- a/job_templates/sklearn_svm/config_fed_client.conf +++ b/job_templates/sklearn_svm/config_fed_client.conf @@ -65,7 +65,7 @@ args { # the launcher will invoke the script - script = "python3 custom/{app_script} {app_config} " + script = "python3 -u custom/{app_script} {app_config} " # if launch_once is true, the SubprocessLauncher will launch once for the whole job # if launch_once is false, the SubprocessLauncher will launch a process for each task it receives from server launch_once = true diff --git a/job_templates/swarm_cse_pt/config_fed_client.conf b/job_templates/swarm_cse_pt/config_fed_client.conf index af248157bb..1380e17ae1 100644 --- a/job_templates/swarm_cse_pt/config_fed_client.conf +++ b/job_templates/swarm_cse_pt/config_fed_client.conf @@ -89,7 +89,7 @@ components = [ args { # the launcher will invoke the script - script = "python3 custom/{app_script} {app_config} " + script = "python3 -u custom/{app_script} {app_config} " # if launch_once is true, the SubprocessLauncher will launch once for the whole job # if launch_once is false, the SubprocessLauncher will launch a process for each task it receives from server launch_once = true diff --git a/job_templates/xgboost_tree/config_fed_client.conf b/job_templates/xgboost_tree/config_fed_client.conf index 936cf5afcf..c9db9d7311 100755 --- a/job_templates/xgboost_tree/config_fed_client.conf +++ b/job_templates/xgboost_tree/config_fed_client.conf @@ -65,7 +65,7 @@ args { # the launcher will invoke the script - script = "python3 custom/{app_script} {app_config} " + script = "python3 -u custom/{app_script} {app_config} " # if launch_once is true, the SubprocessLauncher will launch once for the whole job # if launch_once is false, the SubprocessLauncher will launch a process for each task it receives from server launch_once = true diff --git a/nvflare/app_common/executors/task_exchanger.py b/nvflare/app_common/executors/task_exchanger.py index 3d4eddf200..707516e72b 100644 --- a/nvflare/app_common/executors/task_exchanger.py +++ b/nvflare/app_common/executors/task_exchanger.py @@ -111,6 +111,7 @@ def handle_event(self, event_type: str, fl_ctx: FLContext): ) self.pipe_handler.set_status_cb(self._pipe_status_cb) self.pipe.open(self.pipe_channel_name) + elif event_type == EventType.BEFORE_TASK_EXECUTION: self.pipe_handler.start() elif event_type == EventType.ABOUT_TO_END_RUN: self.log_info(fl_ctx, "Stopping pipe handler") diff --git a/nvflare/app_common/launchers/subprocess_launcher.py b/nvflare/app_common/launchers/subprocess_launcher.py index 6884b6d6a0..71e0b459e5 100644 --- a/nvflare/app_common/launchers/subprocess_launcher.py +++ b/nvflare/app_common/launchers/subprocess_launcher.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import os import shlex import subprocess +from threading import Thread from typing import Optional from nvflare.apis.fl_context import FLContext @@ -23,6 +25,11 @@ from nvflare.app_common.abstract.launcher import Launcher, LauncherRunStatus +def log_subprocess_output(process, logger): + for c in iter(process.stdout.readline, b""): + logger.info(c.decode().rstrip()) + + class SubprocessLauncher(Launcher): def __init__(self, script: str, launch_once: bool = True, clean_up_script: Optional[str] = None): """Initializes the SubprocessLauncher. @@ -38,6 +45,7 @@ def __init__(self, script: str, launch_once: bool = True, clean_up_script: Optio self._script = script self._launch_once = launch_once self._clean_up_script = clean_up_script + self.logger = logging.getLogger(self.__class__.__name__) def initialize(self, fl_ctx: FLContext): self._app_dir = self.get_app_dir(fl_ctx) @@ -64,18 +72,17 @@ def _start_external_process(self): env["CLIENT_API_TYPE"] = "EX_PROCESS_API" command_seq = shlex.split(command) - self._process = subprocess.Popen( - command_seq, - stderr=subprocess.STDOUT, - cwd=self._app_dir, - env=env, + command_seq, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, cwd=self._app_dir, env=env ) + self._log_thread = Thread(target=log_subprocess_output, args=(self._process, self.logger)) + self._log_thread.start() def _stop_external_process(self): if self._process: self._process.terminate() self._process.wait() + self._log_thread.join() if self._clean_up_script: command_seq = shlex.split(self._clean_up_script) process = subprocess.Popen(command_seq, cwd=self._app_dir) @@ -83,11 +90,11 @@ def _stop_external_process(self): self._process = None def check_run_status(self, task_name: str, fl_ctx: FLContext) -> str: - if self._process: - return_code = self._process.poll() - if return_code is None: - return LauncherRunStatus.RUNNING - elif return_code == 0: - return LauncherRunStatus.COMPLETE_SUCCESS - return LauncherRunStatus.COMPLETE_FAILED - return LauncherRunStatus.NOT_RUNNING + if self._process is None: + return LauncherRunStatus.NOT_RUNNING + return_code = self._process.poll() + if return_code is None: + return LauncherRunStatus.RUNNING + if return_code == 0: + return LauncherRunStatus.COMPLETE_SUCCESS + return LauncherRunStatus.COMPLETE_FAILED diff --git a/nvflare/app_common/storages/filesystem_storage.py b/nvflare/app_common/storages/filesystem_storage.py index b1d5547a0f..bfa82654f0 100644 --- a/nvflare/app_common/storages/filesystem_storage.py +++ b/nvflare/app_common/storages/filesystem_storage.py @@ -29,6 +29,20 @@ def _write(path: str, content, mv_file=True): + """Create a file at the specified 'path' with the specified 'content'. + + Args: + path: the path of the file to be created + content: content for the file to be created. It could be either bytes, or path (str) to the source file that + contains the content. + mv_file: whether the destination file should be created simply by moving the source file. This is applicable + only when the 'content' is the path of the source file. If mv_file is False, the destination is created + by copying from the source file, and the source file will remain intact; If mv_file is True, the + destination file is created by "move" the source file, and the original source file will no longer exist. + + Returns: + + """ tmp_path = path + "_" + str(uuid.uuid4()) try: Path(os.path.dirname(path)).mkdir(parents=True, exist_ok=True) @@ -162,7 +176,7 @@ def clone_object(self, from_uri: str, to_uri: str, meta: dict, overwrite_existin from_full_uri = self._object_path(from_uri) from_data_path = os.path.join(from_full_uri, DATA) - _write(data_path, from_data_path) + _write(data_path, from_data_path, mv_file=False) meta_path = os.path.join(full_uri, META) try: diff --git a/nvflare/app_common/widgets/metric_relay.py b/nvflare/app_common/widgets/metric_relay.py index cf321a59d4..deac9c1b34 100644 --- a/nvflare/app_common/widgets/metric_relay.py +++ b/nvflare/app_common/widgets/metric_relay.py @@ -67,6 +67,7 @@ def handle_event(self, event_type: str, fl_ctx: FLContext): self.pipe_handler.set_status_cb(self._pipe_status_cb) self.pipe_handler.set_message_cb(self._pipe_msg_cb) self.pipe.open(self.pipe_channel_name) + elif event_type == EventType.BEFORE_TASK_EXECUTION: self.pipe_handler.start() elif event_type == EventType.ABOUT_TO_END_RUN: self.log_info(fl_ctx, "Stopping pipe handler") diff --git a/nvflare/app_opt/tf/fedopt_ctl.py b/nvflare/app_opt/tf/fedopt_ctl.py new file mode 100644 index 0000000000..0ec0d2420b --- /dev/null +++ b/nvflare/app_opt/tf/fedopt_ctl.py @@ -0,0 +1,160 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import time +from typing import Dict + +import tensorflow as tf + +from nvflare.app_common.abstract.fl_model import FLModel +from nvflare.app_common.workflows.fedavg import FedAvg +from nvflare.security.logging import secure_format_exception + + +class FedOpt(FedAvg): + def __init__( + self, + *args, + optimizer_args: dict = { + "path": "tensorflow.keras.optimizers.SGD", + "args": {"learning_rate": 1.0, "momentum": 0.6}, + }, + lr_scheduler_args: dict = { + "path": "tensorflow.keras.optimizers.schedules.CosineDecay", + "args": {"initial_learning_rate": 1.0, "decay_steps": None, "alpha": 0.9}, + }, + **kwargs, + ): + """Implement the FedOpt algorithm. Based on FedAvg ModelController. + + The algorithm is proposed in Reddi, Sashank, et al. "Adaptive federated optimization." arXiv preprint arXiv:2003.00295 (2020). + After each round, update the global model's trainable variables using the specified optimizer and learning rate scheduler, + in this case, SGD with momentum & CosineDecay. + + Args: + optimizer_args: dictionary of optimizer arguments, with keys of 'optimizer_path' and 'args. + lr_scheduler_args: dictionary of server-side learning rate scheduler arguments, with keys of 'lr_scheduler_path' and 'args. + + Raises: + TypeError: when any of input arguments does not have correct type + """ + super().__init__(*args, **kwargs) + + self.optimizer_args = optimizer_args + self.lr_scheduler_args = lr_scheduler_args + + # Set "decay_steps" arg to num_rounds + if lr_scheduler_args["args"]["decay_steps"] is None: + lr_scheduler_args["args"]["decay_steps"] = self.num_rounds + + self.keras_model = None + self.optimizer = None + self.lr_scheduler = None + + def run(self): + """ + Override run method to add set-up for FedOpt specific optimizer + and LR scheduler. + """ + # set up optimizer + try: + if "args" not in self.optimizer_args: + self.optimizer_args["args"] = {} + self.optimizer = self.build_component(self.optimizer_args) + except Exception as e: + error_msg = f"Exception while constructing optimizer: {secure_format_exception(e)}" + self.exception(error_msg) + self.panic(error_msg) + return + + # set up lr scheduler + try: + if "args" not in self.lr_scheduler_args: + self.lr_scheduler_args["args"] = {} + self.lr_scheduler = self.build_component(self.lr_scheduler_args) + self.optimizer.learning_rate = self.lr_scheduler + except Exception as e: + error_msg = f"Exception while constructing lr_scheduler: {secure_format_exception(e)}" + self.exception(error_msg) + self.panic(error_msg) + return + + super().run() + + def _to_tf_params_list(self, params: Dict, negate: bool = False): + """ + Convert FLModel params to a list of tf.Variables. + Optionally negate the values of weights, needed + to apply gradients. + """ + tf_params_list = [] + for k, v in params.items(): + if negate: + v = -1 * v + tf_params_list.append(tf.Variable(v)) + return tf_params_list + + def update_model(self, global_model: FLModel, aggr_result: FLModel): + """ + Override the default version of update_model + to perform update with Keras Optimizer on the + global model stored in memory in persistor, instead of + creating new temporary model on-the-fly. + + Creating a new model would not work for Keras + Optimizers, since an optimizer is bind to + specific set of Variables. + + """ + # Get the Keras model stored in memory in persistor. + global_model_tf = self.persistor.model + global_params = global_model_tf.trainable_weights + + # Compute model diff: need to use model diffs as + # gradients to be applied by the optimizer. + model_diff_params = {k: aggr_result.params[k] - global_model.params[k] for k in global_model.params} + model_diff = self._to_tf_params_list(model_diff_params, negate=True) + + # Apply model diffs as gradients, using the optimizer. + start = time.time() + self.optimizer.apply_gradients(zip(model_diff, global_params)) + secs = time.time() - start + + # Convert updated global model weights to + # numpy format for FLModel. + start = time.time() + weights = global_model_tf.get_weights() + w_idx = 0 + new_weights = {} + for key in global_model.params: + w = weights[w_idx] + while global_model.params[key].shape != w.shape: + w_idx += 1 + w = weights[w_idx] + new_weights[key] = w + secs_detach = time.time() - start + + self.info( + f"FedOpt ({type(self.optimizer)}) server model update " + f"round {self.current_round}, " + f"{type(self.lr_scheduler)} " + f"lr: {self.optimizer.learning_rate}, " + f"update: {secs} secs., detach: {secs_detach} secs.", + ) + + global_model.params = new_weights + global_model.meta = aggr_result.meta + + return global_model diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/defs.py b/nvflare/app_opt/xgboost/histogram_based_v2/defs.py index b559306440..b32535ac15 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/defs.py +++ b/nvflare/app_opt/xgboost/histogram_based_v2/defs.py @@ -109,6 +109,8 @@ class Constant: HEADER_KEY_HORIZONTAL = "xgb.horizontal" HEADER_KEY_ORIGINAL_BUF_SIZE = "xgb.original_buf_size" HEADER_KEY_IN_AGGR = "xgb.in_aggr" + HEADER_KEY_WORLD_SIZE = "xgb.world_size" + HEADER_KEY_SIZE_DICT = "xgb.size_dict" DUMMY_BUFFER_SIZE = 4 @@ -122,8 +124,6 @@ class Constant: class SplitMode: ROW = 0 COL = 1 - COL_SECURE = 2 - ROW_SECURE = 3 # Mapping of text training mode to split mode @@ -132,10 +132,10 @@ class SplitMode: "horizontal": SplitMode.ROW, "v": SplitMode.COL, "vertical": SplitMode.COL, - "hs": SplitMode.ROW_SECURE, - "horizontal_secure": SplitMode.ROW_SECURE, - "vs": SplitMode.COL_SECURE, - "vertical_secure": SplitMode.COL_SECURE, + "hs": SplitMode.ROW, + "horizontal_secure": SplitMode.ROW, + "vs": SplitMode.COL, + "vertical_secure": SplitMode.COL, } SECURE_TRAINING_MODES = {"hs", "horizontal_secure", "vs", "vertical_secure"} diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/proto/federated_pb2.pyi b/nvflare/app_opt/xgboost/histogram_based_v2/proto/federated_pb2.pyi index 7ad47596df..7dc3e6dde1 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/proto/federated_pb2.pyi +++ b/nvflare/app_opt/xgboost/histogram_based_v2/proto/federated_pb2.pyi @@ -6,7 +6,7 @@ from typing import ClassVar as _ClassVar, Optional as _Optional, Union as _Union DESCRIPTOR: _descriptor.FileDescriptor class DataType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): - __slots__ = [] + __slots__ = () HALF: _ClassVar[DataType] FLOAT: _ClassVar[DataType] DOUBLE: _ClassVar[DataType] @@ -21,7 +21,7 @@ class DataType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): UINT64: _ClassVar[DataType] class ReduceOperation(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): - __slots__ = [] + __slots__ = () MAX: _ClassVar[ReduceOperation] MIN: _ClassVar[ReduceOperation] SUM: _ClassVar[ReduceOperation] @@ -48,7 +48,7 @@ BITWISE_OR: ReduceOperation BITWISE_XOR: ReduceOperation class AllgatherRequest(_message.Message): - __slots__ = ["sequence_number", "rank", "send_buffer"] + __slots__ = ("sequence_number", "rank", "send_buffer") SEQUENCE_NUMBER_FIELD_NUMBER: _ClassVar[int] RANK_FIELD_NUMBER: _ClassVar[int] SEND_BUFFER_FIELD_NUMBER: _ClassVar[int] @@ -58,13 +58,13 @@ class AllgatherRequest(_message.Message): def __init__(self, sequence_number: _Optional[int] = ..., rank: _Optional[int] = ..., send_buffer: _Optional[bytes] = ...) -> None: ... class AllgatherReply(_message.Message): - __slots__ = ["receive_buffer"] + __slots__ = ("receive_buffer",) RECEIVE_BUFFER_FIELD_NUMBER: _ClassVar[int] receive_buffer: bytes def __init__(self, receive_buffer: _Optional[bytes] = ...) -> None: ... class AllgatherVRequest(_message.Message): - __slots__ = ["sequence_number", "rank", "send_buffer"] + __slots__ = ("sequence_number", "rank", "send_buffer") SEQUENCE_NUMBER_FIELD_NUMBER: _ClassVar[int] RANK_FIELD_NUMBER: _ClassVar[int] SEND_BUFFER_FIELD_NUMBER: _ClassVar[int] @@ -74,13 +74,13 @@ class AllgatherVRequest(_message.Message): def __init__(self, sequence_number: _Optional[int] = ..., rank: _Optional[int] = ..., send_buffer: _Optional[bytes] = ...) -> None: ... class AllgatherVReply(_message.Message): - __slots__ = ["receive_buffer"] + __slots__ = ("receive_buffer",) RECEIVE_BUFFER_FIELD_NUMBER: _ClassVar[int] receive_buffer: bytes def __init__(self, receive_buffer: _Optional[bytes] = ...) -> None: ... class AllreduceRequest(_message.Message): - __slots__ = ["sequence_number", "rank", "send_buffer", "data_type", "reduce_operation"] + __slots__ = ("sequence_number", "rank", "send_buffer", "data_type", "reduce_operation") SEQUENCE_NUMBER_FIELD_NUMBER: _ClassVar[int] RANK_FIELD_NUMBER: _ClassVar[int] SEND_BUFFER_FIELD_NUMBER: _ClassVar[int] @@ -94,13 +94,13 @@ class AllreduceRequest(_message.Message): def __init__(self, sequence_number: _Optional[int] = ..., rank: _Optional[int] = ..., send_buffer: _Optional[bytes] = ..., data_type: _Optional[_Union[DataType, str]] = ..., reduce_operation: _Optional[_Union[ReduceOperation, str]] = ...) -> None: ... class AllreduceReply(_message.Message): - __slots__ = ["receive_buffer"] + __slots__ = ("receive_buffer",) RECEIVE_BUFFER_FIELD_NUMBER: _ClassVar[int] receive_buffer: bytes def __init__(self, receive_buffer: _Optional[bytes] = ...) -> None: ... class BroadcastRequest(_message.Message): - __slots__ = ["sequence_number", "rank", "send_buffer", "root"] + __slots__ = ("sequence_number", "rank", "send_buffer", "root") SEQUENCE_NUMBER_FIELD_NUMBER: _ClassVar[int] RANK_FIELD_NUMBER: _ClassVar[int] SEND_BUFFER_FIELD_NUMBER: _ClassVar[int] @@ -112,7 +112,7 @@ class BroadcastRequest(_message.Message): def __init__(self, sequence_number: _Optional[int] = ..., rank: _Optional[int] = ..., send_buffer: _Optional[bytes] = ..., root: _Optional[int] = ...) -> None: ... class BroadcastReply(_message.Message): - __slots__ = ["receive_buffer"] + __slots__ = ("receive_buffer",) RECEIVE_BUFFER_FIELD_NUMBER: _ClassVar[int] receive_buffer: bytes def __init__(self, receive_buffer: _Optional[bytes] = ...) -> None: ... diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/proto/federated_pb2_grpc.py b/nvflare/app_opt/xgboost/histogram_based_v2/proto/federated_pb2_grpc.py index 45eee5c8dd..549d0e4ffc 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/proto/federated_pb2_grpc.py +++ b/nvflare/app_opt/xgboost/histogram_based_v2/proto/federated_pb2_grpc.py @@ -12,13 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: federated.proto +# Protobuf Python Version: 4.25.1 # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! """Client and server classes corresponding to protobuf-defined services.""" import grpc import nvflare.app_opt.xgboost.histogram_based_v2.proto.federated_pb2 as federated__pb2 - class FederatedStub(object): """Missing associated documentation comment in .proto file.""" diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/runners/xgb_client_runner.py b/nvflare/app_opt/xgboost/histogram_based_v2/runners/xgb_client_runner.py index 1b98829711..0d8e8bec1d 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/runners/xgb_client_runner.py +++ b/nvflare/app_opt/xgboost/histogram_based_v2/runners/xgb_client_runner.py @@ -30,7 +30,9 @@ from nvflare.fuel.utils.obj_utils import get_logger from nvflare.utils.cli_utils import get_package_root -LOADER_PARAMS_LIBRARY_PATH = "LIBRARY_PATH" +PLUGIN_PARAM_KEY = "federated_plugin" +PLUGIN_KEY_NAME = "name" +PLUGIN_KEY_PATH = "path" class XGBClientRunner(AppRunner, FLComponent): @@ -135,7 +137,7 @@ def run(self, ctx: dict): self.logger.info(f"server address is {self._server_addr}") communicator_env = { - "xgboost_communicator": "federated", + "dmlc_communicator": "federated", "federated_server_address": f"{self._server_addr}", "federated_world_size": self._world_size, "federated_rank": self._rank, @@ -145,38 +147,35 @@ def run(self, ctx: dict): self.logger.info("XGBoost non-secure training") else: xgb_plugin_name = ConfigService.get_str_var( - name="xgb_plugin_name", conf=SystemConfigs.RESOURCES_CONF, default="nvflare" + name="xgb_plugin_name", conf=SystemConfigs.RESOURCES_CONF, default=None ) - - xgb_loader_params = ConfigService.get_dict_var( - name="xgb_loader_params", conf=SystemConfigs.RESOURCES_CONF, default={} + xgb_plugin_path = ConfigService.get_str_var( + name="xgb_plugin_path", conf=SystemConfigs.RESOURCES_CONF, default=None + ) + xgb_plugin_params: dict = ConfigService.get_dict_var( + name=PLUGIN_PARAM_KEY, conf=SystemConfigs.RESOURCES_CONF, default={} ) - # Library path is frequently used, add a scalar config var and overwrite what's in the dict - xgb_library_path = ConfigService.get_str_var(name="xgb_library_path", conf=SystemConfigs.RESOURCES_CONF) - if xgb_library_path: - xgb_loader_params[LOADER_PARAMS_LIBRARY_PATH] = xgb_library_path + # path and name can be overwritten by scalar configuration + if xgb_plugin_name: + xgb_plugin_params[PLUGIN_KEY_NAME] = xgb_plugin_name - lib_path = xgb_loader_params.get(LOADER_PARAMS_LIBRARY_PATH, None) - if not lib_path: - xgb_loader_params[LOADER_PARAMS_LIBRARY_PATH] = str(get_package_root() / "libs") + if xgb_plugin_path: + xgb_plugin_params[PLUGIN_KEY_PATH] = xgb_plugin_path - xgb_proc_params = ConfigService.get_dict_var( - name="xgb_proc_params", conf=SystemConfigs.RESOURCES_CONF, default={} - ) + # Set default plugin name + if not xgb_plugin_params.get(PLUGIN_KEY_NAME): + xgb_plugin_params[PLUGIN_KEY_NAME] = "cuda_paillier" - self.logger.info( - f"XGBoost secure mode: {self._training_mode} plugin_name: {xgb_plugin_name} " - f"proc_params: {xgb_proc_params} loader_params: {xgb_loader_params}" - ) + if not xgb_plugin_params.get(PLUGIN_KEY_PATH): + # This only works on Linux. Need to support other platforms + lib_ext = "so" + lib_name = f"lib{xgb_plugin_params[PLUGIN_KEY_NAME]}.{lib_ext}" + xgb_plugin_params[PLUGIN_KEY_PATH] = str(get_package_root() / "libs" / lib_name) - communicator_env.update( - { - "plugin_name": xgb_plugin_name, - "proc_params": xgb_proc_params, - "loader_params": xgb_loader_params, - } - ) + self.logger.info(f"XGBoost secure training: {self._training_mode} Params: {xgb_plugin_params}") + + communicator_env[PLUGIN_PARAM_KEY] = xgb_plugin_params with xgb.collective.CommunicatorContext(**communicator_env): # Load the data. Dmatrix must be created with column split mode in CommunicatorContext for vertical FL diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/runners/xgb_server_runner.py b/nvflare/app_opt/xgboost/histogram_based_v2/runners/xgb_server_runner.py index 32e708c90e..e4a8796a38 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/runners/xgb_server_runner.py +++ b/nvflare/app_opt/xgboost/histogram_based_v2/runners/xgb_server_runner.py @@ -29,8 +29,8 @@ def run(self, ctx: dict): self._world_size = ctx.get(Constant.RUNNER_CTX_WORLD_SIZE) xgb_federated.run_federated_server( + n_workers=self._world_size, port=self._port, - world_size=self._world_size, ) self._stopped = True diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/sec/client_handler.py b/nvflare/app_opt/xgboost/histogram_based_v2/sec/client_handler.py index 5aad654824..ea5607d828 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/sec/client_handler.py +++ b/nvflare/app_opt/xgboost/histogram_based_v2/sec/client_handler.py @@ -299,6 +299,10 @@ def _process_after_all_gather_v(self, fl_ctx: FLContext): self._process_after_all_gather_v_vertical(fl_ctx) def _process_after_all_gather_v_vertical(self, fl_ctx: FLContext): + reply = fl_ctx.get_prop(Constant.PARAM_KEY_REPLY) + size_dict = reply.get_header(Constant.HEADER_KEY_SIZE_DICT) + total_size = sum(size_dict.values()) + self.info(fl_ctx, f"{total_size=} {size_dict=}") rcv_buf = fl_ctx.get_prop(Constant.PARAM_KEY_RCV_BUF) # this rcv_buf is a list of replies from ALL clients! rank = fl_ctx.get_prop(Constant.PARAM_KEY_RANK) @@ -309,7 +313,7 @@ def _process_after_all_gather_v_vertical(self, fl_ctx: FLContext): if not self.clear_ghs: # this is non-label client - don't care about the results - dummy = os.urandom(Constant.DUMMY_BUFFER_SIZE) + dummy = os.urandom(total_size) fl_ctx.set_prop(key=Constant.PARAM_KEY_RCV_BUF, value=dummy, private=True, sticky=False) self.info(fl_ctx, "non-label client: return dummy buffer back to XGB") return @@ -352,16 +356,45 @@ def _process_after_all_gather_v_vertical(self, fl_ctx: FLContext): self.info(fl_ctx, f"final aggr: {gid=} features={fid_list}") result = self.data_converter.encode_aggregation_result(final_result, fl_ctx) + + # XGBoost expects every work has a set of histograms. They are already combined here so + # just add zeros + zero_result = final_result + for result_list in zero_result.values(): + for item in result_list: + size = len(item.aggregated_hist) + item.aggregated_hist = [(0, 0)] * size + zero_buf = self.data_converter.encode_aggregation_result(zero_result, fl_ctx) + world_size = len(size_dict) + for _ in range(world_size - 1): + result += zero_buf + + # XGBoost checks that the size of allgatherv is not changed + padding_size = total_size - len(result) + if padding_size > 0: + result += b"\x00" * padding_size + elif padding_size < 0: + self.error(fl_ctx, f"The original size {total_size} is not big enough for data size {len(result)}") + fl_ctx.set_prop(key=Constant.PARAM_KEY_RCV_BUF, value=result, private=True, sticky=False) def _process_after_all_gather_v_horizontal(self, fl_ctx: FLContext): + reply = fl_ctx.get_prop(Constant.PARAM_KEY_REPLY) + world_size = reply.get_header(Constant.HEADER_KEY_WORLD_SIZE) encrypted_histograms = fl_ctx.get_prop(Constant.PARAM_KEY_RCV_BUF) rank = fl_ctx.get_prop(Constant.PARAM_KEY_RANK) if not isinstance(encrypted_histograms, CKKSVector): return self._abort(f"rank {rank}: expect a CKKSVector but got {type(encrypted_histograms)}", fl_ctx) histograms = encrypted_histograms.decrypt(secret_key=self.tenseal_context.secret_key()) + result = self.data_converter.encode_histograms_result(histograms, fl_ctx) + + # XGBoost expect every worker returns a histogram, all zeros are returned for other workers + zeros = [0.0] * len(histograms) + zero_buf = self.data_converter.encode_histograms_result(zeros, fl_ctx) + for _ in range(world_size - 1): + result += zero_buf fl_ctx.set_prop(key=Constant.PARAM_KEY_RCV_BUF, value=result, private=True, sticky=False) def handle_event(self, event_type: str, fl_ctx: FLContext): @@ -376,7 +409,7 @@ def handle_event(self, event_type: str, fl_ctx: FLContext): else: self.debug(fl_ctx, "Tenseal module not loaded, horizontal secure XGBoost is not supported") except Exception as ex: - self.debug(fl_ctx, f"Can't load tenseal context, horizontal secure XGBoost is not supported: {ex}") + self.error(fl_ctx, f"Can't load tenseal context, horizontal secure XGBoost is not supported: {ex}") self.tenseal_context = None elif event_type == EventType.END_RUN: self.tenseal_context = None diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/sec/server_handler.py b/nvflare/app_opt/xgboost/histogram_based_v2/sec/server_handler.py index 53e936c7d4..47e44d17d6 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/sec/server_handler.py +++ b/nvflare/app_opt/xgboost/histogram_based_v2/sec/server_handler.py @@ -39,6 +39,8 @@ def __init__(self): self.aggr_result_dict = None self.aggr_result_to_send = None self.aggr_result_lock = threading.Lock() + self.world_size = 0 + self.size_dict = None if tenseal_imported: decomposers.register() @@ -124,6 +126,10 @@ def _process_before_all_gather_v(self, fl_ctx: FLContext): else: self.info(fl_ctx, f"no aggr data from {rank=}") + if self.size_dict is None: + self.size_dict = {} + + self.size_dict[rank] = request.get_header(Constant.HEADER_KEY_ORIGINAL_BUF_SIZE) # only send a dummy to the Server fl_ctx.set_prop( key=Constant.PARAM_KEY_SEND_BUF, value=os.urandom(Constant.DUMMY_BUFFER_SIZE), private=True, sticky=False @@ -146,6 +152,7 @@ def _process_after_all_gather_v(self, fl_ctx: FLContext): horizontal = fl_ctx.get_prop(Constant.HEADER_KEY_HORIZONTAL) reply.set_header(Constant.HEADER_KEY_ENCRYPTED_DATA, True) reply.set_header(Constant.HEADER_KEY_HORIZONTAL, horizontal) + with self.aggr_result_lock: if not self.aggr_result_to_send: if not self.aggr_result_dict: @@ -159,6 +166,10 @@ def _process_after_all_gather_v(self, fl_ctx: FLContext): # reset aggr_result_dict for next gather self.aggr_result_dict = None + self.world_size = len(self.size_dict) + reply.set_header(Constant.HEADER_KEY_WORLD_SIZE, self.world_size) + reply.set_header(Constant.HEADER_KEY_SIZE_DICT, self.size_dict) + if horizontal: length = self.aggr_result_to_send.size() else: diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/secure_data_loader.py b/nvflare/app_opt/xgboost/histogram_based_v2/secure_data_loader.py new file mode 100644 index 0000000000..6540eb519c --- /dev/null +++ b/nvflare/app_opt/xgboost/histogram_based_v2/secure_data_loader.py @@ -0,0 +1,50 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import xgboost as xgb + +from nvflare.app_opt.xgboost.data_loader import XGBDataLoader +from nvflare.app_opt.xgboost.histogram_based_v2.defs import TRAINING_MODE_MAPPING, SplitMode + + +class SecureDataLoader(XGBDataLoader): + def __init__(self, rank: int, folder: str): + """Reads CSV dataset and return XGB data matrix in vertical secure mode. + + Args: + rank: Rank of the site + folder: Folder to find the CSV files + """ + self.rank = rank + self.folder = folder + + def load_data(self, client_id: str, training_mode: str): + + train_path = f"{self.folder}/site-{self.rank + 1}/train.csv" + valid_path = f"{self.folder}/site-{self.rank + 1}/valid.csv" + + if training_mode not in TRAINING_MODE_MAPPING: + raise ValueError(f"Invalid training_mode: {training_mode}") + + data_split_mode = TRAINING_MODE_MAPPING[training_mode] + + if self.rank == 0 or data_split_mode == SplitMode.ROW: + label = "&label_column=0" + else: + label = "" + + train_data = xgb.DMatrix(train_path + f"?format=csv{label}", data_split_mode=data_split_mode) + valid_data = xgb.DMatrix(valid_path + f"?format=csv{label}", data_split_mode=data_split_mode) + + return train_data, valid_data diff --git a/nvflare/job_config/fed_job.py b/nvflare/job_config/fed_job.py index bd72a8fce6..883c83ad29 100644 --- a/nvflare/job_config/fed_job.py +++ b/nvflare/job_config/fed_job.py @@ -168,11 +168,11 @@ def __init__(self, name="fed_job", min_clients=1, mandatory_clients=None, key_me if metrics are a `dict`, `key_metric` can select the metric used for global model selection. Defaults to "accuracy". """ - self.job_name = name + self.name = name self.key_metric = key_metric self.clients = [] self.job: FedJobConfig = FedJobConfig( - job_name=self.job_name, min_clients=min_clients, mandatory_clients=mandatory_clients + job_name=self.name, min_clients=min_clients, mandatory_clients=mandatory_clients ) self._deploy_map = {} self._deployed = False diff --git a/nvflare/private/fed/app/simulator/simulator_runner.py b/nvflare/private/fed/app/simulator/simulator_runner.py index 274ad2a889..3952ea5db5 100644 --- a/nvflare/private/fed/app/simulator/simulator_runner.py +++ b/nvflare/private/fed/app/simulator/simulator_runner.py @@ -174,6 +174,10 @@ def setup(self): self._cleanup_workspace() init_security_content_service(self.args.workspace) + os.makedirs(os.path.join(self.simulator_root, "server")) + log_file = os.path.join(self.simulator_root, "server", WorkspaceConstants.LOG_FILE_NAME) + add_logfile_handler(log_file) + try: data_bytes, job_name, meta = self.validate_job_data() @@ -501,9 +505,6 @@ def start_server_app(self, args): args.workspace = os.path.join(self.simulator_root, "server") os.chdir(args.workspace) - log_file = os.path.join(self.simulator_root, "server", WorkspaceConstants.LOG_FILE_NAME) - add_logfile_handler(log_file) - args.server_config = os.path.join("config", JobConstants.SERVER_JOB_CONFIG) app_custom_folder = os.path.join(app_server_root, "custom") if os.path.isdir(app_custom_folder) and app_custom_folder not in sys.path: diff --git a/setup.cfg b/setup.cfg index 249d274bc3..967d14c71d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -23,7 +23,7 @@ install_requires = Flask-SQLAlchemy==3.1.1 SQLAlchemy==2.0.16 grpcio==1.62.1 - gunicorn>=20.1.0 + gunicorn>=22.0.0 numpy<=1.26.4 protobuf==4.24.4 psutil>=5.9.1