Skip to content

Commit

Permalink
is encrypted.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jun 17, 2024
1 parent 0c057a1 commit c1730e3
Show file tree
Hide file tree
Showing 9 changed files with 26 additions and 22 deletions.
7 changes: 2 additions & 5 deletions plugin/federated/federated_comm.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

#include <chrono> // for seconds
#include <cstdint> // for int32_t
#include <memory> // for shared_ptr, dynamic_pointer_cast
#include <memory> // for shared_ptr
#include <string> // for string

#include "../../src/collective/comm.h" // for HostComm
Expand Down Expand Up @@ -65,10 +65,7 @@ class FederatedComm : public HostComm {
return Success();
}
[[nodiscard]] bool IsFederated() const override { return true; }
[[nodiscard]] bool IsEncrypted() const override {
auto mock_ptr = std::dynamic_pointer_cast<FederatedPluginMock>(plugin_);
return !mock_ptr;
}
[[nodiscard]] bool IsEncrypted() const override { return static_cast<bool>(plugin_); }
[[nodiscard]] federated::Federated::Stub* Handle() const { return stub_.get(); }

[[nodiscard]] Comm* MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl) const override;
Expand Down
6 changes: 5 additions & 1 deletion plugin/federated/federated_plugin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,13 @@ FederatedPlugin::~FederatedPlugin() = default;
[[nodiscard]] FederatedPluginBase* CreateFederatedPlugin(Json config) {
auto plugin = OptionalArg<Object>(config, "federated_plugin", Object::Map{});
if (!plugin.empty()) {
auto name_it = plugin.find("name");
if (name_it != plugin.cend() && get<String const>(name_it->second) == "mock") {
return new FederatedPluginMock{};
}
auto path = get<String>(plugin["path"]);
return new FederatedPlugin{path, config};
}
return new FederatedPluginMock{};
return nullptr;
}
} // namespace xgboost::collective
2 changes: 1 addition & 1 deletion src/collective/aggregator.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ void ApplyWithLabels(Context const* ctx, MetaInfo const& info, void* buffer, std
template <typename T, typename Fn>
void ApplyWithLabels(Context const* ctx, MetaInfo const& info, HostDeviceVector<T>* result,
Fn&& fn) {
if (info.IsColumnSplit()) {
if (info.IsVerticalFederated()) {
// We assume labels are only available on worker 0, so the calculation is done there
// and result is broadcasted to other workers.
auto rc = detail::TryApplyWithLabels(ctx, std::forward<Fn>(fn));
Expand Down
7 changes: 5 additions & 2 deletions src/collective/communicator-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,12 @@ void Finalize();
*
* @return True if the communicator is federated.
*/
[[nodiscard]] bool IsFederated() noexcept;
[[nodiscard]] bool IsFederated();

[[nodiscard]] bool IsEncrypted() noexcept;
/**
* @brief Get if the communicator has an encryption plugin.
*/
[[nodiscard]] bool IsEncrypted();

/**
* @brief Print the message to the communicator.
Expand Down
8 changes: 4 additions & 4 deletions src/common/quantile.cc
Original file line number Diff line number Diff line change
Expand Up @@ -448,22 +448,22 @@ void SketchContainerImpl<WQSketch>::MakeCuts(Context const *ctx, MetaInfo const
std::int32_t max_num_bins = std::min(num_cuts[fid], max_bins_);
// If vertical and secure mode, we need to sync the max_num_bins aross workers
// to create the same global number of cut point bins for easier future processing
if (info.IsVerticalFederated()) {
if (info.IsVerticalFederated() && collective::IsEncrypted()) {
collective::SafeColl(collective::Allreduce(ctx, &max_num_bins, collective::Op::kMax));
}
typename WQSketch::SummaryContainer const &a = final_summaries[fid];
if (IsCat(feature_types_, fid)) {
max_cat = std::max(max_cat, AddCategories(categories_.at(fid), p_cuts));
} else {
// use special AddCutPoint scheme for secure vertical federated learning
bool is_nan = AddCutPoint<WQSketch>(ctx, a, max_num_bins, p_cuts, collective::IsFederated());
bool is_nan = AddCutPoint<WQSketch>(ctx, a, max_num_bins, p_cuts, collective::IsEncrypted());
// push a value that is greater than anything if the feature is not empty
// i.e. if the last value is not NaN
if (!is_nan) {
const bst_float cpt =
const float cpt =
(a.size > 0) ? a.data[a.size - 1].value : p_cuts->min_vals_.HostVector()[fid];
// this must be bigger than last value in a scale
const bst_float last = cpt + (fabs(cpt) + 1e-5f);
const float last = cpt + (fabs(cpt) + 1e-5f);
p_cuts->cut_values_.HostVector().push_back(last);
} else {
// if the feature is empty, push a NaN value
Expand Down
2 changes: 1 addition & 1 deletion src/learner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1470,7 +1470,7 @@ class LearnerImpl : public LearnerIO {
std::int32_t iter, linalg::Matrix<GradientPair>* out_gpair) {
monitor_.Start(__func__);
out_gpair->Reshape(info.num_row_, this->learner_model_param_.OutputLength());
if (info.IsVerticalFederated()) {
if (info.IsVerticalFederated() && collective::IsEncrypted()) {
#if defined(XGBOOST_USE_FEDERATED)
// Need to encrypt the gradient before broadcasting.
common::Span<std::uint8_t> encrypted;
Expand Down
10 changes: 5 additions & 5 deletions src/tree/hist/evaluate_splits.h
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ class HistEvaluator {
ibegin = static_cast<bst_bin_t>(cut_ptr[fidx + 1]) - 1;
iend = static_cast<bst_bin_t>(cut_ptr[fidx]) - 1;
}
bool fed_vertical = is_secure_ && is_col_split_;
bool enc_vertical = is_secure_ && is_col_split_;

for (bst_bin_t i = ibegin; i != iend; i += d_step) {
// start working
Expand All @@ -306,7 +306,7 @@ class HistEvaluator {
loss_chg =
static_cast<float>(evaluator.CalcSplitGain(*param_, nidx, fidx, GradStats{left_sum},
GradStats{right_sum}) - parent.root_gain);
if (!fed_vertical) {
if (!enc_vertical) {
split_pt = cut_val[i]; // not used for partition based
best.Update(loss_chg, fidx, split_pt, d_step == -1, false, left_sum, right_sum);
} else {
Expand All @@ -319,7 +319,7 @@ class HistEvaluator {
loss_chg =
static_cast<float>(evaluator.CalcSplitGain(*param_, nidx, fidx, GradStats{right_sum},
GradStats{left_sum}) - parent.root_gain);
if (!fed_vertical) {
if (!enc_vertical) {
if (i == imin) {
split_pt = cut.MinValues()[fidx];
} else {
Expand Down Expand Up @@ -516,7 +516,7 @@ class HistEvaluator {
column_sampler_{std::move(sampler)},
tree_evaluator_{*param, static_cast<bst_feature_t>(info.num_col_), DeviceOrd::CPU()},
is_col_split_{info.IsColumnSplit()},
is_secure_{collective::IsFederated()} {
is_secure_{collective::IsEncrypted()} {
interaction_constraints_.Configure(*param, info.num_col_);
column_sampler_->Init(ctx, info.num_col_, info.feature_weights.HostVector(),
param_->colsample_bynode, param_->colsample_bylevel,
Expand Down Expand Up @@ -747,7 +747,7 @@ class HistMultiEvaluator {
column_sampler_{std::move(sampler)},
ctx_{ctx},
is_col_split_{info.IsColumnSplit()},
is_secure_{collective::IsFederated()} {
is_secure_{collective::IsEncrypted()} {
interaction_constraints_.Configure(*param, info.num_col_);
column_sampler_->Init(ctx, info.num_col_, info.feature_weights.HostVector(),
param_->colsample_bynode, param_->colsample_bylevel,
Expand Down
2 changes: 1 addition & 1 deletion src/tree/updater_approx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class GloablApproxBuilder {

histogram_builder_.Reset(ctx_, n_total_bins, p_tree->NumTargets(), BatchSpec(*param_, hess),
collective::IsDistributed(), p_fmat->Info().IsColumnSplit(),
collective::IsFederated(), hist_param_);
collective::IsEncrypted(), hist_param_);
monitor_->Stop(__func__);
}

Expand Down
4 changes: 2 additions & 2 deletions src/tree/updater_quantile_hist.cc
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ class MultiTargetHistBuilder {
histogram_builder_ = std::make_unique<MultiHistogramBuilder>();
histogram_builder_->Reset(ctx_, n_total_bins, n_targets, HistBatch(param_),
collective::IsDistributed(), p_fmat->Info().IsColumnSplit(),
collective::IsFederated(), hist_param_);
collective::IsEncrypted(), hist_param_);

evaluator_ = std::make_unique<HistMultiEvaluator>(ctx_, p_fmat->Info(), param_, col_sampler_);
p_last_tree_ = p_tree;
Expand Down Expand Up @@ -355,7 +355,7 @@ class HistUpdater {
fmat->Info().IsColumnSplit());
}
histogram_builder_->Reset(ctx_, n_total_bins, 1, HistBatch(param_), collective::IsDistributed(),
fmat->Info().IsColumnSplit(), collective::IsFederated(), hist_param_);
fmat->Info().IsColumnSplit(), collective::IsEncrypted(), hist_param_);
evaluator_ = std::make_unique<HistEvaluator>(ctx_, this->param_, fmat->Info(), col_sampler_);
p_last_tree_ = p_tree;
monitor_->Stop(__func__);
Expand Down

0 comments on commit c1730e3

Please sign in to comment.