Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Federated learning plugin. #10410

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 4 additions & 10 deletions include/xgboost/data.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ enum class DataType : uint8_t {

enum class FeatureType : uint8_t { kNumerical = 0, kCategorical = 1 };

enum class DataSplitMode : int { kRow = 0, kCol = 1, kColSecure = 2 };
enum class DataSplitMode : int { kRow = 0, kCol = 1 };

/*!
* \brief Meta information about dataset, always sit in memory.
Expand Down Expand Up @@ -171,17 +171,11 @@ class MetaInfo {
*/
void SynchronizeNumberOfColumns(Context const* ctx);

/*! \brief Whether the data is split row-wise. */
bool IsRowSplit() const {
return data_split_mode == DataSplitMode::kRow;
}
/** @brief Whether the data is split row-wise. */
[[nodiscard]] bool IsRowSplit() const { return data_split_mode == DataSplitMode::kRow; }

/** @brief Whether the data is split column-wise. */
bool IsColumnSplit() const { return (data_split_mode == DataSplitMode::kCol)
|| (data_split_mode == DataSplitMode::kColSecure); }

/** @brief Whether the data is split column-wise with secure computation. */
bool IsSecure() const { return data_split_mode == DataSplitMode::kColSecure; }
[[nodiscard]] bool IsColumnSplit() const { return !this->IsRowSplit(); }

/** @brief Whether this is a learning to rank data. */
bool IsRanking() const { return !group_ptr_.empty(); }
Expand Down
5 changes: 5 additions & 0 deletions include/xgboost/linalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,11 @@ auto MakeVec(T *ptr, size_t s, DeviceOrd device = DeviceOrd::CPU()) {
return linalg::TensorView<T, 1>{{ptr, s}, {s}, device};
}

template <typename T>
auto MakeVec(common::Span<T> data, DeviceOrd device = DeviceOrd::CPU()) {
return linalg::TensorView<T, 1>{data, {data.size()}, device};
}

template <typename T>
auto MakeVec(HostDeviceVector<T> *data) {
return MakeVec(data->Device().IsCPU() ? data->HostPointer() : data->DevicePointer(), data->Size(),
Expand Down
2 changes: 1 addition & 1 deletion plugin/example/custom_obj.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class MyLogistic : public ObjFunction {

void SaveConfig(Json* p_out) const override {
auto& out = *p_out;
out["name"] = String("my_logistic");
out["name"] = String("mylogistic");
out["my_logistic_param"] = ToJson(param_);
}

Expand Down
3 changes: 2 additions & 1 deletion plugin/federated/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ target_link_libraries(federated_client INTERFACE federated_proto)

# Rabit engine for Federated Learning.
target_sources(
objxgboost PRIVATE federated_tracker.cc federated_comm.cc federated_coll.cc
objxgboost PRIVATE
federated_plugin.cc federated_hist.cc federated_tracker.cc federated_comm.cc federated_coll.cc
)
if(USE_CUDA)
target_sources(objxgboost PRIVATE federated_comm.cu federated_coll.cu)
Expand Down
10 changes: 7 additions & 3 deletions plugin/federated/federated_coll.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2023, XGBoost contributors
* Copyright 2023-2024, XGBoost contributors
*/
#include "federated_coll.h"

Expand All @@ -8,11 +8,15 @@

#include <algorithm> // for copy_n

#include "../../src/collective/allgather.h"
#include "../../src/common/common.h" // for AssertGPUSupport
#include "federated_comm.h" // for FederatedComm
#include "xgboost/collective/result.h" // for Result

#if !defined(XGBOOST_USE_CUDA)

#include "../../src/common/common.h" // for AssertGPUSupport

#endif // !defined(XGBOOST_USE_CUDA)

namespace xgboost::collective {
namespace {
[[nodiscard]] Result GetGRPCResult(std::string const &name, grpc::Status const &status) {
Expand Down
18 changes: 12 additions & 6 deletions plugin/federated/federated_comm.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2023, XGBoost contributors
* Copyright 2023-2024, XGBoost contributors
*/
#include "federated_comm.h"

Expand All @@ -8,6 +8,7 @@
#include <cstdint> // for int32_t
#include <cstdlib> // for getenv
#include <limits> // for numeric_limits
#include <memory> // for make_shared
#include <string> // for string, stoi

#include "../../src/common/common.h" // for Split
Expand All @@ -31,7 +32,9 @@ void FederatedComm::Init(std::string const& host, std::int32_t port, std::int32_
CHECK_LT(rank, world) << "Invalid worker rank.";

auto certs = {server_cert, client_cert, client_cert};
auto is_empty = [](auto const& s) { return s.empty(); };
auto is_empty = [](auto const& s) {
return s.empty();
};
bool valid = std::all_of(certs.begin(), certs.end(), is_empty) ||
std::none_of(certs.begin(), certs.end(), is_empty);
CHECK(valid) << "Invalid arguments for certificates.";
Expand All @@ -53,8 +56,8 @@ void FederatedComm::Init(std::string const& host, std::int32_t port, std::int32_
args.SetMaxReceiveMessageSize(std::numeric_limits<std::int32_t>::max());
auto channel = grpc::CreateCustomChannel(host + ":" + std::to_string(port),
grpc::SslCredentials(options), args);
channel->WaitForConnected(
gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), gpr_time_from_seconds(60, GPR_TIMESPAN)));
channel->WaitForConnected(gpr_time_add(
gpr_now(GPR_CLOCK_REALTIME), gpr_time_from_seconds(DefaultTimeoutSec(), GPR_TIMESPAN)));
return federated::Federated::NewStub(channel);
}();
}
Expand Down Expand Up @@ -90,8 +93,6 @@ FederatedComm::FederatedComm(std::int32_t retry, std::chrono::seconds timeout, s
auto parsed = common::Split(server_address, ':');
CHECK_EQ(parsed.size(), 2) << "Invalid server address:" << server_address;

CHECK_NE(rank, -1) << "Parameter `federated_rank` is required";
CHECK_NE(world_size, 0) << "Parameter `federated_world_size` is required.";
CHECK(!server_address.empty()) << "Parameter `federated_server_address` is required.";

/**
Expand Down Expand Up @@ -124,6 +125,11 @@ FederatedComm::FederatedComm(std::int32_t retry, std::chrono::seconds timeout, s
client_key = OptionalArg<String>(config, "federated_client_key_path", client_key);
client_cert = OptionalArg<String>(config, "federated_client_cert_path", client_cert);

/**
* Hist encryption plugin.
*/
this->plugin_.reset(CreateFederatedPlugin(config));

this->Init(parsed[0], std::stoi(parsed[1]), world_size, rank, server_cert, client_key,
client_cert);
}
Expand Down
15 changes: 9 additions & 6 deletions plugin/federated/federated_comm.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,20 @@
#include <federated.grpc.pb.h>
#include <federated.pb.h>

#include <chrono> // for seconds
#include <cstdint> // for int32_t
#include <memory> // for unique_ptr
#include <memory> // for shared_ptr
#include <string> // for string

#include "../../src/collective/comm.h" // for HostComm
#include "../../src/collective/comm.h" // for HostComm
#include "federated_plugin.h" // for FederatedPlugin
#include "xgboost/json.h"

namespace xgboost::collective {
class FederatedComm : public HostComm {
std::shared_ptr<federated::Federated::Stub> stub_;
// Plugin for encryption
std::shared_ptr<FederatedPluginBase> plugin_{nullptr};

void Init(std::string const& host, std::int32_t port, std::int32_t world, std::int32_t rank,
std::string const& server_cert, std::string const& client_key,
Expand Down Expand Up @@ -46,10 +50,6 @@ class FederatedComm : public HostComm {
*/
explicit FederatedComm(std::int32_t retry, std::chrono::seconds timeout, std::string task_id,
Json const& config);
explicit FederatedComm(std::string const& host, std::int32_t port, std::int32_t world,
std::int32_t rank) {
this->Init(host, port, world, rank, {}, {}, {});
}
[[nodiscard]] Result Shutdown() final {
this->ResetState();
return Success();
Expand All @@ -65,6 +65,7 @@ class FederatedComm : public HostComm {
return Success();
}
[[nodiscard]] bool IsFederated() const override { return true; }
[[nodiscard]] bool IsEncrypted() const override { return static_cast<bool>(plugin_); }
[[nodiscard]] federated::Federated::Stub* Handle() const { return stub_.get(); }

[[nodiscard]] Comm* MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl) const override;
Expand All @@ -76,5 +77,7 @@ class FederatedComm : public HostComm {
*out = "rank:" + std::to_string(rank);
return Success();
};

auto EncryptionPlugin() const { return plugin_; }
};
} // namespace xgboost::collective
149 changes: 149 additions & 0 deletions plugin/federated/federated_hist.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
/**
* Copyright 2024, XGBoost contributors
*/
#include "federated_hist.h"

#include "../../src/collective/allgather.h" // for AllgatherV
#include "../../src/collective/communicator-inl.h" // for GetRank
#include "../../src/tree/hist/histogram.h" // for SubtractHistParallel, BuildSampleHistograms

namespace xgboost::tree {
template <bool any_missing>
void FederataedHistPolicy::DoBuildLocalHistograms(
common::BlockedSpace2d const &space, GHistIndexMatrix const &gidx,
std::vector<bst_node_t> const &nodes_to_build,
common::RowSetCollection const &row_set_collection, common::Span<GradientPair const> gpair_h,
bool force_read_by_column, common::ParallelGHistBuilder *buffer) {
if (is_col_split_) {
// Call the interface to transmit gidx information to the secure worker for encrypted
// histogram computation
auto cuts = gidx.Cuts().Ptrs();
// fixme: this can be done during reset.
if (!is_aggr_context_initialized_) {
auto slots = std::vector<int>();
auto num_rows = gidx.Size();
for (std::size_t row = 0; row < num_rows; row++) {
for (std::size_t f = 0; f < cuts.size() - 1; f++) {
auto slot = gidx.GetGindex(row, f);
slots.push_back(slot);
}
}
plugin_->Reset(cuts, slots);
is_aggr_context_initialized_ = true;
}

// Further use the row set collection info to
// get the encrypted histogram from the secure worker
std::vector<std::uint64_t const *> ptrs(nodes_to_build.size());
std::vector<std::size_t> sizes(nodes_to_build.size());
std::vector<bst_node_t> nodes(nodes_to_build.size());
for (std::size_t i = 0; i < nodes_to_build.size(); ++i) {
auto nidx = nodes_to_build[i];
ptrs[i] = row_set_collection[nidx].begin;
sizes[i] = row_set_collection[nidx].Size();
nodes[i] = nidx;
}
hist_data_ = this->plugin_->BuildEncryptedHistVert(ptrs, sizes, nodes);
} else {
BuildSampleHistograms<any_missing>(this->n_threads_, space, gidx, nodes_to_build,
row_set_collection, gpair_h, force_read_by_column, buffer);
}
}

template void FederataedHistPolicy::DoBuildLocalHistograms<true>(
common::BlockedSpace2d const &space, GHistIndexMatrix const &gidx,
std::vector<bst_node_t> const &nodes_to_build,
common::RowSetCollection const &row_set_collection, common::Span<GradientPair const> gpair_h,
bool force_read_by_column, common::ParallelGHistBuilder *buffer);
template void FederataedHistPolicy::DoBuildLocalHistograms<false>(
common::BlockedSpace2d const &space, GHistIndexMatrix const &gidx,
std::vector<bst_node_t> const &nodes_to_build,
common::RowSetCollection const &row_set_collection, common::Span<GradientPair const> gpair_h,
bool force_read_by_column, common::ParallelGHistBuilder *buffer);

void FederataedHistPolicy::DoSyncHistogram(Context const *ctx, RegTree const *p_tree,
std::vector<bst_node_t> const &nodes_to_build,
std::vector<bst_node_t> const &nodes_to_trick,
common::ParallelGHistBuilder *buffer,
tree::BoundedHistCollection *p_hist) {
auto n_total_bins = buffer->TotalBins();
common::BlockedSpace2d space(
nodes_to_build.size(), [&](std::size_t) { return n_total_bins; }, 1024);
CHECK(!nodes_to_build.empty());

auto &hist = *p_hist;
if (is_col_split_) {
// Under secure vertical mode, we perform allgather to get the global histogram. Note
// that only the label owner (rank == 0) needs the global histogram

// Perform AllGather
HostDeviceVector<std::int8_t> hist_entries;
std::vector<std::int64_t> recv_segments;
collective::SafeColl(
collective::AllgatherV(ctx, linalg::MakeVec(hist_data_), &recv_segments, &hist_entries));

// Call interface here to post-process the messages
common::Span<double> hist_aggr =
plugin_->SyncEncryptedHistVert(common::RestoreType<std::uint8_t>(hist_entries.HostSpan()));

// Update histogram for label owner
if (collective::GetRank() == 0) {
// iterator of the beginning of the vector
bst_node_t n_nodes = nodes_to_build.size();
std::int32_t n_workers = collective::GetWorldSize();
bst_idx_t worker_size = hist_aggr.size() / n_workers;
CHECK_EQ(hist_aggr.size() % n_workers, 0);
// Initialize histogram. For the normal case, this is done by the parallel hist
// buffer. We should try to unify the code paths.
for (auto nidx : nodes_to_build) {
auto hist_dst = hist[nidx];
std::fill_n(hist_dst.data(), hist_dst.size(), GradientPairPrecise{});
}

// for each worker
for (auto widx = 0; widx < n_workers; ++widx) {
auto worker_hist = hist_aggr.subspan(widx * worker_size, worker_size);
// for each node
for (bst_node_t nidx_in_set = 0; nidx_in_set < n_nodes; ++nidx_in_set) {
auto hist_src = worker_hist.subspan(n_total_bins * 2 * nidx_in_set, n_total_bins * 2);
auto hist_src_g = common::RestoreType<GradientPairPrecise>(hist_src);
auto hist_dst = hist[nodes_to_build[nidx_in_set]];
CHECK_EQ(hist_src_g.size(), hist_dst.size());
common::IncrementHist(hist_dst, hist_src_g, 0, hist_dst.size());
}
}
}
} else {
common::ParallelFor2d(space, this->n_threads_, [&](std::size_t node, common::Range1d r) {
// Merging histograms from each thread.
buffer->ReduceHist(node, r.begin(), r.end());
});
// Secure mode, we need to call interface to perform encryption and decryption
// note that the actual aggregation will be performed at server side
auto first_nidx = nodes_to_build.front();
std::size_t n = n_total_bins * nodes_to_build.size() * 2;
auto hist_to_aggr = std::vector<double>();
for (std::size_t hist_idx = 0; hist_idx < n; hist_idx++) {
double hist_item = reinterpret_cast<double *>(hist[first_nidx].data())[hist_idx];
hist_to_aggr.push_back(hist_item);
}
// ProcessHistograms
auto hist_buf = plugin_->BuildEncryptedHistHori(hist_to_aggr);

// allgather
HostDeviceVector<std::int8_t> hist_entries;
std::vector<std::int64_t> recv_segments;
auto rc = collective::AllgatherV(ctx, linalg::MakeVec(hist_buf), &recv_segments, &hist_entries);
collective::SafeColl(rc);

auto hist_aggr =
plugin_->SyncEncryptedHistHori(common::RestoreType<std::uint8_t>(hist_entries.HostSpan()));
// Assign the aggregated histogram back to the local histogram
for (std::size_t hist_idx = 0; hist_idx < n; hist_idx++) {
reinterpret_cast<double *>(hist[first_nidx].data())[hist_idx] = hist_aggr[hist_idx];
}
}

SubtractHistParallel(ctx, space, p_tree, nodes_to_build, nodes_to_trick, buffer, p_hist);
}
} // namespace xgboost::tree
Loading
Loading