Skip to content

Commit

Permalink
revert.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jun 17, 2024
1 parent c1730e3 commit 185ef4b
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 46 deletions.
5 changes: 3 additions & 2 deletions tests/cpp/common/test_quantile.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#include "../../../src/data/simple_dmatrix.h" // SimpleDMatrix
#include "../collective/test_worker.h" // for TestDistributedGlobal
#if defined(XGBOOST_USE_FEDERATED)
#include "../plugin/federated/test_worker.h"
#include "../plugin/federated/test_worker.h" // for TestEncryptedGlobal
#endif // defined(XGBOOST_USE_FEDERATED)
#include "xgboost/context.h"

Expand Down Expand Up @@ -314,6 +314,7 @@ void DoTestColSplitQuantileSecure() {
Context ctx;
auto const world = collective::GetWorldSize();
auto const rank = collective::GetRank();
ASSERT_TRUE(collective::IsEncrypted());
size_t cols = 2;
size_t rows = 3;

Expand Down Expand Up @@ -395,7 +396,7 @@ void DoTestColSplitQuantileSecure() {
template <bool use_column>
void TestColSplitQuantileSecure() {
auto constexpr kWorkers = 2;
collective::TestFederatedGlobal(kWorkers, [&] { DoTestColSplitQuantileSecure<use_column>(); });
collective::TestEncryptedGlobal(kWorkers, [&] { DoTestColSplitQuantileSecure<use_column>(); });
}
} // anonymous namespace

Expand Down
12 changes: 12 additions & 0 deletions tests/cpp/plugin/federated/test_worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,16 @@ void TestFederatedGlobal(std::int32_t n_workers, WorkerFn&& fn) {
collective::Finalize();
});
}

template <typename WorkerFn>
void TestEncryptedGlobal(std::int32_t n_workers, WorkerFn&& fn) {
TestFederatedImpl(n_workers, [&](std::int32_t port, std::int32_t i) {
auto config = FederatedTestConfig(n_workers, port, i);
config["federated_plugin"] = Object{};
config["federated_plugin"]["name"] = String{"mock"};
collective::Init(config);
fn();
collective::Finalize();
});
}
} // namespace xgboost::collective
34 changes: 2 additions & 32 deletions tests/cpp/plugin/test_federated_learner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,21 +51,8 @@ void VerifyObjective(std::size_t rows, std::size_t cols, float expected_base_sco

auto model = MakeModel(tree_method, device, objective, sliced);
auto base_score = GetBaseScore(model);

std::unique_ptr<Learner> expected{Learner::Create({})};
expected->LoadModel(expected_model);

std::unique_ptr<Learner> got{Learner::Create({})};
got->LoadModel(model);

if (rank == 0) {
ASSERT_EQ(base_score, expected_base_score) << " rank " << rank;
HostDeviceVector<float> expected_predt;
expected->Predict(dmat, false, &expected_predt, 0, 0);
HostDeviceVector<float> got_predt;
expected->Predict(dmat, false, &got_predt, 0, 0);
ASSERT_EQ(expected_predt.HostVector(), got_predt.HostVector());
}
ASSERT_EQ(base_score, expected_base_score) << " rank " << rank;
ASSERT_EQ(model, expected_model) << " rank " << rank;
}
} // namespace

Expand All @@ -74,19 +61,6 @@ class VerticalFederatedLearnerTest : public ::testing::TestWithParam<std::string

protected:
void Run(std::string tree_method, std::string device, std::string objective) {
// Following objectives are not yet supported.
if (objective.find("multi:") != std::string::npos) {
GTEST_SKIP();
return;
}
if (objective.find("quantile") != std::string::npos) {
GTEST_SKIP();
return;
}
if (objective.find("absoluteerror") != std::string::npos) {
GTEST_SKIP();
}

static auto constexpr kRows{16};
static auto constexpr kCols{16};

Expand Down Expand Up @@ -132,15 +106,11 @@ TEST_P(VerticalFederatedLearnerTest, Hist) {
#if defined(XGBOOST_USE_CUDA)
TEST_P(VerticalFederatedLearnerTest, GPUApprox) {
std::string objective = GetParam();
// Not yet supported by the plugin system
GTEST_SKIP();
this->Run("approx", DeviceSym::CUDA(), objective);
}

TEST_P(VerticalFederatedLearnerTest, GPUHist) {
std::string objective = GetParam();
// Not yet supported by the plugin system
GTEST_SKIP();
this->Run("hist", DeviceSym::CUDA(), objective);
}
#endif // defined(XGBOOST_USE_CUDA)
Expand Down
12 changes: 0 additions & 12 deletions tests/cpp/test_learner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -756,9 +756,6 @@ void TestColumnSplitWithArgs(std::string const& tree_method, bool use_gpu, Args
auto model = GetModelWithArgs(p_fmat, tree_method, device, args);

auto world_size{3};
if (federated && use_gpu) {
GTEST_SKIP();
}
if (use_gpu) {
world_size = common::AllVisibleGPUs();
// Simulate MPU on a single GPU. Federated doesn't use nccl, can run multiple
Expand Down Expand Up @@ -791,26 +788,17 @@ class ColumnSplitTrainingTest
public:
static void TestColumnSplitColumnSampler(std::string const& tree_method, bool use_gpu,
bool federated) {
if (federated) {
GTEST_SKIP();
}
Args args{
{"colsample_bytree", "0.5"}, {"colsample_bylevel", "0.6"}, {"colsample_bynode", "0.7"}};
TestColumnSplitWithArgs(tree_method, use_gpu, args, federated);
}
static void TestColumnSplitInteractionConstraints(std::string const& tree_method, bool use_gpu,
bool federated) {
if (federated) {
GTEST_SKIP();
}
Args args{{"interaction_constraints", "[[0, 5, 7], [2, 8, 9], [1, 3, 6]]"}};
TestColumnSplitWithArgs(tree_method, use_gpu, args, federated);
}
static void TestColumnSplitMonotoneConstraints(std::string const& tree_method, bool use_gpu,
bool federated) {
if (federated) {
GTEST_SKIP();
}
Args args{{"monotone_constraints", "(1,-1,0,1,1,-1,-1,0,0,1)"}};
TestColumnSplitWithArgs(tree_method, use_gpu, args, federated);
}
Expand Down

0 comments on commit 185ef4b

Please sign in to comment.