From 45a16d569b650378f26b3420e69c403f99272d93 Mon Sep 17 00:00:00 2001 From: Zhihong Zhang <100308595+nvidianz@users.noreply.github.com> Date: Fri, 2 Aug 2024 17:03:01 -0400 Subject: [PATCH] XGBoost plugin with new API (#2725) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Updated FOBS readme to add DatumManager, added agrpcs as secure scheme * Implemented LocalPlugin * Refactoring plugin * Fixed formats * Fixed horizontal secure isses with mismatching algather-v sizes * Added padding to the buffer so it's big enough for histograms * Format fix * Changed log level for tenseal exceptions * Fixed a typo * Added debug statements * Fixed LocalPlugin horizontal bug * Added #include * Added docstring to BasePlugin --------- Co-authored-by: Yuan-Ting Hsieh (謝沅廷) --- .../xgboost/encryption_plugins/.editorconfig | 11 + .../xgboost/encryption_plugins/CMakeLists.txt | 41 ++ .../xgboost/encryption_plugins/README.md | 9 + .../src/README.md | 0 .../src/dam/README.md | 0 .../xgboost/encryption_plugins/src/dam/dam.cc | 274 +++++++++++++ .../src/include/base_plugin.h | 155 +++++++ .../encryption_plugins/src/include/dam.h | 143 +++++++ .../src/include/data_set_ids.h | 23 ++ .../src/include/delegated_plugin.h | 66 +++ .../src/include/local_plugin.h | 107 +++++ .../src/include/nvflare_plugin.h} | 48 ++- .../src/include/pass_thru_plugin.h | 41 ++ .../encryption_plugins/src/include/util.h | 18 + .../src/plugins/delegated_plugin.cc | 36 ++ .../src/plugins/local_plugin.cc | 366 +++++++++++++++++ .../src/plugins/nvflare_plugin.cc | 297 ++++++++++++++ .../src/plugins/pass_thru_plugin.cc | 130 ++++++ .../src/plugins/plugin_main.cc | 184 +++++++++ .../encryption_plugins/src/plugins/util.cc | 99 +++++ .../encryption_plugins/tests/CMakeLists.txt | 14 + .../tests/test_dam.cc | 27 +- .../tests/test_main.cc | 0 .../tests/test_tenseal.py | 0 integration/xgboost/processor/CMakeLists.txt | 46 --- integration/xgboost/processor/README.md | 11 - integration/xgboost/processor/src/dam/dam.cc | 146 ------- .../xgboost/processor/src/include/dam.h | 93 ----- .../src/nvflare-plugin/nvflare_processor.cc | 378 ------------------ .../xgboost/processor/tests/CMakeLists.txt | 14 - .../xgboost/histogram_based_v2/defs.py | 12 +- .../proto/federated_pb2.pyi | 20 +- .../proto/federated_pb2_grpc.py | 5 +- .../runners/xgb_client_runner.py | 53 ++- .../runners/xgb_server_runner.py | 2 +- .../histogram_based_v2/sec/client_handler.py | 37 +- .../histogram_based_v2/sec/server_handler.py | 11 + .../histogram_based_v2/secure_data_loader.py | 50 +++ 38 files changed, 2209 insertions(+), 758 deletions(-) create mode 100644 integration/xgboost/encryption_plugins/.editorconfig create mode 100644 integration/xgboost/encryption_plugins/CMakeLists.txt create mode 100644 integration/xgboost/encryption_plugins/README.md rename integration/xgboost/{processor => encryption_plugins}/src/README.md (100%) rename integration/xgboost/{processor => encryption_plugins}/src/dam/README.md (100%) create mode 100644 integration/xgboost/encryption_plugins/src/dam/dam.cc create mode 100644 integration/xgboost/encryption_plugins/src/include/base_plugin.h create mode 100644 integration/xgboost/encryption_plugins/src/include/dam.h create mode 100644 integration/xgboost/encryption_plugins/src/include/data_set_ids.h create mode 100644 integration/xgboost/encryption_plugins/src/include/delegated_plugin.h create mode 100644 integration/xgboost/encryption_plugins/src/include/local_plugin.h rename integration/xgboost/{processor/src/include/nvflare_processor.h => encryption_plugins/src/include/nvflare_plugin.h} (77%) create mode 100644 integration/xgboost/encryption_plugins/src/include/pass_thru_plugin.h create mode 100644 integration/xgboost/encryption_plugins/src/include/util.h create mode 100644 integration/xgboost/encryption_plugins/src/plugins/delegated_plugin.cc create mode 100644 integration/xgboost/encryption_plugins/src/plugins/local_plugin.cc create mode 100644 integration/xgboost/encryption_plugins/src/plugins/nvflare_plugin.cc create mode 100644 integration/xgboost/encryption_plugins/src/plugins/pass_thru_plugin.cc create mode 100644 integration/xgboost/encryption_plugins/src/plugins/plugin_main.cc create mode 100644 integration/xgboost/encryption_plugins/src/plugins/util.cc create mode 100644 integration/xgboost/encryption_plugins/tests/CMakeLists.txt rename integration/xgboost/{processor => encryption_plugins}/tests/test_dam.cc (65%) rename integration/xgboost/{processor => encryption_plugins}/tests/test_main.cc (100%) rename integration/xgboost/{processor => encryption_plugins}/tests/test_tenseal.py (100%) delete mode 100644 integration/xgboost/processor/CMakeLists.txt delete mode 100644 integration/xgboost/processor/README.md delete mode 100644 integration/xgboost/processor/src/dam/dam.cc delete mode 100644 integration/xgboost/processor/src/include/dam.h delete mode 100644 integration/xgboost/processor/src/nvflare-plugin/nvflare_processor.cc delete mode 100644 integration/xgboost/processor/tests/CMakeLists.txt create mode 100644 nvflare/app_opt/xgboost/histogram_based_v2/secure_data_loader.py diff --git a/integration/xgboost/encryption_plugins/.editorconfig b/integration/xgboost/encryption_plugins/.editorconfig new file mode 100644 index 0000000000..97a7bc133a --- /dev/null +++ b/integration/xgboost/encryption_plugins/.editorconfig @@ -0,0 +1,11 @@ +root = true + +[*] +charset=utf-8 +indent_style = space +indent_size = 2 +insert_final_newline = true + +[*.py] +indent_style = space +indent_size = 4 diff --git a/integration/xgboost/encryption_plugins/CMakeLists.txt b/integration/xgboost/encryption_plugins/CMakeLists.txt new file mode 100644 index 0000000000..f5d71dd61c --- /dev/null +++ b/integration/xgboost/encryption_plugins/CMakeLists.txt @@ -0,0 +1,41 @@ +cmake_minimum_required(VERSION 3.19) +project(xgb_nvflare LANGUAGES CXX C VERSION 1.0) +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_BUILD_TYPE Debug) + +option(GOOGLE_TEST "Build google tests" OFF) + +file(GLOB_RECURSE LIB_SRC "src/*.cc") + +add_library(nvflare SHARED ${LIB_SRC}) +set_target_properties(nvflare PROPERTIES + CXX_STANDARD 17 + CXX_STANDARD_REQUIRED ON + POSITION_INDEPENDENT_CODE ON + ENABLE_EXPORTS ON +) +target_include_directories(nvflare PRIVATE ${xgb_nvflare_SOURCE_DIR}/src/include) + +if (APPLE) + add_link_options("LINKER:-object_path_lto,$_lto.o") + add_link_options("LINKER:-cache_path_lto,${CMAKE_BINARY_DIR}/LTOCache") +endif () + +#-- Unit Tests +if(GOOGLE_TEST) + find_package(GTest REQUIRED) + enable_testing() + add_executable(nvflare_test) + target_link_libraries(nvflare_test PRIVATE nvflare) + + + target_include_directories(nvflare_test PRIVATE ${xgb_nvflare_SOURCE_DIR}/src/include) + + add_subdirectory(${xgb_nvflare_SOURCE_DIR}/tests) + + add_test( + NAME TestNvflarePlugins + COMMAND nvflare_test + WORKING_DIRECTORY ${xgb_nvflare_BINARY_DIR}) + +endif() diff --git a/integration/xgboost/encryption_plugins/README.md b/integration/xgboost/encryption_plugins/README.md new file mode 100644 index 0000000000..57f2c4621e --- /dev/null +++ b/integration/xgboost/encryption_plugins/README.md @@ -0,0 +1,9 @@ +# Build Instruction + +cd NVFlare/integration/xgboost/encryption_plugins +mkdir build +cd build +cmake .. +make + +The library is libxgb_nvflare.so diff --git a/integration/xgboost/processor/src/README.md b/integration/xgboost/encryption_plugins/src/README.md similarity index 100% rename from integration/xgboost/processor/src/README.md rename to integration/xgboost/encryption_plugins/src/README.md diff --git a/integration/xgboost/processor/src/dam/README.md b/integration/xgboost/encryption_plugins/src/dam/README.md similarity index 100% rename from integration/xgboost/processor/src/dam/README.md rename to integration/xgboost/encryption_plugins/src/dam/README.md diff --git a/integration/xgboost/encryption_plugins/src/dam/dam.cc b/integration/xgboost/encryption_plugins/src/dam/dam.cc new file mode 100644 index 0000000000..9fdb7d8582 --- /dev/null +++ b/integration/xgboost/encryption_plugins/src/dam/dam.cc @@ -0,0 +1,274 @@ +/** + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include "dam.h" + + +void print_hex(const uint8_t *buffer, std::size_t size) { + std::cout << std::hex; + for (int i = 0; i < size; i++) { + int c = buffer[i]; + std::cout << c << " "; + } + std::cout << std::endl << std::dec; +} + +void print_buffer(const uint8_t *buffer, std::size_t size) { + if (size <= 64) { + std::cout << "Whole buffer: " << size << " bytes" << std::endl; + print_hex(buffer, size); + return; + } + + std::cout << "First chunk, Total: " << size << " bytes" << std::endl; + print_hex(buffer, 32); + std::cout << "Last chunk, Offset: " << size-16 << " bytes" << std::endl; + print_hex(buffer+size-32, 32); +} + +size_t align(const size_t length) { + return ((length + 7)/8)*8; +} + +// DamEncoder ====== +void DamEncoder::AddBuffer(const Buffer &buffer) { + if (debug_) { + std::cout << "AddBuffer called, size: " << buffer.buf_size << std::endl; + } + if (encoded_) { + std::cout << "Buffer is already encoded" << std::endl; + return; + } + // print_buffer(buffer, buf_size); + entries_.emplace_back(kDataTypeBuffer, static_cast(buffer.buffer), buffer.buf_size); +} + +void DamEncoder::AddFloatArray(const std::vector &value) { + if (debug_) { + std::cout << "AddFloatArray called, size: " << value.size() << std::endl; + } + + if (encoded_) { + std::cout << "Buffer is already encoded" << std::endl; + return; + } + // print_buffer(reinterpret_cast(value.data()), value.size() * 8); + entries_.emplace_back(kDataTypeFloatArray, reinterpret_cast(value.data()), value.size()); +} + +void DamEncoder::AddIntArray(const std::vector &value) { + if (debug_) { + std::cout << "AddIntArray called, size: " << value.size() << std::endl; + } + + if (encoded_) { + std::cout << "Buffer is already encoded" << std::endl; + return; + } + // print_buffer(buffer, buf_size); + entries_.emplace_back(kDataTypeIntArray, reinterpret_cast(value.data()), value.size()); +} + +void DamEncoder::AddBufferArray(const std::vector &value) { + if (debug_) { + std::cout << "AddBufferArray called, size: " << value.size() << std::endl; + } + + if (encoded_) { + std::cout << "Buffer is already encoded" << std::endl; + return; + } + size_t size = 0; + for (auto &buf: value) { + size += buf.buf_size; + } + size += 8*value.size(); + entries_.emplace_back(kDataTypeBufferArray, reinterpret_cast(&value), size); +} + + +std::uint8_t * DamEncoder::Finish(size_t &size) { + encoded_ = true; + + size = CalculateSize(); + auto buf = static_cast(calloc(size, 1)); + auto pointer = buf; + auto sig = local_version_ ? kSignatureLocal : kSignature; + memcpy(pointer, sig, strlen(sig)); + memcpy(pointer+8, &size, 8); + memcpy(pointer+16, &data_set_id_, 8); + + pointer += kPrefixLen; + for (auto& entry : entries_) { + std::size_t len; + if (entry.data_type == kDataTypeBufferArray) { + auto buffers = reinterpret_cast *>(entry.pointer); + memcpy(pointer, &entry.data_type, 8); + pointer += 8; + auto array_size = static_cast(buffers->size()); + memcpy(pointer, &array_size, 8); + pointer += 8; + auto sizes = reinterpret_cast(pointer); + for (auto &item : *buffers) { + *sizes = static_cast(item.buf_size); + sizes++; + } + len = 8*buffers->size(); + auto buf_ptr = pointer + len; + for (auto &item : *buffers) { + if (item.buf_size > 0) { + memcpy(buf_ptr, item.buffer, item.buf_size); + } + buf_ptr += item.buf_size; + len += item.buf_size; + } + } else { + memcpy(pointer, &entry.data_type, 8); + pointer += 8; + memcpy(pointer, &entry.size, 8); + pointer += 8; + len = entry.size * entry.ItemSize(); + if (len) { + memcpy(pointer, entry.pointer, len); + } + } + pointer += align(len); + } + + if ((pointer - buf) != size) { + std::cout << "Invalid encoded size: " << (pointer - buf) << std::endl; + return nullptr; + } + + return buf; +} + +std::size_t DamEncoder::CalculateSize() { + std::size_t size = kPrefixLen; + + for (auto& entry : entries_) { + size += 16; // The Type and Len + auto len = entry.size * entry.ItemSize(); + size += align(len); + } + + return size; +} + + +// DamDecoder ====== + +DamDecoder::DamDecoder(std::uint8_t *buffer, std::size_t size, bool local_version, bool debug) { + local_version_ = local_version; + buffer_ = buffer; + buf_size_ = size; + pos_ = buffer + kPrefixLen; + debug_ = debug; + + if (size >= kPrefixLen) { + memcpy(&len_, buffer + 8, 8); + memcpy(&data_set_id_, buffer + 16, 8); + } else { + len_ = 0; + data_set_id_ = 0; + } +} + +bool DamDecoder::IsValid() const { + auto sig = local_version_ ? kSignatureLocal : kSignature; + return buf_size_ >= kPrefixLen && memcmp(buffer_, sig, strlen(sig)) == 0; +} + +Buffer DamDecoder::DecodeBuffer() { + auto type = *reinterpret_cast(pos_); + if (type != kDataTypeBuffer) { + std::cout << "Data type " << type << " doesn't match bytes" << std::endl; + return {}; + } + pos_ += 8; + + auto size = *reinterpret_cast(pos_); + pos_ += 8; + + if (size == 0) { + return {}; + } + + auto ptr = reinterpret_cast(pos_); + pos_ += align(size); + return{ ptr, static_cast(size)}; +} + +std::vector DamDecoder::DecodeIntArray() { + auto type = *reinterpret_cast(pos_); + if (type != kDataTypeIntArray) { + std::cout << "Data type " << type << " doesn't match Int Array" << std::endl; + return {}; + } + pos_ += 8; + + auto array_size = *reinterpret_cast(pos_); + pos_ += 8; + auto ptr = reinterpret_cast(pos_); + pos_ += align(8 * array_size); + return {ptr, ptr + array_size}; +} + +std::vector DamDecoder::DecodeFloatArray() { + auto type = *reinterpret_cast(pos_); + if (type != kDataTypeFloatArray) { + std::cout << "Data type " << type << " doesn't match Float Array" << std::endl; + return {}; + } + pos_ += 8; + + auto array_size = *reinterpret_cast(pos_); + pos_ += 8; + + auto ptr = reinterpret_cast(pos_); + pos_ += align(8 * array_size); + return {ptr, ptr + array_size}; +} + +std::vector DamDecoder::DecodeBufferArray() { + auto type = *reinterpret_cast(pos_); + if (type != kDataTypeBufferArray) { + std::cout << "Data type " << type << " doesn't match Bytes Array" << std::endl; + return {}; + } + pos_ += 8; + + auto num = *reinterpret_cast(pos_); + pos_ += 8; + + auto size_ptr = reinterpret_cast(pos_); + auto buf_ptr = pos_ + 8 * num; + size_t total_size = 8 * num; + auto result = std::vector(num); + for (int i = 0; i < num; i++) { + auto size = size_ptr[i]; + if (buf_size_ > 0) { + result[i].buf_size = size; + result[i].buffer = buf_ptr; + buf_ptr += size; + } + total_size += size; + } + + pos_ += align(total_size); + return result; +} diff --git a/integration/xgboost/encryption_plugins/src/include/base_plugin.h b/integration/xgboost/encryption_plugins/src/include/base_plugin.h new file mode 100644 index 0000000000..dddd5a7911 --- /dev/null +++ b/integration/xgboost/encryption_plugins/src/include/base_plugin.h @@ -0,0 +1,155 @@ +/** + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include // for uint8_t, uint32_t, int32_t, int64_t +#include // for string_view +#include // for pair +#include // for vector +#include +#include +#include + +#include "util.h" + +namespace nvflare { + +/** + * @brief Abstract interface for the encryption plugin + * + * All plugin implementations must inherit this class. + */ +class BasePlugin { +protected: + bool debug_ = false; + bool print_timing_ = false; + bool dam_debug_ = false; + +public: +/** + * @brief Constructor + * + * All inherited classes should call this constructor. + * + * @param args Entries from federated_plugin in communicator environments. + */ + explicit BasePlugin( + std::vector> const &args) { + debug_ = get_bool(args, "debug"); + print_timing_ = get_bool(args, "print_timing"); + dam_debug_ = get_bool(args, "dam_debug"); + } + + /** + * @brief Destructor + */ + virtual ~BasePlugin() = default; + + /** + * @brief Identity for the plugin used for debug + * + * This is a string with instance address and process id. + */ + std::string Ident() { + std::stringstream ss; + ss << std::hex << std::uppercase << std::setw(sizeof(void*) * 2) << std::setfill('0') << + reinterpret_cast(this); + return ss.str() + "-" + std::to_string(getpid()); + } + + /** + * @brief Encrypt the gradient pairs + * + * @param in_gpair Input g and h pairs for each record + * @param n_in The array size (2xnum_of_records) + * @param out_gpair Pointer to encrypted buffer + * @param n_out Encrypted buffer size + */ + virtual void EncryptGPairs(float const *in_gpair, std::size_t n_in, + std::uint8_t **out_gpair, std::size_t *n_out) = 0; + + /** + * @brief Process encrypted gradient pairs + * + * @param in_gpair Encrypted gradient pairs + * @param n_bytes Buffer size of Encrypted gradient + * @param out_gpair Pointer to decrypted gradient pairs + * @param out_n_bytes Decrypted buffer size + */ + virtual void SyncEncryptedGPairs(std::uint8_t const *in_gpair, std::size_t n_bytes, + std::uint8_t const **out_gpair, + std::size_t *out_n_bytes) = 0; + + /** + * @brief Reset the histogram context + * + * @param cutptrs Cut-pointers for the flattened histograms + * @param cutptr_len cutptrs array size (number of features plus one) + * @param bin_idx An array (flattened matrix) of slot index for each record/feature + * @param n_idx The size of above array + */ + virtual void ResetHistContext(std::uint32_t const *cutptrs, std::size_t cutptr_len, + std::int32_t const *bin_idx, std::size_t n_idx) = 0; + + /** + * @brief Encrypt histograms for horizontal training + * + * @param in_histogram The array for the histogram + * @param len The array size + * @param out_hist Pointer to encrypted buffer + * @param out_len Encrypted buffer size + */ + virtual void BuildEncryptedHistHori(double const *in_histogram, std::size_t len, + std::uint8_t **out_hist, std::size_t *out_len) = 0; + + /** + * @brief Process encrypted histograms for horizontal training + * + * @param buffer Buffer for encrypted histograms + * @param len Buffer size of encrypted histograms + * @param out_hist Pointer to decrypted histograms + * @param out_len Size of above array + */ + virtual void SyncEncryptedHistHori(std::uint8_t const *buffer, std::size_t len, + double **out_hist, std::size_t *out_len) = 0; + + /** + * @brief Build histograms in encrypted space for vertical training + * + * @param ridx Pointer to a matrix of row IDs for each node + * @param sizes An array of sizes of each node + * @param nidx An array for each node ID + * @param len Number of nodes + * @param out_hist Pointer to encrypted histogram buffer + * @param out_len Buffer size + */ + virtual void BuildEncryptedHistVert(std::uint64_t const **ridx, + std::size_t const *sizes, + std::int32_t const *nidx, std::size_t len, + std::uint8_t **out_hist, std::size_t *out_len) = 0; + + /** + * @brief Decrypt histogram for vertical training + * + * @param hist_buffer Encrypted histogram buffer + * @param len Buffer size of encrypted histogram + * @param out Pointer to decrypted histograms + * @param out_len Size of above array + */ + virtual void SyncEncryptedHistVert(std::uint8_t *hist_buffer, std::size_t len, + double **out, std::size_t *out_len) = 0; +}; +} // namespace nvflare diff --git a/integration/xgboost/encryption_plugins/src/include/dam.h b/integration/xgboost/encryption_plugins/src/include/dam.h new file mode 100644 index 0000000000..8677a413b1 --- /dev/null +++ b/integration/xgboost/encryption_plugins/src/include/dam.h @@ -0,0 +1,143 @@ +/** + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include +#include + +constexpr char kSignature[] = "NVDADAM1"; // DAM (Direct Accessible Marshalling) V1 +constexpr char kSignatureLocal[] = "NVDADAML"; // DAM Local version +constexpr int kPrefixLen = 24; + +constexpr int kDataTypeInt = 1; +constexpr int kDataTypeFloat = 2; +constexpr int kDataTypeString = 3; +constexpr int kDataTypeBuffer = 4; +constexpr int kDataTypeIntArray = 257; +constexpr int kDataTypeFloatArray = 258; +constexpr int kDataTypeBufferArray = 259; +constexpr int kDataTypeMap = 1025; + +/*! \brief A replacement for std::span */ +class Buffer { +public: + void *buffer; + size_t buf_size; + bool allocated; + + Buffer() : buffer(nullptr), buf_size(0), allocated(false) { + } + + Buffer(void *buffer, size_t buf_size, bool allocated=false) : + buffer(buffer), buf_size(buf_size), allocated(allocated) { + } + + Buffer(const Buffer &that): + buffer(that.buffer), buf_size(that.buf_size), allocated(false) { + } +}; + +class Entry { + public: + int64_t data_type; + const uint8_t * pointer; + int64_t size; + + Entry(int64_t data_type, const uint8_t *pointer, int64_t size) { + this->data_type = data_type; + this->pointer = pointer; + this->size = size; + } + + [[nodiscard]] std::size_t ItemSize() const + { + size_t item_size; + switch (data_type) { + case kDataTypeBuffer: + case kDataTypeString: + case kDataTypeBufferArray: + item_size = 1; + break; + default: + item_size = 8; + } + return item_size; + } +}; + +class DamEncoder { + private: + bool encoded_ = false; + bool local_version_ = false; + bool debug_ = false; + int64_t data_set_id_; + std::vector entries_; + + public: + explicit DamEncoder(int64_t data_set_id, bool local_version=false, bool debug=false) { + data_set_id_ = data_set_id; + local_version_ = local_version; + debug_ = debug; + + } + + void AddBuffer(const Buffer &buffer); + + void AddIntArray(const std::vector &value); + + void AddFloatArray(const std::vector &value); + + void AddBufferArray(const std::vector &value); + + std::uint8_t * Finish(size_t &size); + + private: + std::size_t CalculateSize(); +}; + +class DamDecoder { + private: + bool local_version_ = false; + std::uint8_t *buffer_ = nullptr; + std::size_t buf_size_ = 0; + std::uint8_t *pos_ = nullptr; + std::size_t remaining_ = 0; + int64_t data_set_id_ = 0; + int64_t len_ = 0; + bool debug_ = false; + + public: + explicit DamDecoder(std::uint8_t *buffer, std::size_t size, bool local_version=false, bool debug=false); + + [[nodiscard]] std::size_t Size() const { + return len_; + } + + [[nodiscard]] int64_t GetDataSetId() const { + return data_set_id_; + } + + [[nodiscard]] bool IsValid() const; + + Buffer DecodeBuffer(); + + std::vector DecodeIntArray(); + + std::vector DecodeFloatArray(); + + std::vector DecodeBufferArray(); +}; + +void print_buffer(const uint8_t *buffer, std::size_t size); diff --git a/integration/xgboost/encryption_plugins/src/include/data_set_ids.h b/integration/xgboost/encryption_plugins/src/include/data_set_ids.h new file mode 100644 index 0000000000..98eb20e838 --- /dev/null +++ b/integration/xgboost/encryption_plugins/src/include/data_set_ids.h @@ -0,0 +1,23 @@ +/** + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +constexpr int kDataSetGHPairs = 1; +constexpr int kDataSetAggregation = 2; +constexpr int kDataSetAggregationWithFeatures = 3; +constexpr int kDataSetAggregationResult = 4; +constexpr int kDataSetHistograms = 5; +constexpr int kDataSetHistogramResult = 6; diff --git a/integration/xgboost/encryption_plugins/src/include/delegated_plugin.h b/integration/xgboost/encryption_plugins/src/include/delegated_plugin.h new file mode 100644 index 0000000000..7b4f353b21 --- /dev/null +++ b/integration/xgboost/encryption_plugins/src/include/delegated_plugin.h @@ -0,0 +1,66 @@ +/** + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include "base_plugin.h" + +namespace nvflare { + +// Plugin that delegates to other real plugins +class DelegatedPlugin : public BasePlugin { + + BasePlugin *plugin_{nullptr}; + +public: + explicit DelegatedPlugin(std::vector> const &args); + + ~DelegatedPlugin() override { + delete plugin_; + } + + void EncryptGPairs(const float* in_gpair, std::size_t n_in, std::uint8_t** out_gpair, std::size_t* n_out) override { + plugin_->EncryptGPairs(in_gpair, n_in, out_gpair, n_out); + } + + void SyncEncryptedGPairs(const std::uint8_t* in_gpair, std::size_t n_bytes, const std::uint8_t** out_gpair, + std::size_t* out_n_bytes) override { + plugin_->SyncEncryptedGPairs(in_gpair, n_bytes, out_gpair, out_n_bytes); + } + + void ResetHistContext(const std::uint32_t* cutptrs, std::size_t cutptr_len, const std::int32_t* bin_idx, + std::size_t n_idx) override { + plugin_->ResetHistContext(cutptrs, cutptr_len, bin_idx, n_idx); + } + + void BuildEncryptedHistHori(const double* in_histogram, std::size_t len, std::uint8_t** out_hist, + std::size_t* out_len) override { + plugin_->BuildEncryptedHistHori(in_histogram, len, out_hist, out_len); + } + + void SyncEncryptedHistHori(const std::uint8_t* buffer, std::size_t len, double** out_hist, + std::size_t* out_len) override { + plugin_->SyncEncryptedHistHori(buffer, len, out_hist, out_len); + } + + void BuildEncryptedHistVert(const std::uint64_t** ridx, const std::size_t* sizes, const std::int32_t* nidx, + std::size_t len, std::uint8_t** out_hist, std::size_t* out_len) override { + plugin_->BuildEncryptedHistVert(ridx, sizes, nidx, len, out_hist, out_len); + } + + void SyncEncryptedHistVert(std::uint8_t* hist_buffer, std::size_t len, double** out, std::size_t* out_len) override { + plugin_->SyncEncryptedHistVert(hist_buffer, len, out, out_len); + } +}; +} // namespace nvflare diff --git a/integration/xgboost/encryption_plugins/src/include/local_plugin.h b/integration/xgboost/encryption_plugins/src/include/local_plugin.h new file mode 100644 index 0000000000..2022322266 --- /dev/null +++ b/integration/xgboost/encryption_plugins/src/include/local_plugin.h @@ -0,0 +1,107 @@ +/** + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "base_plugin.h" +#include "dam.h" + +namespace nvflare { + +// A base plugin for all plugins that handle encryption locally in C++ +class LocalPlugin : public BasePlugin { +protected: + std::vector gh_pairs_; + std::vector encrypted_gh_; + std::vector histo_; + std::vector cuts_; + std::vector slots_; + std::vector buffer_; + +public: + explicit LocalPlugin(std::vector> const &args) : + BasePlugin(args) {} + + ~LocalPlugin() override = default; + + void EncryptGPairs(const float *in_gpair, std::size_t n_in, std::uint8_t **out_gpair, + std::size_t *n_out) override; + + void SyncEncryptedGPairs(const std::uint8_t *in_gpair, std::size_t n_bytes, const std::uint8_t **out_gpair, + std::size_t *out_n_bytes) override; + + void ResetHistContext(const std::uint32_t *cutptrs, std::size_t cutptr_len, const std::int32_t *bin_idx, + std::size_t n_idx) override; + + void BuildEncryptedHistHori(const double *in_histogram, std::size_t len, std::uint8_t **out_hist, + std::size_t *out_len) override; + + void SyncEncryptedHistHori(const std::uint8_t *buffer, std::size_t len, double **out_hist, + std::size_t *out_len) override; + + void BuildEncryptedHistVert(const std::uint64_t **ridx, const std::size_t *sizes, const std::int32_t *nidx, + std::size_t len, std::uint8_t **out_hist, std::size_t *out_len) override; + + void SyncEncryptedHistVert(std::uint8_t *hist_buffer, std::size_t len, double **out, + std::size_t *out_len) override; + + // Method needs to be implemented by local plugins + + /*! + * \brief Encrypt a vector of float-pointing numbers + * \param cleartext A vector of numbers in cleartext + * \return A buffer with serialized ciphertext + */ + virtual Buffer EncryptVector(const std::vector &cleartext) = 0; + + /*! + * \brief Decrypt a serialized ciphertext into an array of numbers + * \param ciphertext A serialzied buffer of ciphertext + * \return An array of numbers + */ + virtual std::vector DecryptVector(const std::vector &ciphertext) = 0; + + /*! + * \brief Add the G&H pairs for a series of samples + * \param sample_ids A map of slot number and an array of sample IDs + * \return A map of the serialized encrypted sum of G and H for each slot + * The input and output maps must have the same size + */ + virtual std::map AddGHPairs(const std::map> &sample_ids) = 0; + + /*! + * \brief Free encrypted data buffer + * \param ciphertext The buffer for encrypted data + */ + virtual void FreeEncryptedData(Buffer &ciphertext) { + if (ciphertext.allocated && ciphertext.buffer != nullptr) { + free(ciphertext.buffer); + ciphertext.allocated = false; + } + ciphertext.buffer = nullptr; + ciphertext.buf_size = 0; + }; + +private: + + void BuildEncryptedHistVertActive(const std::uint64_t **ridx, const std::size_t *sizes, const std::int32_t *nidx, + std::size_t len, std::uint8_t **out_hist, std::size_t *out_len); + + void BuildEncryptedHistVertPassive(const std::uint64_t **ridx, const std::size_t *sizes, const std::int32_t *nidx, + std::size_t len, std::uint8_t **out_hist, std::size_t *out_len); + +}; + +} // namespace nvflare diff --git a/integration/xgboost/processor/src/include/nvflare_processor.h b/integration/xgboost/encryption_plugins/src/include/nvflare_plugin.h similarity index 77% rename from integration/xgboost/processor/src/include/nvflare_processor.h rename to integration/xgboost/encryption_plugins/src/include/nvflare_plugin.h index cb7076eaf4..87f47d622c 100644 --- a/integration/xgboost/processor/src/include/nvflare_processor.h +++ b/integration/xgboost/encryption_plugins/src/include/nvflare_plugin.h @@ -14,61 +14,65 @@ * limitations under the License. */ #pragma once + #include // for uint8_t, uint32_t, int32_t, int64_t #include // for string_view #include // for pair #include // for vector -const int kDataSetHGPairs = 1; -const int kDataSetAggregation = 2; -const int kDataSetAggregationWithFeatures = 3; -const int kDataSetAggregationResult = 4; -const int kDataSetHistograms = 5; -const int kDataSetHistogramResult = 6; - -// Opaque pointer type for the C API. -typedef void *FederatedPluginHandle; // NOLINT +#include "base_plugin.h" namespace nvflare { -// Plugin that uses Python tenseal and GRPC. -class TensealPlugin { + +// Plugin that uses Python TenSeal and GRPC. +class NvflarePlugin : public BasePlugin { // Buffer for storing encrypted gradient pairs. std::vector encrypted_gpairs_; // Buffer for histogram cut pointers (indptr of a CSC). std::vector cut_ptrs_; // Buffer for histogram index. std::vector bin_idx_; + std::vector gh_pairs_; bool feature_sent_{false}; // The feature index. std::vector features_; // Buffer for output histogram. std::vector encrypted_hist_; - std::vector hist_; + // A temporary buffer to hold return value + std::vector buffer_; + // Buffer for clear histogram + std::vector histo_; public: - TensealPlugin( - std::vector> const &args); + explicit NvflarePlugin(std::vector> const &args) : BasePlugin(args) {} + + ~NvflarePlugin() override = default; + // Gradient pairs void EncryptGPairs(float const *in_gpair, std::size_t n_in, - std::uint8_t **out_gpair, std::size_t *n_out); + std::uint8_t **out_gpair, std::size_t *n_out) override; + void SyncEncryptedGPairs(std::uint8_t const *in_gpair, std::size_t n_bytes, std::uint8_t const **out_gpair, - std::size_t *out_n_bytes); + std::size_t *out_n_bytes) override; // Histogram void ResetHistContext(std::uint32_t const *cutptrs, std::size_t cutptr_len, - std::int32_t const *bin_idx, std::size_t n_idx); + std::int32_t const *bin_idx, std::size_t n_idx) override; + void BuildEncryptedHistHori(double const *in_histogram, std::size_t len, - std::uint8_t **out_hist, std::size_t *out_len); + std::uint8_t **out_hist, std::size_t *out_len) override; + void SyncEncryptedHistHori(std::uint8_t const *buffer, std::size_t len, - double **out_hist, std::size_t *out_len); + double **out_hist, std::size_t *out_len) override; - void BuildEncryptedHistVert(std::size_t const **ridx, + void BuildEncryptedHistVert(std::uint64_t const **ridx, std::size_t const *sizes, std::int32_t const *nidx, std::size_t len, - std::uint8_t **out_hist, std::size_t *out_len); + std::uint8_t **out_hist, std::size_t *out_len) override; + void SyncEncryptedHistVert(std::uint8_t *hist_buffer, std::size_t len, - double **out, std::size_t *out_len); + double **out, std::size_t *out_len) override; }; } // namespace nvflare diff --git a/integration/xgboost/encryption_plugins/src/include/pass_thru_plugin.h b/integration/xgboost/encryption_plugins/src/include/pass_thru_plugin.h new file mode 100644 index 0000000000..3abeee4b56 --- /dev/null +++ b/integration/xgboost/encryption_plugins/src/include/pass_thru_plugin.h @@ -0,0 +1,41 @@ +/** + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include "local_plugin.h" + +namespace nvflare { + // A pass-through plugin that doesn't encrypt any data + class PassThruPlugin : public LocalPlugin { + public: + explicit PassThruPlugin(std::vector> const &args) : + LocalPlugin(args) {} + + ~PassThruPlugin() override = default; + + // Horizontal in local plugin still goes through NVFlare, so it needs to be overwritten + void BuildEncryptedHistHori(const double *in_histogram, std::size_t len, std::uint8_t **out_hist, + std::size_t *out_len) override; + + void SyncEncryptedHistHori(const std::uint8_t *buffer, std::size_t len, double **out_hist, + std::size_t *out_len) override; + + Buffer EncryptVector(const std::vector &cleartext) override; + + std::vector DecryptVector(const std::vector &ciphertext) override; + + std::map AddGHPairs(const std::map> &sample_ids) override; + }; +} // namespace nvflare diff --git a/integration/xgboost/encryption_plugins/src/include/util.h b/integration/xgboost/encryption_plugins/src/include/util.h new file mode 100644 index 0000000000..bb8ba16d1a --- /dev/null +++ b/integration/xgboost/encryption_plugins/src/include/util.h @@ -0,0 +1,18 @@ +#pragma once +#include +#include + +std::vector> distribute_work(size_t num_jobs, size_t num_workers); + +uint32_t to_int(double d); + +double to_double(uint32_t i); + +std::string get_string(std::vector> const &args, + std::string_view const &key,std::string_view default_value = ""); + +bool get_bool(std::vector> const &args, + const std::string &key, bool default_value = false); + +int get_int(std::vector> const &args, + const std::string &key, int default_value = 0); diff --git a/integration/xgboost/encryption_plugins/src/plugins/delegated_plugin.cc b/integration/xgboost/encryption_plugins/src/plugins/delegated_plugin.cc new file mode 100644 index 0000000000..a026822799 --- /dev/null +++ b/integration/xgboost/encryption_plugins/src/plugins/delegated_plugin.cc @@ -0,0 +1,36 @@ +/** + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "delegated_plugin.h" +#include "pass_thru_plugin.h" +#include "nvflare_plugin.h" + +namespace nvflare { + +DelegatedPlugin::DelegatedPlugin(std::vector> const &args): + BasePlugin(args) { + + auto name = get_string(args, "name"); + // std::cout << "==== Name is " << name << std::endl; + if (name == "pass-thru") { + plugin_ = new PassThruPlugin(args); + } else if (name == "nvflare") { + plugin_ = new NvflarePlugin(args); + } else { + throw std::invalid_argument{"Unknown plugin name: " + name}; + } +} + +} // namespace nvflare diff --git a/integration/xgboost/encryption_plugins/src/plugins/local_plugin.cc b/integration/xgboost/encryption_plugins/src/plugins/local_plugin.cc new file mode 100644 index 0000000000..99e304ea77 --- /dev/null +++ b/integration/xgboost/encryption_plugins/src/plugins/local_plugin.cc @@ -0,0 +1,366 @@ +/** + * Copyright 2014-2024 by XGBoost Contributors + */ +#include +#include +#include +#include "local_plugin.h" +#include "data_set_ids.h" + +namespace nvflare { + +void LocalPlugin::EncryptGPairs(const float *in_gpair, std::size_t n_in, std::uint8_t **out_gpair, std::size_t *n_out) { + if (debug_) { + std::cout << Ident() << " LocalPlugin::EncryptGPairs called with pairs size: " << n_in << std::endl; + } + + if (print_timing_) { + std::cout << "Encrypting " << n_in / 2 << " GH Pairs" << std::endl; + } + auto start = std::chrono::system_clock::now(); + + auto pairs = std::vector(in_gpair, in_gpair + n_in); + auto double_pairs = std::vector(pairs.cbegin(), pairs.cend()); + auto encrypted_data = EncryptVector(double_pairs); + + if (print_timing_) { + auto end = std::chrono::system_clock::now(); + auto secs = static_cast(std::chrono::duration_cast(end - start).count()) / 1000.0; + std::cout << "Encryption time: " << secs << " seconds" << std::endl; + } + + // Serialize with DAM so the buffers can be separated after all-gather + DamEncoder encoder(kDataSetGHPairs, true, dam_debug_); + encoder.AddBuffer(encrypted_data); + + std::size_t size; + auto buffer = encoder.Finish(size); + FreeEncryptedData(encrypted_data); + buffer_.resize(size); + std::copy_n(buffer, size, buffer_.begin()); + free(buffer); + + *out_gpair = buffer_.data(); + *n_out = buffer_.size(); + if (debug_) { + std::cout << "Encrypted GPairs:" << std::endl; + print_buffer(*out_gpair, *n_out); + } + + // Save pairs for future operations. This is only called on active site + gh_pairs_ = std::vector(double_pairs); +} + +void LocalPlugin::SyncEncryptedGPairs(const std::uint8_t *in_gpair, std::size_t n_bytes, + const std::uint8_t **out_gpair, std::size_t *out_n_bytes) { + if (debug_) { + std::cout << Ident() << " LocalPlugin::SyncEncryptedGPairs called with buffer:" << std::endl; + print_buffer(in_gpair, n_bytes); + } + + *out_n_bytes = n_bytes; + *out_gpair = in_gpair; + auto decoder = DamDecoder(const_cast(in_gpair), n_bytes, true, dam_debug_); + if (!decoder.IsValid()) { + std::cout << "LocalPlugin::SyncEncryptedGPairs called with wrong data" << std::endl; + return; + } + + auto encrypted_buffer = decoder.DecodeBuffer(); + if (debug_) { + std::cout << "Encrypted buffer size: " << encrypted_buffer.buf_size << std::endl; + } + + // The caller may free buffer so a copy is needed + auto pointer = static_cast(encrypted_buffer.buffer); + encrypted_gh_ = std::vector(pointer, pointer + encrypted_buffer.buf_size); + FreeEncryptedData(encrypted_buffer); +} + +void LocalPlugin::ResetHistContext(const std::uint32_t *cutptrs, std::size_t cutptr_len, const std::int32_t *bin_idx, + std::size_t n_idx) { + if (debug_) { + std::cout << Ident() << " LocalPlugin::ResetHistContext called with cutptrs size: " << cutptr_len << " bin_idx size: " + << n_idx << std::endl; + } + + cuts_ = std::vector(cutptrs, cutptrs + cutptr_len); + slots_ = std::vector(bin_idx, bin_idx + n_idx); +} + +void LocalPlugin::BuildEncryptedHistHori(const double *in_histogram, std::size_t len, std::uint8_t **out_hist, + std::size_t *out_len) { + if (debug_) { + std::cout << Ident() << " LocalPlugin::BuildEncryptedHistHori called with " << len << " entries" << std::endl; + print_buffer(reinterpret_cast(in_histogram), len); + } + + // don't have a local implementation yet, just encoded it and let NVFlare handle it. + DamEncoder encoder(kDataSetHistograms, false, dam_debug_); + auto histograms = std::vector(in_histogram, in_histogram + len); + encoder.AddFloatArray(histograms); + std::size_t size; + auto buffer = encoder.Finish(size); + buffer_.resize(size); + std::copy_n(buffer, size, buffer_.begin()); + free(buffer); + + *out_hist = buffer_.data(); + *out_len = buffer_.size(); + if (debug_) { + std::cout << "Output buffer" << std::endl; + print_buffer(*out_hist, *out_len); + } +} + +void LocalPlugin::SyncEncryptedHistHori(const std::uint8_t *buffer, std::size_t len, double **out_hist, + std::size_t *out_len) { + if (debug_) { + std::cout << Ident() << " LocalPlugin::SyncEncryptedHistHori called with buffer size: " << len << std::endl; + print_buffer(buffer, len); + } + auto remaining = len; + auto pointer = buffer; + + // The buffer is concatenated by AllGather. It may contain multiple DAM buffers + std::vector& result = histo_; + result.clear(); + while (remaining > kPrefixLen) { + DamDecoder decoder(const_cast(pointer), remaining, false, dam_debug_); + if (!decoder.IsValid()) { + std::cout << "Not DAM encoded histogram ignored at offset: " + << static_cast(pointer - buffer) << std::endl; + break; + } + + if (decoder.GetDataSetId() != kDataSetHistogramResult) { + throw std::runtime_error{"Invalid dataset: " + std::to_string(decoder.GetDataSetId())}; + } + + auto size = decoder.Size(); + auto histo = decoder.DecodeFloatArray(); + result.insert(result.end(), histo.cbegin(), histo.cend()); + + remaining -= size; + pointer += size; + } + + *out_hist = result.data(); + *out_len = result.size(); + + if (debug_) { + std::cout << "Output buffer" << std::endl; + print_buffer(reinterpret_cast(*out_hist), histo_.size() * sizeof(double)); + } +} + +void LocalPlugin::BuildEncryptedHistVert(const std::uint64_t **ridx, const std::size_t *sizes, const std::int32_t *nidx, + std::size_t len, std::uint8_t **out_hist, std::size_t *out_len) { + if (debug_) { + std::cout << Ident() << " LocalPlugin::BuildEncryptedHistVert called with number of nodes: " << len << std::endl; + } + + if (gh_pairs_.empty()) { + BuildEncryptedHistVertPassive(ridx, sizes, nidx, len, out_hist, out_len); + } else { + BuildEncryptedHistVertActive(ridx, sizes, nidx, len, out_hist, out_len); + } + + if (debug_) { + std::cout << "Encrypted histogram output:" << std::endl; + print_buffer(*out_hist, *out_len); + } +} + +void LocalPlugin::BuildEncryptedHistVertActive(const std::uint64_t **ridx, const std::size_t *sizes, const std::int32_t *, + std::size_t len, std::uint8_t **out_hist, std::size_t *out_len) { + + if (debug_) { + std::cout << Ident() << " LocalPlugin::BuildEncryptedHistVertActive called with " << len << " nodes" << std::endl; + } + + auto total_bin_size = cuts_.back(); + auto histo_size = total_bin_size * 2; + auto total_size = histo_size * len; + + histo_.clear(); + histo_.resize(total_size); + size_t start = 0; + for (std::size_t i = 0; i < len; i++) { + for (std::size_t j = 0; j < sizes[i]; j++) { + auto row_id = ridx[i][j]; + auto num = cuts_.size() - 1; + for (std::size_t f = 0; f < num; f++) { + int slot = slots_[f + num * row_id]; + if ((slot < 0) || (slot >= total_bin_size)) { + continue; + } + auto g = gh_pairs_[row_id * 2]; + auto h = gh_pairs_[row_id * 2 + 1]; + (histo_)[start + slot * 2] += g; + (histo_)[start + slot * 2 + 1] += h; + } + } + start += histo_size; + } + + // Histogram is in clear, can't send to all_gather. Just return empty DAM buffer + auto encoder = DamEncoder(kDataSetAggregationResult, true, dam_debug_); + encoder.AddBuffer(Buffer()); + std::size_t size; + auto buffer = encoder.Finish(size); + buffer_.resize(size); + std::copy_n(buffer, size, buffer_.begin()); + free(buffer); + *out_hist = buffer_.data(); + *out_len = size; +} + +void LocalPlugin::BuildEncryptedHistVertPassive(const std::uint64_t **ridx, const std::size_t *sizes, const std::int32_t *, + std::size_t len, std::uint8_t **out_hist, std::size_t *out_len) { + if (debug_) { + std::cout << Ident() << " LocalPlugin::BuildEncryptedHistVertPassive called with " << len << " nodes" << std::endl; + } + + auto num_slot = cuts_.back(); + auto total_size = num_slot * len; + + auto encrypted_histo = std::vector(total_size); + size_t offset = 0; + for (std::size_t i = 0; i < len; i++) { + auto num = cuts_.size() - 1; + auto row_id_map = std::map>(); + + // Empty slot leaks data so fill everything with empty vectors + for (int slot = 0; slot < num_slot; slot++) { + row_id_map.insert({slot, std::vector()}); + } + + for (std::size_t f = 0; f < num; f++) { + for (std::size_t j = 0; j < sizes[i]; j++) { + auto row_id = ridx[i][j]; + int slot = slots_[f + num * row_id]; + if ((slot < 0) || (slot >= num_slot)) { + continue; + } + auto &row_ids = row_id_map[slot]; + row_ids.push_back(static_cast(row_id)); + } + } + + if (print_timing_) { + std::size_t add_ops = 0; + for (auto &item: row_id_map) { + add_ops += item.second.size(); + } + std::cout << "Aggregating with " << add_ops << " additions" << std::endl; + } + auto start = std::chrono::system_clock::now(); + + auto encrypted_sum = AddGHPairs(row_id_map); + + if (print_timing_) { + auto end = std::chrono::system_clock::now(); + auto secs = static_cast(std::chrono::duration_cast(end - start).count()) / 1000.0; + std::cout << "Aggregation time: " << secs << " seconds" << std::endl; + } + + // Convert map back to array + for (int slot = 0; slot < num_slot; slot++) { + auto it = encrypted_sum.find(slot); + if (it != encrypted_sum.end()) { + encrypted_histo[offset + slot] = it->second; + } + } + + offset += num_slot; + } + + auto encoder = DamEncoder(kDataSetAggregationResult, true, dam_debug_); + encoder.AddBufferArray(encrypted_histo); + std::size_t size; + auto buffer = encoder.Finish(size); + for (auto &item: encrypted_histo) { + FreeEncryptedData(item); + } + buffer_.resize(size); + std::copy_n(buffer, size, buffer_.begin()); + free(buffer); + *out_hist = buffer_.data(); + *out_len = size; +} + +void LocalPlugin::SyncEncryptedHistVert(std::uint8_t *hist_buffer, std::size_t len, + double **out, std::size_t *out_len) { + if (debug_) { + std::cout << Ident() << " LocalPlugin::SyncEncryptedHistVert called with buffer size: " << len << " nodes" << std::endl; + print_buffer(hist_buffer, len); + } + + auto remaining = len; + auto pointer = hist_buffer; + + *out = nullptr; + *out_len = 0; + if (gh_pairs_.empty()) { + if (debug_) { + std::cout << Ident() << " LocalPlugin::SyncEncryptedHistVert Do nothing for passive worker" << std::endl; + } + // Do nothing for passive worker + return; + } + + // The buffer is concatenated by AllGather. It may contain multiple DAM buffers + auto first = true; + auto orig_size = histo_.size(); + while (remaining > kPrefixLen) { + DamDecoder decoder(pointer, remaining, true, dam_debug_); + if (!decoder.IsValid()) { + std::cout << "Not DAM encoded buffer ignored at offset: " + << static_cast((pointer - hist_buffer)) << std::endl; + break; + } + auto size = decoder.Size(); + if (first) { + if (histo_.empty()) { + std::cout << "No clear histogram." << std::endl; + return; + } + first = false; + } else { + auto encrypted_buf = decoder.DecodeBufferArray(); + + if (print_timing_) { + std::cout << "Decrypting " << encrypted_buf.size() << " pairs" << std::endl; + } + auto start = std::chrono::system_clock::now(); + + auto decrypted_histo = DecryptVector(encrypted_buf); + + if (print_timing_) { + auto end = std::chrono::system_clock::now(); + auto secs = static_cast(std::chrono::duration_cast(end - start).count()) / 1000.0; + std::cout << "Decryption time: " << secs << " seconds" << std::endl; + } + + if (decrypted_histo.size() != orig_size) { + std::cout << "Histo sizes are different: " << decrypted_histo.size() + << " != " << orig_size << std::endl; + } + histo_.insert(histo_.end(), decrypted_histo.cbegin(), decrypted_histo.cend()); + } + remaining -= size; + pointer += size; + } + + if (debug_) { + std::cout << Ident() << " Decrypted result size: " << histo_.size() << std::endl; + } + + // print_buffer(reinterpret_cast(result.data()), result.size()*8); + + *out = histo_.data(); + *out_len = histo_.size(); +} + +} // namespace nvflare diff --git a/integration/xgboost/encryption_plugins/src/plugins/nvflare_plugin.cc b/integration/xgboost/encryption_plugins/src/plugins/nvflare_plugin.cc new file mode 100644 index 0000000000..b062aecfa6 --- /dev/null +++ b/integration/xgboost/encryption_plugins/src/plugins/nvflare_plugin.cc @@ -0,0 +1,297 @@ +/** + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include // for copy_n, transform +#include // for memcpy +#include // for invalid_argument +#include // for vector + +#include "nvflare_plugin.h" +#include "data_set_ids.h" +#include "dam.h" // for DamEncoder + +namespace nvflare { + +void NvflarePlugin::EncryptGPairs(float const *in_gpair, std::size_t n_in, + std::uint8_t **out_gpair, + std::size_t *n_out) { + if (debug_) { + std::cout << Ident() << " NvflarePlugin::EncryptGPairs called with pairs size: " << n_in<< std::endl; + } + + auto pairs = std::vector(in_gpair, in_gpair + n_in); + gh_pairs_ = std::vector(pairs.cbegin(), pairs.cend()); + + DamEncoder encoder(kDataSetGHPairs, false, dam_debug_); + encoder.AddFloatArray(gh_pairs_); + std::size_t size; + auto buffer = encoder.Finish(size); + if (!out_gpair) { + throw std::invalid_argument{"Invalid pointer to output gpair."}; + } + buffer_.resize(size); + std::copy_n(buffer, size, buffer_.begin()); + free(buffer); + *out_gpair = buffer_.data(); + *n_out = size; +} + +void NvflarePlugin::SyncEncryptedGPairs(std::uint8_t const *in_gpair, + std::size_t n_bytes, + std::uint8_t const **out_gpair, + std::size_t *out_n_bytes) { + if (debug_) { + std::cout << Ident() << " NvflarePlugin::SyncEncryptedGPairs called with buffer size: " << n_bytes << std::endl; + } + + // For NVFlare plugin, nothing needs to be done here + *out_n_bytes = n_bytes; + *out_gpair = in_gpair; +} + +void NvflarePlugin::ResetHistContext(std::uint32_t const *cutptrs, + std::size_t cutptr_len, + std::int32_t const *bin_idx, + std::size_t n_idx) { + if (debug_) { + std::cout << Ident() << " NvFlarePlugin::ResetHistContext called with cutptrs size: " << cutptr_len << " bin_idx size: " + << n_idx<< std::endl; + } + + cut_ptrs_.resize(cutptr_len); + std::copy_n(cutptrs, cutptr_len, cut_ptrs_.begin()); + bin_idx_.resize(n_idx); + std::copy_n(bin_idx, n_idx, this->bin_idx_.begin()); +} + +void NvflarePlugin::BuildEncryptedHistVert(std::uint64_t const **ridx, + std::size_t const *sizes, + std::int32_t const *nidx, + std::size_t len, + std::uint8_t** out_hist, + std::size_t* out_len) { + if (debug_) { + std::cout << Ident() << " NvflarePlugin::BuildEncryptedHistVert called with len: " << len << std::endl; + } + + std::int64_t data_set_id; + if (!feature_sent_) { + data_set_id = kDataSetAggregationWithFeatures; + feature_sent_ = true; + } else { + data_set_id = kDataSetAggregation; + } + + DamEncoder encoder(data_set_id, false, dam_debug_); + + // Add cuts pointers + std::vector cuts_vec(cut_ptrs_.cbegin(), cut_ptrs_.cend()); + encoder.AddIntArray(cuts_vec); + + auto num_features = cut_ptrs_.size() - 1; + auto num_samples = bin_idx_.size() / num_features; + if (debug_) { + std::cout << "Samples: " << num_samples << " Features: " << num_features << std::endl; + } + + std::vector bins; + if (data_set_id == kDataSetAggregationWithFeatures) { + if (features_.empty()) { // when is it not empty? + for (int64_t f = 0; f < num_features; f++) { + auto slot = bin_idx_[f]; + if (slot >= 0) { + // what happens if it's missing? + features_.push_back(f); + } + } + } + encoder.AddIntArray(features_); + + for (int i = 0; i < num_samples; i++) { + for (auto f : features_) { + auto index = f + i * num_features; + if (index > bin_idx_.size()) { + throw std::out_of_range{"Index is out of range: " + + std::to_string(index)}; + } + auto slot = bin_idx_[index]; + bins.push_back(slot); + } + } + encoder.AddIntArray(bins); + } + + // Add nodes to build + std::vector node_vec(len); + for (std::size_t i = 0; i < len; i++) { + node_vec[i] = nidx[i]; + } + encoder.AddIntArray(node_vec); + + // For each node, get the row_id/slot pair + auto row_ids = std::vector>(len); + for (std::size_t i = 0; i < len; ++i) { + auto& rows = row_ids[i]; + rows.resize(sizes[i]); + for (std::size_t j = 0; j < sizes[i]; j++) { + rows[j] = static_cast(ridx[i][j]); + } + encoder.AddIntArray(rows); + } + + std::size_t n{0}; + auto buffer = encoder.Finish(n); + if (debug_) { + std::cout << "Finished size: " << n << std::endl; + } + + // XGBoost doesn't allow the change of allgatherV sizes. Make sure it's big + // enough to carry histograms + auto max_slot = cut_ptrs_.back(); + auto histo_size = 2 * max_slot * sizeof(double) * len + 1024*1024; // 1M is DAM overhead + auto buf_size = histo_size > n ? histo_size : n; + + // Copy to an array so the buffer can be freed, should change encoder to return vector + buffer_.resize(buf_size); + std::copy_n(buffer, n, buffer_.begin()); + free(buffer); + + *out_hist = buffer_.data(); + *out_len = buffer_.size(); +} + +void NvflarePlugin::SyncEncryptedHistVert(std::uint8_t *buffer, + std::size_t buf_size, + double **out, + std::size_t *out_len) { + if (debug_) { + std::cout << Ident() << " NvflarePlugin::SyncEncryptedHistVert called with buffer size: " << buf_size << std::endl; + } + + auto remaining = buf_size; + char *pointer = reinterpret_cast(buffer); + + // The buffer is concatenated by AllGather. It may contain multiple DAM buffers + std::vector &result = histo_; + result.clear(); + auto max_slot = cut_ptrs_.back(); + auto array_size = 2 * max_slot * sizeof(double); + + // A new histogram array? + auto slots = static_cast(malloc(array_size)); + while (remaining > kPrefixLen) { + DamDecoder decoder(reinterpret_cast(pointer), remaining, false, dam_debug_); + if (!decoder.IsValid()) { + std::cout << "Not DAM encoded buffer ignored at offset: " + << static_cast((pointer - reinterpret_cast(buffer))) << std::endl; + break; + } + auto size = decoder.Size(); + auto node_list = decoder.DecodeIntArray(); + if (debug_) { + std::cout << "Number of nodes: " << node_list.size() << " Histo size: " << 2*max_slot << std::endl; + } + for ([[maybe_unused]] auto node : node_list) { + std::memset(slots, 0, array_size); + auto feature_list = decoder.DecodeIntArray(); + // Convert per-feature histo to a flat one + for (auto f : feature_list) { + auto base = cut_ptrs_[f]; // cut pointer for the current feature + auto bins = decoder.DecodeFloatArray(); + auto n = bins.size() / 2; + for (int i = 0; i < n; i++) { + auto index = base + i; + // [Q] Build local histogram? Why does it need to be built here? + slots[2 * index] += bins[2 * i]; + slots[2 * index + 1] += bins[2 * i + 1]; + } + } + result.insert(result.end(), slots, slots + 2 * max_slot); + } + remaining -= size; + pointer += size; + } + free(slots); + + // result is a reference to a histo_ + *out_len = result.size(); + *out = result.data(); + if (debug_) { + std::cout << "Total histogram size: " << *out_len << std::endl; + } +} + +void NvflarePlugin::BuildEncryptedHistHori(double const *in_histogram, + std::size_t len, + std::uint8_t **out_hist, + std::size_t *out_len) { + if (debug_) { + std::cout << Ident() << " NvflarePlugin::BuildEncryptedHistHori called with histo size: " << len << std::endl; + } + + DamEncoder encoder(kDataSetHistograms, false, dam_debug_); + std::vector copy(in_histogram, in_histogram + len); + encoder.AddFloatArray(copy); + + std::size_t size{0}; + auto buffer = encoder.Finish(size); + buffer_.resize(size); + std::copy_n(buffer, size, buffer_.begin()); + free(buffer); + + *out_hist = this->buffer_.data(); + *out_len = this->buffer_.size(); +} + +void NvflarePlugin::SyncEncryptedHistHori(std::uint8_t const *buffer, + std::size_t len, + double **out_hist, + std::size_t *out_len) { + if (debug_) { + std::cout << Ident() << " NvflarePlugin::SyncEncryptedHistHori called with buffer size: " << len << std::endl; + } + + auto remaining = len; + auto pointer = buffer; + + // The buffer is concatenated by AllGather. It may contain multiple DAM buffers + std::vector& result = histo_; + result.clear(); + while (remaining > kPrefixLen) { + DamDecoder decoder(const_cast(pointer), remaining, false, dam_debug_); + if (!decoder.IsValid()) { + std::cout << "Not DAM encoded histogram ignored at offset: " + << static_cast(pointer - buffer) << std::endl; + break; + } + + if (decoder.GetDataSetId() != kDataSetHistogramResult) { + throw std::runtime_error{"Invalid dataset: " + std::to_string(decoder.GetDataSetId())}; + } + + auto size = decoder.Size(); + auto histo = decoder.DecodeFloatArray(); + result.insert(result.end(), histo.cbegin(), histo.cend()); + + remaining -= size; + pointer += size; + } + + *out_hist = result.data(); + *out_len = result.size(); +} + +} // namespace nvflare diff --git a/integration/xgboost/encryption_plugins/src/plugins/pass_thru_plugin.cc b/integration/xgboost/encryption_plugins/src/plugins/pass_thru_plugin.cc new file mode 100644 index 0000000000..4a29d0ed2b --- /dev/null +++ b/integration/xgboost/encryption_plugins/src/plugins/pass_thru_plugin.cc @@ -0,0 +1,130 @@ +/** + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include + +#include "pass_thru_plugin.h" +#include "data_set_ids.h" + +namespace nvflare { + +void PassThruPlugin::BuildEncryptedHistHori(const double *in_histogram, std::size_t len, + std::uint8_t **out_hist, std::size_t *out_len) { + if (debug_) { + std::cout << Ident() << " PassThruPlugin::BuildEncryptedHistHori called with " << len << " entries" << std::endl; + } + + DamEncoder encoder(kDataSetHistogramResult, true, dam_debug_); + auto array = std::vector(in_histogram, in_histogram + len); + encoder.AddFloatArray(array); + std::size_t size; + auto buffer = encoder.Finish(size); + buffer_.resize(size); + std::copy_n(buffer, size, buffer_.begin()); + free(buffer); + *out_hist = buffer_.data(); + *out_len = buffer_.size(); +} + +void PassThruPlugin::SyncEncryptedHistHori(const std::uint8_t *buffer, std::size_t len, + double **out_hist, std::size_t *out_len) { + if (debug_) { + std::cout << Ident() << " PassThruPlugin::SyncEncryptedHistHori called with buffer size: " << len << std::endl; + } + + auto remaining = len; + auto pointer = buffer; + + // The buffer is concatenated by AllGather. It may contain multiple DAM buffers + std::vector& result = histo_; + result.clear(); + while (remaining > kPrefixLen) { + DamDecoder decoder(const_cast(pointer), remaining, true, dam_debug_); + if (!decoder.IsValid()) { + std::cout << "Not DAM encoded histogram ignored at offset: " + << static_cast(pointer - buffer) << std::endl; + break; + } + auto size = decoder.Size(); + auto histo = decoder.DecodeFloatArray(); + result.insert(result.end(), histo.cbegin(), histo.cend()); + + remaining -= size; + pointer += size; + } + + *out_hist = result.data(); + *out_len = result.size(); +} + +Buffer PassThruPlugin::EncryptVector(const std::vector& cleartext) { + if (debug_ && cleartext.size() > 2) { + std::cout << "PassThruPlugin::EncryptVector called with cleartext size: " << cleartext.size() << std::endl; + } + + size_t size = cleartext.size() * sizeof(double); + auto buf = static_cast(malloc(size)); + std::copy_n(reinterpret_cast(cleartext.data()), size, buf); + + return {buf, size, true}; +} + +std::vector PassThruPlugin::DecryptVector(const std::vector& ciphertext) { + if (debug_) { + std::cout << "PassThruPlugin::DecryptVector with ciphertext size: " << ciphertext.size() << std::endl; + } + + std::vector result; + + for (auto const &v : ciphertext) { + size_t n = v.buf_size/sizeof(double); + auto p = static_cast(v.buffer); + for (int i = 0; i < n; i++) { + result.push_back(p[i]); + } + } + + return result; +} + +std::map PassThruPlugin::AddGHPairs(const std::map>& sample_ids) { + if (debug_) { + std::cout << "PassThruPlugin::AddGHPairs called with " << sample_ids.size() << " slots" << std::endl; + } + + // Can't do this in real plugin. It needs to be broken into encrypted parts + auto gh_pairs = DecryptVector(std::vector{Buffer(encrypted_gh_.data(), encrypted_gh_.size())}); + + auto result = std::map(); + for (auto const &entry : sample_ids) { + auto rows = entry.second; + double g = 0.0; + double h = 0.0; + + for (auto row : rows) { + g += gh_pairs[2 * row]; + h += gh_pairs[2 * row + 1]; + } + // In real plugin, the sum should be still in encrypted state. No need to do this step + auto encrypted_sum = EncryptVector(std::vector{g, h}); + // print_buffer(reinterpret_cast(encrypted_sum.buffer), encrypted_sum.buf_size); + result.insert({entry.first, encrypted_sum}); + } + + return result; +} + +} // namespace nvflare diff --git a/integration/xgboost/encryption_plugins/src/plugins/plugin_main.cc b/integration/xgboost/encryption_plugins/src/plugins/plugin_main.cc new file mode 100644 index 0000000000..4c1d43a6f8 --- /dev/null +++ b/integration/xgboost/encryption_plugins/src/plugins/plugin_main.cc @@ -0,0 +1,184 @@ +/** + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include // for shared_ptr +#include // for invalid_argument +#include // for string_view +#include // for vector +#include // for transform + +#include "delegated_plugin.h" + +// Opaque pointer type for the C API. +typedef void *FederatedPluginHandle; // NOLINT + +namespace nvflare { +namespace { +// The opaque type for the C handle. +using CHandleT = std::shared_ptr *; +// Actual representation used in C++ code base. +using HandleT = std::remove_pointer_t; + +std::string &GlobalErrorMsg() { + static thread_local std::string msg; + return msg; +} + +// Perform handle handling for C API functions. +template auto CApiGuard(FederatedPluginHandle handle, Fn &&fn) { + auto pptr = static_cast(handle); + if (!pptr) { + return 1; + } + + try { + if constexpr (std::is_void_v>) { + fn(*pptr); + return 0; + } else { + return fn(*pptr); + } + } catch (std::exception const &e) { + GlobalErrorMsg() = e.what(); + return 1; + } +} +} // namespace +} // namespace nvflare + +#if defined(_MSC_VER) || defined(_WIN32) +#define NVF_C __declspec(dllexport) +#else +#define NVF_C __attribute__((visibility("default"))) +#endif // defined(_MSC_VER) || defined(_WIN32) + +extern "C" { +NVF_C char const *FederatedPluginErrorMsg() { + return nvflare::GlobalErrorMsg().c_str(); +} + +FederatedPluginHandle NVF_C FederatedPluginCreate(int argc, char const **argv) { + // std::cout << "==== FedreatedPluginCreate called with argc=" << argc << std::endl; + using namespace nvflare; + try { + auto pptr = new std::shared_ptr; + std::vector> args; + std::transform( + argv, argv + argc, std::back_inserter(args), [](char const *carg) { + // Split a key value pair in contructor argument: `key=value` + std::string_view arg{carg}; + auto idx = arg.find('='); + if (idx == std::string_view::npos) { + // `=` not found + throw std::invalid_argument{"Invalid argument:" + std::string{arg}}; + } + auto key = arg.substr(0, idx); + auto value = arg.substr(idx + 1); + return std::make_pair(key, value); + }); + *pptr = std::make_shared(args); + // std::cout << "==== Plugin created: " << pptr << std::endl; + return pptr; + } catch (std::exception const &e) { + // std::cout << "==== Create exception " << e.what() << std::endl; + GlobalErrorMsg() = e.what(); + return nullptr; + } +} + +int NVF_C FederatedPluginClose(FederatedPluginHandle handle) { + using namespace nvflare; + auto pptr = static_cast(handle); + if (!pptr) { + return 1; + } + + delete pptr; + + return 0; +} + +int NVF_C FederatedPluginEncryptGPairs(FederatedPluginHandle handle, + float const *in_gpair, size_t n_in, + uint8_t **out_gpair, size_t *n_out) { + using namespace nvflare; + return CApiGuard(handle, [&](HandleT const &plugin) { + plugin->EncryptGPairs(in_gpair, n_in, out_gpair, n_out); + return 0; + }); +} + +int NVF_C FederatedPluginSyncEncryptedGPairs(FederatedPluginHandle handle, + uint8_t const *in_gpair, + size_t n_bytes, + uint8_t const **out_gpair, + size_t *n_out) { + using namespace nvflare; + return CApiGuard(handle, [&](HandleT const &plugin) { + plugin->SyncEncryptedGPairs(in_gpair, n_bytes, out_gpair, n_out); + }); +} + +int NVF_C FederatedPluginResetHistContextVert(FederatedPluginHandle handle, + uint32_t const *cutptrs, + size_t cutptr_len, + int32_t const *bin_idx, + size_t n_idx) { + using namespace nvflare; + return CApiGuard(handle, [&](HandleT const &plugin) { + plugin->ResetHistContext(cutptrs, cutptr_len, bin_idx, n_idx); + }); +} + +int NVF_C FederatedPluginBuildEncryptedHistVert( + FederatedPluginHandle handle, uint64_t const **ridx, size_t const *sizes, + int32_t const *nidx, size_t len, uint8_t **out_hist, size_t *out_len) { + using namespace nvflare; + return CApiGuard(handle, [&](HandleT const &plugin) { + plugin->BuildEncryptedHistVert(ridx, sizes, nidx, len, out_hist, out_len); + }); +} + +int NVF_C FederatedPluginSyncEncryptedHistVert(FederatedPluginHandle handle, + uint8_t *in_hist, size_t len, + double **out_hist, + size_t *out_len) { + using namespace nvflare; + return CApiGuard(handle, [&](HandleT const &plugin) { + plugin->SyncEncryptedHistVert(in_hist, len, out_hist, out_len); + }); +} + +int NVF_C FederatedPluginBuildEncryptedHistHori(FederatedPluginHandle handle, + double const *in_hist, + size_t len, uint8_t **out_hist, + size_t *out_len) { + using namespace nvflare; + return CApiGuard(handle, [&](HandleT const &plugin) { + plugin->BuildEncryptedHistHori(in_hist, len, out_hist, out_len); + }); +} + +int NVF_C FederatedPluginSyncEncryptedHistHori(FederatedPluginHandle handle, + uint8_t const *in_hist, + size_t len, double **out_hist, + size_t *out_len) { + using namespace nvflare; + return CApiGuard(handle, [&](HandleT const &plugin) { + plugin->SyncEncryptedHistHori(in_hist, len, out_hist, out_len); + return 0; + }); +} +} // extern "C" diff --git a/integration/xgboost/encryption_plugins/src/plugins/util.cc b/integration/xgboost/encryption_plugins/src/plugins/util.cc new file mode 100644 index 0000000000..a0cbd922d4 --- /dev/null +++ b/integration/xgboost/encryption_plugins/src/plugins/util.cc @@ -0,0 +1,99 @@ +/** + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include "util.h" + + +constexpr double kScaleFactor = 1000000.0; + +std::vector> distribute_work(size_t num_jobs, size_t const num_workers) { + std::vector> result; + auto num = num_jobs / num_workers; + auto remainder = num_jobs % num_workers; + int start = 0; + for (int i = 0; i < num_workers; i++) { + auto stop = static_cast((start + num - 1)); + if (i < remainder) { + // If jobs cannot be evenly distributed, first few workers take an extra one + stop += 1; + } + + if (start <= stop) { + result.emplace_back(start, stop); + } + start = stop + 1; + } + + // Verify all jobs are distributed + int sum = 0; + for (auto &item: result) { + sum += item.second - item.first + 1; + } + + if (sum != num_jobs) { + std::cout << "Distribution error" << std::endl; + } + + return result; +} + +uint32_t to_int(double d) { + auto int_val = static_cast(d * kScaleFactor); + return static_cast(int_val); +} + +double to_double(uint32_t i) { + auto int_val = static_cast(i); + return static_cast(int_val / kScaleFactor); +} + +std::string get_string(std::vector> const &args, + std::string_view const &key, std::string_view const default_value) { + + auto it = find_if( + args.begin(), args.end(), + [key](const auto &p) { return p.first == key; }); + + if (it != args.end()) { + return std::string{it->second}; + } + + return std::string{default_value}; +} + +bool get_bool(std::vector> const &args, + const std::string &key, bool default_value) { + std::string value = get_string(args, key, ""); + if (value.empty()) { + return default_value; + } + std::transform(value.begin(), value.end(), value.begin(), [](unsigned char c) { return std::tolower(c); }); + auto true_values = std::set < std::string_view > {"true", "yes", "y", "on", "1"}; + return true_values.count(value) > 0; +} + +int get_int(std::vector> const &args, + const std::string &key, int default_value) { + + auto value = get_string(args, key, ""); + if (value.empty()) { + return default_value; + } + + return stoi(value, nullptr); +} diff --git a/integration/xgboost/encryption_plugins/tests/CMakeLists.txt b/integration/xgboost/encryption_plugins/tests/CMakeLists.txt new file mode 100644 index 0000000000..04580bdd59 --- /dev/null +++ b/integration/xgboost/encryption_plugins/tests/CMakeLists.txt @@ -0,0 +1,14 @@ +file(GLOB_RECURSE TEST_SOURCES "*.cc") + +target_sources(xgb_nvflare_test PRIVATE ${TEST_SOURCES}) + +target_include_directories(xgb_nvflare_test + PRIVATE + ${GTEST_INCLUDE_DIRS} + ${xgb_nvflare_SOURCE_DIR/tests} + ${xgb_nvflare_SOURCE_DIR}/src) + +message("Include Dir: ${GTEST_INCLUDE_DIRS}") +target_link_libraries(xgb_nvflare_test + PRIVATE + ${GTEST_LIBRARIES}) diff --git a/integration/xgboost/processor/tests/test_dam.cc b/integration/xgboost/encryption_plugins/tests/test_dam.cc similarity index 65% rename from integration/xgboost/processor/tests/test_dam.cc rename to integration/xgboost/encryption_plugins/tests/test_dam.cc index 5573d5440d..345978b110 100644 --- a/integration/xgboost/processor/tests/test_dam.cc +++ b/integration/xgboost/encryption_plugins/tests/test_dam.cc @@ -19,20 +19,45 @@ TEST(DamTest, TestEncodeDecode) { double float_array[] = {1.1, 1.2, 1.3, 1.4}; int64_t int_array[] = {123, 456, 789}; + char buf1[] = "short"; + char buf2[] = "very long"; + DamEncoder encoder(123); + auto b1 = Buffer(buf1, strlen(buf1)); + auto b2 = Buffer(buf2, strlen(buf2)); + encoder.AddBuffer(b1); + encoder.AddBuffer(b2); + + std::vector b{b1, b2}; + encoder.AddBufferArray(b); + auto f = std::vector(float_array, float_array + 4); encoder.AddFloatArray(f); + auto i = std::vector(int_array, int_array + 3); encoder.AddIntArray(i); + size_t size; auto buf = encoder.Finish(size); std::cout << "Encoded size is " << size << std::endl; - DamDecoder decoder(buf.data(), size); + // Decoding test + DamDecoder decoder(buf, size); EXPECT_EQ(decoder.IsValid(), true); EXPECT_EQ(decoder.GetDataSetId(), 123); + auto new_buf1 = decoder.DecodeBuffer(); + EXPECT_EQ(0, memcmp(new_buf1.buffer, buf1, new_buf1.buf_size)); + + auto new_buf2 = decoder.DecodeBuffer(); + EXPECT_EQ(0, memcmp(new_buf2.buffer, buf2, new_buf2.buf_size)); + + auto buf_vec = decoder.DecodeBufferArray(); + EXPECT_EQ(2, buf_vec.size()); + EXPECT_EQ(0, memcmp(buf_vec[0].buffer, buf1, buf_vec[0].buf_size)); + EXPECT_EQ(0, memcmp(buf_vec[1].buffer, buf2, buf_vec[1].buf_size)); + auto float_vec = decoder.DecodeFloatArray(); EXPECT_EQ(0, memcmp(float_vec.data(), float_array, float_vec.size()*8)); diff --git a/integration/xgboost/processor/tests/test_main.cc b/integration/xgboost/encryption_plugins/tests/test_main.cc similarity index 100% rename from integration/xgboost/processor/tests/test_main.cc rename to integration/xgboost/encryption_plugins/tests/test_main.cc diff --git a/integration/xgboost/processor/tests/test_tenseal.py b/integration/xgboost/encryption_plugins/tests/test_tenseal.py similarity index 100% rename from integration/xgboost/processor/tests/test_tenseal.py rename to integration/xgboost/encryption_plugins/tests/test_tenseal.py diff --git a/integration/xgboost/processor/CMakeLists.txt b/integration/xgboost/processor/CMakeLists.txt deleted file mode 100644 index 056fd365e2..0000000000 --- a/integration/xgboost/processor/CMakeLists.txt +++ /dev/null @@ -1,46 +0,0 @@ -cmake_minimum_required(VERSION 3.19) -project(proc_nvflare LANGUAGES CXX C VERSION 1.0) -set(CMAKE_CXX_STANDARD 17) -set(CMAKE_BUILD_TYPE Debug) - -option(GOOGLE_TEST "Build google tests" OFF) - -file(GLOB_RECURSE LIB_SRC "src/*.cc") - -add_library(proc_nvflare SHARED ${LIB_SRC}) -set_target_properties(proc_nvflare PROPERTIES - CXX_STANDARD 17 - CXX_STANDARD_REQUIRED ON - POSITION_INDEPENDENT_CODE ON - ENABLE_EXPORTS ON -) -target_include_directories(proc_nvflare PRIVATE ${proc_nvflare_SOURCE_DIR}/src/include) - -if (APPLE) - add_link_options("LINKER:-object_path_lto,$_lto.o") - add_link_options("LINKER:-cache_path_lto,${CMAKE_BINARY_DIR}/LTOCache") -endif () - -#-- Unit Tests -if(GOOGLE_TEST) - find_package(GTest REQUIRED) - enable_testing() - add_executable(proc_test) - target_link_libraries(proc_test PRIVATE proc_nvflare) - - - target_include_directories(proc_test PRIVATE ${proc_nvflare_SOURCE_DIR}/src/include - ${XGB_SRC}/src - ${XGB_SRC}/rabit/include - ${XGB_SRC}/include - ${XGB_SRC}/dmlc-core/include - ${XGB_SRC}/tests) - - add_subdirectory(${proc_nvflare_SOURCE_DIR}/tests) - - add_test( - NAME TestProcessor - COMMAND proc_test - WORKING_DIRECTORY ${proc_nvflare_BINARY_DIR}) - -endif() diff --git a/integration/xgboost/processor/README.md b/integration/xgboost/processor/README.md deleted file mode 100644 index e879081b84..0000000000 --- a/integration/xgboost/processor/README.md +++ /dev/null @@ -1,11 +0,0 @@ -# Build Instruction - -``` sh -cd NVFlare/integration/xgboost/processor -mkdir build -cd build -cmake .. -make -``` - -See [tests](./tests) for simple examples. \ No newline at end of file diff --git a/integration/xgboost/processor/src/dam/dam.cc b/integration/xgboost/processor/src/dam/dam.cc deleted file mode 100644 index 10625ab9b5..0000000000 --- a/integration/xgboost/processor/src/dam/dam.cc +++ /dev/null @@ -1,146 +0,0 @@ -/** - * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include -#include -#include "dam.h" - -void print_buffer(uint8_t *buffer, int size) { - for (int i = 0; i < size; i++) { - auto c = buffer[i]; - std::cout << std::hex << (int) c << " "; - } - std::cout << std::endl << std::dec; -} - -// DamEncoder ====== -void DamEncoder::AddFloatArray(const std::vector &value) { - if (encoded) { - std::cout << "Buffer is already encoded" << std::endl; - return; - } - auto buf_size = value.size() * 8; - uint8_t *buffer = static_cast(malloc(buf_size)); - memcpy(buffer, value.data(), buf_size); - entries->push_back(new Entry(kDataTypeFloatArray, buffer, value.size())); -} - -void DamEncoder::AddIntArray(const std::vector &value) { - std::cout << "AddIntArray called, size: " << value.size() << std::endl; - if (encoded) { - std::cout << "Buffer is already encoded" << std::endl; - return; - } - auto buf_size = value.size()*8; - std::cout << "Allocating " << buf_size << " bytes" << std::endl; - uint8_t *buffer = static_cast(malloc(buf_size)); - memcpy(buffer, value.data(), buf_size); - // print_buffer(buffer, buf_size); - entries->push_back(new Entry(kDataTypeIntArray, buffer, value.size())); -} - -std::vector DamEncoder::Finish(size_t &size) { - encoded = true; - - size = calculate_size(); - std::vector buf(size); - auto pointer = buf.data(); - memcpy(pointer, kSignature, strlen(kSignature)); - memcpy(pointer + 8, &size, 8); - memcpy(pointer + 16, &data_set_id, 8); - - pointer += kPrefixLen; - for (auto entry : *entries) { - memcpy(pointer, &entry->data_type, 8); - pointer += 8; - memcpy(pointer, &entry->size, 8); - pointer += 8; - int len = 8*entry->size; - memcpy(pointer, entry->pointer, len); - free(entry->pointer); - pointer += len; - // print_buffer(entry->pointer, entry->size*8); - } - - if ((pointer - buf.data()) != size) { - throw std::runtime_error{"Invalid encoded size: " + - std::to_string(pointer - buf.data())}; - } - - return buf; -} - -std::size_t DamEncoder::calculate_size() { - auto size = kPrefixLen; - - for (auto entry : *entries) { - size += 16; // The Type and Len - size += entry->size * 8; // All supported data types are 8 bytes - } - - return size; -} - - -// DamDecoder ====== - -DamDecoder::DamDecoder(std::uint8_t const *buffer, std::size_t size) { - this->buffer = buffer; - this->buf_size = size; - this->pos = buffer + kPrefixLen; - if (size >= kPrefixLen) { - memcpy(&len, buffer + 8, 8); - memcpy(&data_set_id, buffer + 16, 8); - } else { - len = 0; - data_set_id = 0; - } -} - -bool DamDecoder::IsValid() { - return buf_size >= kPrefixLen && memcmp(buffer, kSignature, strlen(kSignature)) == 0; -} - -std::vector DamDecoder::DecodeIntArray() { - auto type = *reinterpret_cast(pos); - if (type != kDataTypeIntArray) { - std::cout << "Data type " << type << " doesn't match Int Array" - << std::endl; - return std::vector(); - } - pos += 8; - - auto len = *reinterpret_cast(pos); - pos += 8; - auto ptr = reinterpret_cast(pos); - pos += 8 * len; - return std::vector(ptr, ptr + len); -} - -std::vector DamDecoder::DecodeFloatArray() { - auto type = *reinterpret_cast(pos); - if (type != kDataTypeFloatArray) { - std::cout << "Data type " << type << " doesn't match Float Array" << std::endl; - return std::vector(); - } - pos += 8; - - auto len = *reinterpret_cast(pos); - pos += 8; - - auto ptr = reinterpret_cast(pos); - pos += 8*len; - return std::vector(ptr, ptr + len); -} diff --git a/integration/xgboost/processor/src/include/dam.h b/integration/xgboost/processor/src/include/dam.h deleted file mode 100644 index 7afdf983af..0000000000 --- a/integration/xgboost/processor/src/include/dam.h +++ /dev/null @@ -1,93 +0,0 @@ -/** - * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once -#include -#include // for int64_t -#include // for size_t - -const char kSignature[] = "NVDADAM1"; // DAM (Direct Accessible Marshalling) V1 -const int kPrefixLen = 24; - -const int kDataTypeInt = 1; -const int kDataTypeFloat = 2; -const int kDataTypeString = 3; -const int kDataTypeIntArray = 257; -const int kDataTypeFloatArray = 258; - -const int kDataTypeMap = 1025; - -class Entry { - public: - int64_t data_type; - uint8_t * pointer; - int64_t size; - - Entry(int64_t data_type, uint8_t *pointer, int64_t size) { - this->data_type = data_type; - this->pointer = pointer; - this->size = size; - } -}; - -class DamEncoder { - private: - bool encoded = false; - int64_t data_set_id; - std::vector *entries = new std::vector(); - - public: - explicit DamEncoder(int64_t data_set_id) { - this->data_set_id = data_set_id; - } - - void AddIntArray(const std::vector &value); - - void AddFloatArray(const std::vector &value); - - std::vector Finish(size_t &size); - - private: - std::size_t calculate_size(); -}; - -class DamDecoder { - private: - std::uint8_t const *buffer = nullptr; - std::size_t buf_size = 0; - std::uint8_t const *pos = nullptr; - std::size_t remaining = 0; - int64_t data_set_id = 0; - int64_t len = 0; - - public: - explicit DamDecoder(std::uint8_t const *buffer, std::size_t size); - - size_t Size() { - return len; - } - - int64_t GetDataSetId() { - return data_set_id; - } - - bool IsValid(); - - std::vector DecodeIntArray(); - - std::vector DecodeFloatArray(); -}; - -void print_buffer(uint8_t *buffer, int size); diff --git a/integration/xgboost/processor/src/nvflare-plugin/nvflare_processor.cc b/integration/xgboost/processor/src/nvflare-plugin/nvflare_processor.cc deleted file mode 100644 index 3e742b14ef..0000000000 --- a/integration/xgboost/processor/src/nvflare-plugin/nvflare_processor.cc +++ /dev/null @@ -1,378 +0,0 @@ -/** - * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "nvflare_processor.h" - -#include "dam.h" // for DamEncoder -#include -#include // for copy_n, transform -#include // for memcpy -#include // for shared_ptr -#include // for invalid_argument -#include // for string_view -#include // for vector - -namespace nvflare { -namespace { -// The opaque type for the C handle. -using CHandleT = std::shared_ptr *; -// Actual representation used in C++ code base. -using HandleT = std::remove_pointer_t; - -std::string &GlobalErrorMsg() { - static thread_local std::string msg; - return msg; -} - -// Perform handle handling for C API functions. -template auto CApiGuard(FederatedPluginHandle handle, Fn &&fn) { - auto pptr = static_cast(handle); - if (!pptr) { - return 1; - } - - try { - if constexpr (std::is_void_v>) { - fn(*pptr); - return 0; - } else { - return fn(*pptr); - } - } catch (std::exception const &e) { - GlobalErrorMsg() = e.what(); - return 1; - } -} -} // namespace - -TensealPlugin::TensealPlugin( - std::vector> const &args) { - if (!args.empty()) { - throw std::invalid_argument{"Invaid arguments for the tenseal plugin."}; - } -} - -void TensealPlugin::EncryptGPairs(float const *in_gpair, std::size_t n_in, - std::uint8_t **out_gpair, - std::size_t *n_out) { - std::vector pairs(n_in); - std::copy_n(in_gpair, n_in, pairs.begin()); - DamEncoder encoder(kDataSetHGPairs); - encoder.AddFloatArray(pairs); - encrypted_gpairs_ = encoder.Finish(*n_out); - if (!out_gpair) { - throw std::invalid_argument{"Invalid pointer to output gpair."}; - } - *out_gpair = encrypted_gpairs_.data(); - *n_out = encrypted_gpairs_.size(); -} - -void TensealPlugin::SyncEncryptedGPairs(std::uint8_t const *in_gpair, - std::size_t n_bytes, - std::uint8_t const **out_gpair, - std::size_t *out_n_bytes) { - *out_n_bytes = n_bytes; - *out_gpair = in_gpair; -} - -void TensealPlugin::ResetHistContext(std::uint32_t const *cutptrs, - std::size_t cutptr_len, - std::int32_t const *bin_idx, - std::size_t n_idx) { - // fixme: this doesn't have to be copied multiple times. - this->cut_ptrs_.resize(cutptr_len); - std::copy_n(cutptrs, cutptr_len, cut_ptrs_.begin()); - this->bin_idx_.resize(n_idx); - std::copy_n(bin_idx, n_idx, this->bin_idx_.begin()); -} - -void TensealPlugin::BuildEncryptedHistVert(std::size_t const **ridx, - std::size_t const *sizes, - std::int32_t const *nidx, - std::size_t len, - std::uint8_t** out_hist, - std::size_t* out_len) { - std::int64_t data_set_id; - if (!feature_sent_) { - data_set_id = kDataSetAggregationWithFeatures; - feature_sent_ = true; - } else { - data_set_id = kDataSetAggregation; - } - - DamEncoder encoder(data_set_id); - - // Add cuts pointers - std::vector cuts_vec(cut_ptrs_.cbegin(), cut_ptrs_.cend()); - encoder.AddIntArray(cuts_vec); - - auto num_features = cut_ptrs_.size() - 1; - auto num_samples = bin_idx_.size() / num_features; - - if (data_set_id == kDataSetAggregationWithFeatures) { - if (features_.empty()) { // when is it not empty? - for (std::size_t f = 0; f < num_features; f++) { - auto slot = bin_idx_[f]; - if (slot >= 0) { - // what happens if it's missing? - features_.push_back(f); - } - } - } - encoder.AddIntArray(features_); - - std::vector bins; - for (int i = 0; i < num_samples; i++) { - for (auto f : features_) { - auto index = f + i * num_features; - if (index > bin_idx_.size()) { - throw std::out_of_range{"Index is out of range: " + - std::to_string(index)}; - } - auto slot = bin_idx_[index]; - bins.push_back(slot); - } - } - encoder.AddIntArray(bins); - } - - // Add nodes to build - std::vector node_vec(len); - std::copy_n(nidx, len, node_vec.begin()); - encoder.AddIntArray(node_vec); - - // For each node, get the row_id/slot pair - for (std::size_t i = 0; i < len; ++i) { - std::vector rows(sizes[i]); - std::copy_n(ridx[i], sizes[i], rows.begin()); - encoder.AddIntArray(rows); - } - - std::size_t n{0}; - encrypted_hist_ = encoder.Finish(n); - - *out_hist = encrypted_hist_.data(); - *out_len = encrypted_hist_.size(); -} - -void TensealPlugin::SyncEncryptedHistVert(std::uint8_t *buffer, - std::size_t buf_size, double **out, - std::size_t *out_len) { - auto remaining = buf_size; - char *pointer = reinterpret_cast(buffer); - - // The buffer is concatenated by AllGather. It may contain multiple DAM - // buffers - std::vector &result = hist_; - result.clear(); - auto max_slot = cut_ptrs_.back(); - auto array_size = 2 * max_slot * sizeof(double); - // A new histogram array? - double *slots = static_cast(malloc(array_size)); - while (remaining > kPrefixLen) { - DamDecoder decoder(reinterpret_cast(pointer), remaining); - if (!decoder.IsValid()) { - std::cout << "Not DAM encoded buffer ignored at offset: " - << static_cast( - (pointer - reinterpret_cast(buffer))) - << std::endl; - break; - } - auto size = decoder.Size(); - auto node_list = decoder.DecodeIntArray(); - for (auto node : node_list) { - std::memset(slots, 0, array_size); - auto feature_list = decoder.DecodeIntArray(); - // Convert per-feature histo to a flat one - for (auto f : feature_list) { - auto base = cut_ptrs_[f]; // cut pointer for the current feature - auto bins = decoder.DecodeFloatArray(); - auto n = bins.size() / 2; - for (int i = 0; i < n; i++) { - auto index = base + i; - // [Q] Build local histogram? Why does it need to be built here? - slots[2 * index] += bins[2 * i]; - slots[2 * index + 1] += bins[2 * i + 1]; - } - } - result.insert(result.end(), slots, slots + 2 * max_slot); - } - remaining -= size; - pointer += size; - } - free(slots); - - *out_len = result.size(); - *out = result.data(); -} - -void TensealPlugin::BuildEncryptedHistHori(double const *in_histogram, - std::size_t len, - std::uint8_t **out_hist, - std::size_t *out_len) { - DamEncoder encoder(kDataSetHistograms); - std::vector copy(in_histogram, in_histogram + len); - encoder.AddFloatArray(copy); - - std::size_t size{0}; - this->encrypted_hist_ = encoder.Finish(size); - - *out_hist = this->encrypted_hist_.data(); - *out_len = this->encrypted_hist_.size(); -} - -void TensealPlugin::SyncEncryptedHistHori(std::uint8_t const *buffer, - std::size_t len, double **out_hist, - std::size_t *out_len) { - DamDecoder decoder(reinterpret_cast(buffer), len); - if (!decoder.IsValid()) { - std::cout << "Not DAM encoded buffer, ignored" << std::endl; - } - - if (decoder.GetDataSetId() != kDataSetHistogramResult) { - throw std::runtime_error{"Invalid dataset: " + - std::to_string(decoder.GetDataSetId())}; - } - this->hist_ = decoder.DecodeFloatArray(); - *out_hist = this->hist_.data(); - *out_len = this->hist_.size(); -} -} // namespace nvflare - -#if defined(_MSC_VER) || defined(_WIN32) -#define NVF_C __declspec(dllexport) -#else -#define NVF_C __attribute__((visibility("default"))) -#endif // defined(_MSC_VER) || defined(_WIN32) - -extern "C" { -NVF_C char const *FederatedPluginErrorMsg() { - return nvflare::GlobalErrorMsg().c_str(); -} - -FederatedPluginHandle NVF_C FederatedPluginCreate(int argc, char const **argv) { - using namespace nvflare; - try { - CHandleT pptr = new std::shared_ptr; - std::vector> args; - std::transform( - argv, argv + argc, std::back_inserter(args), [](char const *carg) { - // Split a key value pair in contructor argument: `key=value` - std::string_view arg{carg}; - auto idx = arg.find('='); - if (idx == std::string_view::npos) { - // `=` not found - throw std::invalid_argument{"Invalid argument:" + std::string{arg}}; - } - auto key = arg.substr(0, idx); - auto value = arg.substr(idx + 1); - return std::make_pair(key, value); - }); - *pptr = std::make_shared(args); - return pptr; - } catch (std::exception const &e) { - GlobalErrorMsg() = e.what(); - return nullptr; - } -} - -int NVF_C FederatedPluginClose(FederatedPluginHandle handle) { - using namespace nvflare; - auto pptr = static_cast(handle); - if (!pptr) { - return 1; - } - - try { - delete pptr; - } catch (std::exception const &e) { - GlobalErrorMsg() = e.what(); - return 1; - } - return 0; -} - -int NVF_C FederatedPluginEncryptGPairs(FederatedPluginHandle handle, - float const *in_gpair, size_t n_in, - uint8_t **out_gpair, size_t *n_out) { - using namespace nvflare; - return CApiGuard(handle, [&](HandleT plugin) { - plugin->EncryptGPairs(in_gpair, n_in, out_gpair, n_out); - return 0; - }); -} - -int NVF_C FederatedPluginSyncEncryptedGPairs(FederatedPluginHandle handle, - uint8_t const *in_gpair, - size_t n_bytes, - uint8_t const **out_gpair, - size_t *n_out) { - using namespace nvflare; - return CApiGuard(handle, [&](HandleT plugin) { - plugin->SyncEncryptedGPairs(in_gpair, n_bytes, out_gpair, n_out); - }); -} - -int NVF_C FederatedPluginResetHistContextVert(FederatedPluginHandle handle, - uint32_t const *cutptrs, - size_t cutptr_len, - int32_t const *bin_idx, - size_t n_idx) { - using namespace nvflare; - return CApiGuard(handle, [&](HandleT plugin) { - plugin->ResetHistContext(cutptrs, cutptr_len, bin_idx, n_idx); - }); -} - -int NVF_C FederatedPluginBuildEncryptedHistVert( - FederatedPluginHandle handle, uint64_t const **ridx, size_t const *sizes, - int32_t const *nidx, size_t len, uint8_t **out_hist, size_t *out_len) { - using namespace nvflare; - return CApiGuard(handle, [&](HandleT plugin) { - plugin->BuildEncryptedHistVert(ridx, sizes, nidx, len, out_hist, out_len); - }); -} - -int NVF_C FederatedPluginSyncEnrcyptedHistVert(FederatedPluginHandle handle, - uint8_t *in_hist, size_t len, - double **out_hist, - size_t *out_len) { - using namespace nvflare; - return CApiGuard(handle, [&](HandleT plugin) { - plugin->SyncEncryptedHistVert(in_hist, len, out_hist, out_len); - }); -} - -int NVF_C FederatedPluginBuildEncryptedHistHori(FederatedPluginHandle handle, - double const *in_hist, - size_t len, uint8_t **out_hist, - size_t *out_len) { - using namespace nvflare; - return CApiGuard(handle, [&](HandleT plugin) { - plugin->BuildEncryptedHistHori(in_hist, len, out_hist, out_len); - }); -} - -int NVF_C FederatedPluginSyncEnrcyptedHistHori(FederatedPluginHandle handle, - uint8_t const *in_hist, - size_t len, double **out_hist, - size_t *out_len) { - using namespace nvflare; - return CApiGuard(handle, [&](HandleT plugin) { - plugin->SyncEncryptedHistHori(in_hist, len, out_hist, out_len); - return 0; - }); -} -} // extern "C" diff --git a/integration/xgboost/processor/tests/CMakeLists.txt b/integration/xgboost/processor/tests/CMakeLists.txt deleted file mode 100644 index 893d8738dc..0000000000 --- a/integration/xgboost/processor/tests/CMakeLists.txt +++ /dev/null @@ -1,14 +0,0 @@ -file(GLOB_RECURSE TEST_SOURCES "*.cc") - -target_sources(proc_test PRIVATE ${TEST_SOURCES}) - -target_include_directories(proc_test - PRIVATE - ${GTEST_INCLUDE_DIRS} - ${proc_nvflare_SOURCE_DIR/tests} - ${proc_nvflare_SOURCE_DIR}/src) - -message("Include Dir: ${GTEST_INCLUDE_DIRS}") -target_link_libraries(proc_test - PRIVATE - ${GTEST_LIBRARIES}) diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/defs.py b/nvflare/app_opt/xgboost/histogram_based_v2/defs.py index b559306440..b32535ac15 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/defs.py +++ b/nvflare/app_opt/xgboost/histogram_based_v2/defs.py @@ -109,6 +109,8 @@ class Constant: HEADER_KEY_HORIZONTAL = "xgb.horizontal" HEADER_KEY_ORIGINAL_BUF_SIZE = "xgb.original_buf_size" HEADER_KEY_IN_AGGR = "xgb.in_aggr" + HEADER_KEY_WORLD_SIZE = "xgb.world_size" + HEADER_KEY_SIZE_DICT = "xgb.size_dict" DUMMY_BUFFER_SIZE = 4 @@ -122,8 +124,6 @@ class Constant: class SplitMode: ROW = 0 COL = 1 - COL_SECURE = 2 - ROW_SECURE = 3 # Mapping of text training mode to split mode @@ -132,10 +132,10 @@ class SplitMode: "horizontal": SplitMode.ROW, "v": SplitMode.COL, "vertical": SplitMode.COL, - "hs": SplitMode.ROW_SECURE, - "horizontal_secure": SplitMode.ROW_SECURE, - "vs": SplitMode.COL_SECURE, - "vertical_secure": SplitMode.COL_SECURE, + "hs": SplitMode.ROW, + "horizontal_secure": SplitMode.ROW, + "vs": SplitMode.COL, + "vertical_secure": SplitMode.COL, } SECURE_TRAINING_MODES = {"hs", "horizontal_secure", "vs", "vertical_secure"} diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/proto/federated_pb2.pyi b/nvflare/app_opt/xgboost/histogram_based_v2/proto/federated_pb2.pyi index 7ad47596df..7dc3e6dde1 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/proto/federated_pb2.pyi +++ b/nvflare/app_opt/xgboost/histogram_based_v2/proto/federated_pb2.pyi @@ -6,7 +6,7 @@ from typing import ClassVar as _ClassVar, Optional as _Optional, Union as _Union DESCRIPTOR: _descriptor.FileDescriptor class DataType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): - __slots__ = [] + __slots__ = () HALF: _ClassVar[DataType] FLOAT: _ClassVar[DataType] DOUBLE: _ClassVar[DataType] @@ -21,7 +21,7 @@ class DataType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): UINT64: _ClassVar[DataType] class ReduceOperation(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): - __slots__ = [] + __slots__ = () MAX: _ClassVar[ReduceOperation] MIN: _ClassVar[ReduceOperation] SUM: _ClassVar[ReduceOperation] @@ -48,7 +48,7 @@ BITWISE_OR: ReduceOperation BITWISE_XOR: ReduceOperation class AllgatherRequest(_message.Message): - __slots__ = ["sequence_number", "rank", "send_buffer"] + __slots__ = ("sequence_number", "rank", "send_buffer") SEQUENCE_NUMBER_FIELD_NUMBER: _ClassVar[int] RANK_FIELD_NUMBER: _ClassVar[int] SEND_BUFFER_FIELD_NUMBER: _ClassVar[int] @@ -58,13 +58,13 @@ class AllgatherRequest(_message.Message): def __init__(self, sequence_number: _Optional[int] = ..., rank: _Optional[int] = ..., send_buffer: _Optional[bytes] = ...) -> None: ... class AllgatherReply(_message.Message): - __slots__ = ["receive_buffer"] + __slots__ = ("receive_buffer",) RECEIVE_BUFFER_FIELD_NUMBER: _ClassVar[int] receive_buffer: bytes def __init__(self, receive_buffer: _Optional[bytes] = ...) -> None: ... class AllgatherVRequest(_message.Message): - __slots__ = ["sequence_number", "rank", "send_buffer"] + __slots__ = ("sequence_number", "rank", "send_buffer") SEQUENCE_NUMBER_FIELD_NUMBER: _ClassVar[int] RANK_FIELD_NUMBER: _ClassVar[int] SEND_BUFFER_FIELD_NUMBER: _ClassVar[int] @@ -74,13 +74,13 @@ class AllgatherVRequest(_message.Message): def __init__(self, sequence_number: _Optional[int] = ..., rank: _Optional[int] = ..., send_buffer: _Optional[bytes] = ...) -> None: ... class AllgatherVReply(_message.Message): - __slots__ = ["receive_buffer"] + __slots__ = ("receive_buffer",) RECEIVE_BUFFER_FIELD_NUMBER: _ClassVar[int] receive_buffer: bytes def __init__(self, receive_buffer: _Optional[bytes] = ...) -> None: ... class AllreduceRequest(_message.Message): - __slots__ = ["sequence_number", "rank", "send_buffer", "data_type", "reduce_operation"] + __slots__ = ("sequence_number", "rank", "send_buffer", "data_type", "reduce_operation") SEQUENCE_NUMBER_FIELD_NUMBER: _ClassVar[int] RANK_FIELD_NUMBER: _ClassVar[int] SEND_BUFFER_FIELD_NUMBER: _ClassVar[int] @@ -94,13 +94,13 @@ class AllreduceRequest(_message.Message): def __init__(self, sequence_number: _Optional[int] = ..., rank: _Optional[int] = ..., send_buffer: _Optional[bytes] = ..., data_type: _Optional[_Union[DataType, str]] = ..., reduce_operation: _Optional[_Union[ReduceOperation, str]] = ...) -> None: ... class AllreduceReply(_message.Message): - __slots__ = ["receive_buffer"] + __slots__ = ("receive_buffer",) RECEIVE_BUFFER_FIELD_NUMBER: _ClassVar[int] receive_buffer: bytes def __init__(self, receive_buffer: _Optional[bytes] = ...) -> None: ... class BroadcastRequest(_message.Message): - __slots__ = ["sequence_number", "rank", "send_buffer", "root"] + __slots__ = ("sequence_number", "rank", "send_buffer", "root") SEQUENCE_NUMBER_FIELD_NUMBER: _ClassVar[int] RANK_FIELD_NUMBER: _ClassVar[int] SEND_BUFFER_FIELD_NUMBER: _ClassVar[int] @@ -112,7 +112,7 @@ class BroadcastRequest(_message.Message): def __init__(self, sequence_number: _Optional[int] = ..., rank: _Optional[int] = ..., send_buffer: _Optional[bytes] = ..., root: _Optional[int] = ...) -> None: ... class BroadcastReply(_message.Message): - __slots__ = ["receive_buffer"] + __slots__ = ("receive_buffer",) RECEIVE_BUFFER_FIELD_NUMBER: _ClassVar[int] receive_buffer: bytes def __init__(self, receive_buffer: _Optional[bytes] = ...) -> None: ... diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/proto/federated_pb2_grpc.py b/nvflare/app_opt/xgboost/histogram_based_v2/proto/federated_pb2_grpc.py index 45eee5c8dd..549d0e4ffc 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/proto/federated_pb2_grpc.py +++ b/nvflare/app_opt/xgboost/histogram_based_v2/proto/federated_pb2_grpc.py @@ -12,13 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: federated.proto +# Protobuf Python Version: 4.25.1 # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! """Client and server classes corresponding to protobuf-defined services.""" import grpc import nvflare.app_opt.xgboost.histogram_based_v2.proto.federated_pb2 as federated__pb2 - class FederatedStub(object): """Missing associated documentation comment in .proto file.""" diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/runners/xgb_client_runner.py b/nvflare/app_opt/xgboost/histogram_based_v2/runners/xgb_client_runner.py index 1b98829711..0d8e8bec1d 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/runners/xgb_client_runner.py +++ b/nvflare/app_opt/xgboost/histogram_based_v2/runners/xgb_client_runner.py @@ -30,7 +30,9 @@ from nvflare.fuel.utils.obj_utils import get_logger from nvflare.utils.cli_utils import get_package_root -LOADER_PARAMS_LIBRARY_PATH = "LIBRARY_PATH" +PLUGIN_PARAM_KEY = "federated_plugin" +PLUGIN_KEY_NAME = "name" +PLUGIN_KEY_PATH = "path" class XGBClientRunner(AppRunner, FLComponent): @@ -135,7 +137,7 @@ def run(self, ctx: dict): self.logger.info(f"server address is {self._server_addr}") communicator_env = { - "xgboost_communicator": "federated", + "dmlc_communicator": "federated", "federated_server_address": f"{self._server_addr}", "federated_world_size": self._world_size, "federated_rank": self._rank, @@ -145,38 +147,35 @@ def run(self, ctx: dict): self.logger.info("XGBoost non-secure training") else: xgb_plugin_name = ConfigService.get_str_var( - name="xgb_plugin_name", conf=SystemConfigs.RESOURCES_CONF, default="nvflare" + name="xgb_plugin_name", conf=SystemConfigs.RESOURCES_CONF, default=None ) - - xgb_loader_params = ConfigService.get_dict_var( - name="xgb_loader_params", conf=SystemConfigs.RESOURCES_CONF, default={} + xgb_plugin_path = ConfigService.get_str_var( + name="xgb_plugin_path", conf=SystemConfigs.RESOURCES_CONF, default=None + ) + xgb_plugin_params: dict = ConfigService.get_dict_var( + name=PLUGIN_PARAM_KEY, conf=SystemConfigs.RESOURCES_CONF, default={} ) - # Library path is frequently used, add a scalar config var and overwrite what's in the dict - xgb_library_path = ConfigService.get_str_var(name="xgb_library_path", conf=SystemConfigs.RESOURCES_CONF) - if xgb_library_path: - xgb_loader_params[LOADER_PARAMS_LIBRARY_PATH] = xgb_library_path + # path and name can be overwritten by scalar configuration + if xgb_plugin_name: + xgb_plugin_params[PLUGIN_KEY_NAME] = xgb_plugin_name - lib_path = xgb_loader_params.get(LOADER_PARAMS_LIBRARY_PATH, None) - if not lib_path: - xgb_loader_params[LOADER_PARAMS_LIBRARY_PATH] = str(get_package_root() / "libs") + if xgb_plugin_path: + xgb_plugin_params[PLUGIN_KEY_PATH] = xgb_plugin_path - xgb_proc_params = ConfigService.get_dict_var( - name="xgb_proc_params", conf=SystemConfigs.RESOURCES_CONF, default={} - ) + # Set default plugin name + if not xgb_plugin_params.get(PLUGIN_KEY_NAME): + xgb_plugin_params[PLUGIN_KEY_NAME] = "cuda_paillier" - self.logger.info( - f"XGBoost secure mode: {self._training_mode} plugin_name: {xgb_plugin_name} " - f"proc_params: {xgb_proc_params} loader_params: {xgb_loader_params}" - ) + if not xgb_plugin_params.get(PLUGIN_KEY_PATH): + # This only works on Linux. Need to support other platforms + lib_ext = "so" + lib_name = f"lib{xgb_plugin_params[PLUGIN_KEY_NAME]}.{lib_ext}" + xgb_plugin_params[PLUGIN_KEY_PATH] = str(get_package_root() / "libs" / lib_name) - communicator_env.update( - { - "plugin_name": xgb_plugin_name, - "proc_params": xgb_proc_params, - "loader_params": xgb_loader_params, - } - ) + self.logger.info(f"XGBoost secure training: {self._training_mode} Params: {xgb_plugin_params}") + + communicator_env[PLUGIN_PARAM_KEY] = xgb_plugin_params with xgb.collective.CommunicatorContext(**communicator_env): # Load the data. Dmatrix must be created with column split mode in CommunicatorContext for vertical FL diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/runners/xgb_server_runner.py b/nvflare/app_opt/xgboost/histogram_based_v2/runners/xgb_server_runner.py index 32e708c90e..e4a8796a38 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/runners/xgb_server_runner.py +++ b/nvflare/app_opt/xgboost/histogram_based_v2/runners/xgb_server_runner.py @@ -29,8 +29,8 @@ def run(self, ctx: dict): self._world_size = ctx.get(Constant.RUNNER_CTX_WORLD_SIZE) xgb_federated.run_federated_server( + n_workers=self._world_size, port=self._port, - world_size=self._world_size, ) self._stopped = True diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/sec/client_handler.py b/nvflare/app_opt/xgboost/histogram_based_v2/sec/client_handler.py index 5aad654824..ea5607d828 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/sec/client_handler.py +++ b/nvflare/app_opt/xgboost/histogram_based_v2/sec/client_handler.py @@ -299,6 +299,10 @@ def _process_after_all_gather_v(self, fl_ctx: FLContext): self._process_after_all_gather_v_vertical(fl_ctx) def _process_after_all_gather_v_vertical(self, fl_ctx: FLContext): + reply = fl_ctx.get_prop(Constant.PARAM_KEY_REPLY) + size_dict = reply.get_header(Constant.HEADER_KEY_SIZE_DICT) + total_size = sum(size_dict.values()) + self.info(fl_ctx, f"{total_size=} {size_dict=}") rcv_buf = fl_ctx.get_prop(Constant.PARAM_KEY_RCV_BUF) # this rcv_buf is a list of replies from ALL clients! rank = fl_ctx.get_prop(Constant.PARAM_KEY_RANK) @@ -309,7 +313,7 @@ def _process_after_all_gather_v_vertical(self, fl_ctx: FLContext): if not self.clear_ghs: # this is non-label client - don't care about the results - dummy = os.urandom(Constant.DUMMY_BUFFER_SIZE) + dummy = os.urandom(total_size) fl_ctx.set_prop(key=Constant.PARAM_KEY_RCV_BUF, value=dummy, private=True, sticky=False) self.info(fl_ctx, "non-label client: return dummy buffer back to XGB") return @@ -352,16 +356,45 @@ def _process_after_all_gather_v_vertical(self, fl_ctx: FLContext): self.info(fl_ctx, f"final aggr: {gid=} features={fid_list}") result = self.data_converter.encode_aggregation_result(final_result, fl_ctx) + + # XGBoost expects every work has a set of histograms. They are already combined here so + # just add zeros + zero_result = final_result + for result_list in zero_result.values(): + for item in result_list: + size = len(item.aggregated_hist) + item.aggregated_hist = [(0, 0)] * size + zero_buf = self.data_converter.encode_aggregation_result(zero_result, fl_ctx) + world_size = len(size_dict) + for _ in range(world_size - 1): + result += zero_buf + + # XGBoost checks that the size of allgatherv is not changed + padding_size = total_size - len(result) + if padding_size > 0: + result += b"\x00" * padding_size + elif padding_size < 0: + self.error(fl_ctx, f"The original size {total_size} is not big enough for data size {len(result)}") + fl_ctx.set_prop(key=Constant.PARAM_KEY_RCV_BUF, value=result, private=True, sticky=False) def _process_after_all_gather_v_horizontal(self, fl_ctx: FLContext): + reply = fl_ctx.get_prop(Constant.PARAM_KEY_REPLY) + world_size = reply.get_header(Constant.HEADER_KEY_WORLD_SIZE) encrypted_histograms = fl_ctx.get_prop(Constant.PARAM_KEY_RCV_BUF) rank = fl_ctx.get_prop(Constant.PARAM_KEY_RANK) if not isinstance(encrypted_histograms, CKKSVector): return self._abort(f"rank {rank}: expect a CKKSVector but got {type(encrypted_histograms)}", fl_ctx) histograms = encrypted_histograms.decrypt(secret_key=self.tenseal_context.secret_key()) + result = self.data_converter.encode_histograms_result(histograms, fl_ctx) + + # XGBoost expect every worker returns a histogram, all zeros are returned for other workers + zeros = [0.0] * len(histograms) + zero_buf = self.data_converter.encode_histograms_result(zeros, fl_ctx) + for _ in range(world_size - 1): + result += zero_buf fl_ctx.set_prop(key=Constant.PARAM_KEY_RCV_BUF, value=result, private=True, sticky=False) def handle_event(self, event_type: str, fl_ctx: FLContext): @@ -376,7 +409,7 @@ def handle_event(self, event_type: str, fl_ctx: FLContext): else: self.debug(fl_ctx, "Tenseal module not loaded, horizontal secure XGBoost is not supported") except Exception as ex: - self.debug(fl_ctx, f"Can't load tenseal context, horizontal secure XGBoost is not supported: {ex}") + self.error(fl_ctx, f"Can't load tenseal context, horizontal secure XGBoost is not supported: {ex}") self.tenseal_context = None elif event_type == EventType.END_RUN: self.tenseal_context = None diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/sec/server_handler.py b/nvflare/app_opt/xgboost/histogram_based_v2/sec/server_handler.py index 53e936c7d4..47e44d17d6 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/sec/server_handler.py +++ b/nvflare/app_opt/xgboost/histogram_based_v2/sec/server_handler.py @@ -39,6 +39,8 @@ def __init__(self): self.aggr_result_dict = None self.aggr_result_to_send = None self.aggr_result_lock = threading.Lock() + self.world_size = 0 + self.size_dict = None if tenseal_imported: decomposers.register() @@ -124,6 +126,10 @@ def _process_before_all_gather_v(self, fl_ctx: FLContext): else: self.info(fl_ctx, f"no aggr data from {rank=}") + if self.size_dict is None: + self.size_dict = {} + + self.size_dict[rank] = request.get_header(Constant.HEADER_KEY_ORIGINAL_BUF_SIZE) # only send a dummy to the Server fl_ctx.set_prop( key=Constant.PARAM_KEY_SEND_BUF, value=os.urandom(Constant.DUMMY_BUFFER_SIZE), private=True, sticky=False @@ -146,6 +152,7 @@ def _process_after_all_gather_v(self, fl_ctx: FLContext): horizontal = fl_ctx.get_prop(Constant.HEADER_KEY_HORIZONTAL) reply.set_header(Constant.HEADER_KEY_ENCRYPTED_DATA, True) reply.set_header(Constant.HEADER_KEY_HORIZONTAL, horizontal) + with self.aggr_result_lock: if not self.aggr_result_to_send: if not self.aggr_result_dict: @@ -159,6 +166,10 @@ def _process_after_all_gather_v(self, fl_ctx: FLContext): # reset aggr_result_dict for next gather self.aggr_result_dict = None + self.world_size = len(self.size_dict) + reply.set_header(Constant.HEADER_KEY_WORLD_SIZE, self.world_size) + reply.set_header(Constant.HEADER_KEY_SIZE_DICT, self.size_dict) + if horizontal: length = self.aggr_result_to_send.size() else: diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/secure_data_loader.py b/nvflare/app_opt/xgboost/histogram_based_v2/secure_data_loader.py new file mode 100644 index 0000000000..6540eb519c --- /dev/null +++ b/nvflare/app_opt/xgboost/histogram_based_v2/secure_data_loader.py @@ -0,0 +1,50 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import xgboost as xgb + +from nvflare.app_opt.xgboost.data_loader import XGBDataLoader +from nvflare.app_opt.xgboost.histogram_based_v2.defs import TRAINING_MODE_MAPPING, SplitMode + + +class SecureDataLoader(XGBDataLoader): + def __init__(self, rank: int, folder: str): + """Reads CSV dataset and return XGB data matrix in vertical secure mode. + + Args: + rank: Rank of the site + folder: Folder to find the CSV files + """ + self.rank = rank + self.folder = folder + + def load_data(self, client_id: str, training_mode: str): + + train_path = f"{self.folder}/site-{self.rank + 1}/train.csv" + valid_path = f"{self.folder}/site-{self.rank + 1}/valid.csv" + + if training_mode not in TRAINING_MODE_MAPPING: + raise ValueError(f"Invalid training_mode: {training_mode}") + + data_split_mode = TRAINING_MODE_MAPPING[training_mode] + + if self.rank == 0 or data_split_mode == SplitMode.ROW: + label = "&label_column=0" + else: + label = "" + + train_data = xgb.DMatrix(train_path + f"?format=csv{label}", data_split_mode=data_split_mode) + valid_data = xgb.DMatrix(valid_path + f"?format=csv{label}", data_split_mode=data_split_mode) + + return train_data, valid_data