Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrap the xgboost plugin into a C library. #2639

Merged
merged 9 commits into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 9 additions & 15 deletions integration/xgboost/processor/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,22 @@ set(CMAKE_BUILD_TYPE Debug)

option(GOOGLE_TEST "Build google tests" OFF)

file(GLOB_RECURSE LIB_SRC
"src/*.h"
"src/*.cc"
)
file(GLOB_RECURSE LIB_SRC "src/*.cc")

add_library(proc_nvflare SHARED ${LIB_SRC})
set(XGB_SRC ${proc_nvflare_SOURCE_DIR}/../../../../xgboost)
target_include_directories(proc_nvflare PRIVATE ${proc_nvflare_SOURCE_DIR}/src/include
${XGB_SRC}/src
${XGB_SRC}/rabit/include
${XGB_SRC}/include
${XGB_SRC}/dmlc-core/include)

link_directories(${XGB_SRC}/lib/)
set_target_properties(proc_nvflare PROPERTIES
CXX_STANDARD 17
CXX_STANDARD_REQUIRED ON
POSITION_INDEPENDENT_CODE ON
ENABLE_EXPORTS ON
)
target_include_directories(proc_nvflare PRIVATE ${proc_nvflare_SOURCE_DIR}/src/include)

if (APPLE)
add_link_options("LINKER:-object_path_lto,$<TARGET_PROPERTY:NAME>_lto.o")
add_link_options("LINKER:-cache_path_lto,${CMAKE_BINARY_DIR}/LTOCache")
endif ()

target_link_libraries(proc_nvflare ${XGB_SRC}/lib/libxgboost${CMAKE_SHARED_LIBRARY_SUFFIX})

#-- Unit Tests
if(GOOGLE_TEST)
find_package(GTest REQUIRED)
Expand All @@ -49,4 +43,4 @@ if(GOOGLE_TEST)
COMMAND proc_test
WORKING_DIRECTORY ${proc_nvflare_BINARY_DIR})

endif()
endif()
12 changes: 4 additions & 8 deletions integration/xgboost/processor/README.md
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
# Build Instruction

This plugin build requires xgboost source code, checkout xgboost source and build it with FEDERATED plugin,

cd xgboost
mkdir build
cd build
cmake .. -DPLUGIN_FEDERATED=ON
make

``` sh
cd NVFlare/integration/xgboost/processor
mkdir build
cd build
cmake ..
make
```

See [tests](./tests) for simple examples.
5 changes: 1 addition & 4 deletions integration/xgboost/processor/src/README.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
# encoding-plugins
Processor Plugin for NVFlare

This plugin is a companion for NVFlare based encryption, it processes the data so it can
This plugin is a companion for NVFlare based encryption, it processes the data so it can
be properly decoded by Python code running on NVFlare.

All the encryption is happening on the local GRPC client/server so no encryption is needed
in this plugin.



5 changes: 1 addition & 4 deletions integration/xgboost/processor/src/dam/README.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
# DAM (Direct-Accessible Marshaller)

A simple serialization library that doesn't have dependencies, and the data
A simple serialization library that doesn't have dependencies, and the data
is directly accessible in C/C++ without copying.

To make the data accessible in C, following rules must be followed,

1. Numeric values must be stored in native byte-order.
2. Numeric values must start at the 64-bit boundaries (8-bytes)



66 changes: 33 additions & 33 deletions integration/xgboost/processor/src/dam/dam.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,14 @@ void print_buffer(uint8_t *buffer, int size) {

// DamEncoder ======
void DamEncoder::AddFloatArray(const std::vector<double> &value) {
if (encoded) {
std::cout << "Buffer is already encoded" << std::endl;
return;
}
auto buf_size = value.size()*8;
uint8_t *buffer = static_cast<uint8_t *>(malloc(buf_size));
memcpy(buffer, value.data(), buf_size);
// print_buffer(reinterpret_cast<uint8_t *>(value.data()), value.size() * 8);
entries->push_back(new Entry(kDataTypeFloatArray, buffer, value.size()));
if (encoded) {
std::cout << "Buffer is already encoded" << std::endl;
return;
}
auto buf_size = value.size() * 8;
uint8_t *buffer = static_cast<uint8_t *>(malloc(buf_size));
memcpy(buffer, value.data(), buf_size);
entries->push_back(new Entry(kDataTypeFloatArray, buffer, value.size()));
}

void DamEncoder::AddIntArray(const std::vector<int64_t> &value) {
Expand All @@ -52,15 +51,15 @@ void DamEncoder::AddIntArray(const std::vector<int64_t> &value) {
entries->push_back(new Entry(kDataTypeIntArray, buffer, value.size()));
}

std::uint8_t * DamEncoder::Finish(size_t &size) {
std::vector<std::uint8_t> DamEncoder::Finish(size_t &size) {
encoded = true;

size = calculate_size();
auto buf = static_cast<uint8_t *>(malloc(size));
auto pointer = buf;
std::vector<std::uint8_t> buf(size);
auto pointer = buf.data();
memcpy(pointer, kSignature, strlen(kSignature));
memcpy(pointer+8, &size, 8);
memcpy(pointer+16, &data_set_id, 8);
memcpy(pointer + 8, &size, 8);
memcpy(pointer + 16, &data_set_id, 8);

pointer += kPrefixLen;
for (auto entry : *entries) {
Expand All @@ -75,9 +74,9 @@ std::uint8_t * DamEncoder::Finish(size_t &size) {
// print_buffer(entry->pointer, entry->size*8);
}

if ((pointer - buf) != size) {
std::cout << "Invalid encoded size: " << (pointer - buf) << std::endl;
return nullptr;
if ((pointer - buf.data()) != size) {
throw std::runtime_error{"Invalid encoded size: " +
std::to_string(pointer - buf.data())};
}

return buf;
Expand All @@ -97,7 +96,7 @@ std::size_t DamEncoder::calculate_size() {

// DamDecoder ======

DamDecoder::DamDecoder(std::uint8_t *buffer, std::size_t size) {
DamDecoder::DamDecoder(std::uint8_t const *buffer, std::size_t size) {
this->buffer = buffer;
this->buf_size = size;
this->pos = buffer + kPrefixLen;
Expand All @@ -115,32 +114,33 @@ bool DamDecoder::IsValid() {
}

std::vector<int64_t> DamDecoder::DecodeIntArray() {
auto type = *reinterpret_cast<int64_t *>(pos);
if (type != kDataTypeIntArray) {
std::cout << "Data type " << type << " doesn't match Int Array" << std::endl;
return std::vector<int64_t>();
}
pos += 8;

auto len = *reinterpret_cast<int64_t *>(pos);
pos += 8;
auto ptr = reinterpret_cast<int64_t *>(pos);
pos += 8*len;
return std::vector<int64_t>(ptr, ptr + len);
auto type = *reinterpret_cast<int64_t const*>(pos);
if (type != kDataTypeIntArray) {
std::cout << "Data type " << type << " doesn't match Int Array"
<< std::endl;
return std::vector<int64_t>();
}
pos += 8;

auto len = *reinterpret_cast<int64_t const *>(pos);
pos += 8;
auto ptr = reinterpret_cast<int64_t const *>(pos);
pos += 8 * len;
return std::vector<int64_t>(ptr, ptr + len);
}

std::vector<double> DamDecoder::DecodeFloatArray() {
auto type = *reinterpret_cast<int64_t *>(pos);
auto type = *reinterpret_cast<int64_t const*>(pos);
if (type != kDataTypeFloatArray) {
std::cout << "Data type " << type << " doesn't match Float Array" << std::endl;
return std::vector<double>();
}
pos += 8;

auto len = *reinterpret_cast<int64_t *>(pos);
auto len = *reinterpret_cast<int64_t const *>(pos);
pos += 8;

auto ptr = reinterpret_cast<double *>(pos);
auto ptr = reinterpret_cast<double const *>(pos);
pos += 8*len;
return std::vector<double>(ptr, ptr + len);
}
16 changes: 8 additions & 8 deletions integration/xgboost/processor/src/include/dam.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
* limitations under the License.
*/
#pragma once
#include <string>
#include <vector>
#include <map>
#include <cstdint> // for int64_t
#include <cstddef> // for size_t

const char kSignature[] = "NVDADAM1"; // DAM (Direct Accessible Marshalling) V1
const int kPrefixLen = 24;
Expand Down Expand Up @@ -57,23 +57,23 @@ class DamEncoder {

void AddFloatArray(const std::vector<double> &value);

std::uint8_t * Finish(size_t &size);
std::vector<std::uint8_t> Finish(size_t &size);

private:
private:
std::size_t calculate_size();
};

class DamDecoder {
private:
std::uint8_t *buffer = nullptr;
std::uint8_t const *buffer = nullptr;
std::size_t buf_size = 0;
std::uint8_t *pos = nullptr;
std::uint8_t const *pos = nullptr;
std::size_t remaining = 0;
int64_t data_set_id = 0;
int64_t len = 0;

public:
explicit DamDecoder(std::uint8_t *buffer, std::size_t size);
public:
explicit DamDecoder(std::uint8_t const *buffer, std::size_t size);

size_t Size() {
return len;
Expand Down
92 changes: 45 additions & 47 deletions integration/xgboost/processor/src/include/nvflare_processor.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,10 @@
* limitations under the License.
*/
#pragma once
#include <iostream>
#include <string>
#include <vector>
#include <map>
#include "processing/processor.h"
#include <cstdint> // for uint8_t, uint32_t, int32_t, int64_t
#include <string_view> // for string_view
#include <utility> // for pair
#include <vector> // for vector

const int kDataSetHGPairs = 1;
const int kDataSetAggregation = 2;
Expand All @@ -27,50 +26,49 @@ const int kDataSetAggregationResult = 4;
const int kDataSetHistograms = 5;
const int kDataSetHistogramResult = 6;

class NVFlareProcessor: public processing::Processor {
private:
bool active_ = false;
const std::map<std::string, std::string> *params_;
std::vector<double> *gh_pairs_{nullptr};
std::vector<uint32_t> cuts_;
std::vector<int> slots_;
bool feature_sent_ = false;
std::vector<int64_t> features_;
// Opaque pointer type for the C API.
typedef void *FederatedPluginHandle; // NOLINT

public:
void Initialize(bool active, std::map<std::string, std::string> params) override {
this->active_ = active;
this->params_ = &params;
}
namespace nvflare {
// Plugin that uses Python tenseal and GRPC.
class TensealPlugin {
// Buffer for storing encrypted gradient pairs.
std::vector<std::uint8_t> encrypted_gpairs_;
// Buffer for histogram cut pointers (indptr of a CSC).
std::vector<std::uint32_t> cut_ptrs_;
// Buffer for histogram index.
std::vector<std::int32_t> bin_idx_;

void Shutdown() override {
this->gh_pairs_ = nullptr;
this->cuts_.clear();
this->slots_.clear();
}
bool feature_sent_{false};
// The feature index.
std::vector<std::int64_t> features_;
// Buffer for output histogram.
std::vector<std::uint8_t> encrypted_hist_;
std::vector<double> hist_;

void FreeBuffer(void *buffer) override {
free(buffer);
}
public:
TensealPlugin(
std::vector<std::pair<std::string_view, std::string_view>> const &args);
// Gradient pairs
void EncryptGPairs(float const *in_gpair, std::size_t n_in,
std::uint8_t **out_gpair, std::size_t *n_out);
void SyncEncryptedGPairs(std::uint8_t const *in_gpair, std::size_t n_bytes,
std::uint8_t const **out_gpair,
std::size_t *out_n_bytes);

void* ProcessGHPairs(size_t *size, const std::vector<double>& pairs) override;
// Histogram
void ResetHistContext(std::uint32_t const *cutptrs, std::size_t cutptr_len,
std::int32_t const *bin_idx, std::size_t n_idx);
void BuildEncryptedHistHori(double const *in_histogram, std::size_t len,
std::uint8_t **out_hist, std::size_t *out_len);
void SyncEncryptedHistHori(std::uint8_t const *buffer, std::size_t len,
double **out_hist, std::size_t *out_len);

void* HandleGHPairs(size_t *size, void *buffer, size_t buf_size) override;

void InitAggregationContext(const std::vector<uint32_t> &cuts, const std::vector<int> &slots) override {
if (this->slots_.empty()) {
this->cuts_ = std::vector<uint32_t>(cuts);
this->slots_ = std::vector<int>(slots);
} else {
std::cout << "Multiple calls to InitAggregationContext" << std::endl;
}
}

void *ProcessAggregation(size_t *size, std::map<int, std::vector<int>> nodes) override;

std::vector<double> HandleAggregation(void *buffer, size_t buf_size) override;

void *ProcessHistograms(size_t *size, const std::vector<double>& histograms) override;

std::vector<double> HandleHistograms(void *buffer, size_t buf_size) override;
};
void BuildEncryptedHistVert(std::size_t const **ridx,
std::size_t const *sizes,
std::int32_t const *nidx, std::size_t len,
std::uint8_t **out_hist, std::size_t *out_len);
void SyncEncryptedHistVert(std::uint8_t *hist_buffer, std::size_t len,
double **out, std::size_t *out_len);
};
} // namespace nvflare
Loading
Loading