Skip to content

Commit

Permalink
Add tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jun 17, 2024
1 parent 185ef4b commit dc57ef8
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 19 deletions.
4 changes: 2 additions & 2 deletions tests/cpp/objective/test_objective.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2016-2023 by XGBoost contributors
* Copyright 2016-2024, XGBoost contributors
*/
#include <gtest/gtest.h>
#include <xgboost/context.h>
Expand Down Expand Up @@ -99,6 +99,6 @@ TEST_P(TestDefaultObjConfig, Objective) {
INSTANTIATE_TEST_SUITE_P(Objective, TestDefaultObjConfig,
::testing::ValuesIn(MakeObjNamesForTest()),
[](const ::testing::TestParamInfo<TestDefaultObjConfig::ParamType>& info) {
return ObjTestNameGenerator(info);
return ObjTestNameGenerator(info.param);
});
} // namespace xgboost
4 changes: 1 addition & 3 deletions tests/cpp/objective_helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@ inline auto MakeObjNamesForTest() {
return names;
}

template <typename ParamType>
inline std::string ObjTestNameGenerator(const ::testing::TestParamInfo<ParamType>& info) {
auto name = std::string{info.param};
inline std::string ObjTestNameGenerator(std::string name) {
// Name must be a valid c++ symbol
auto it = std::find(name.cbegin(), name.cend(), ':');
if (it != name.cend()) {
Expand Down
68 changes: 55 additions & 13 deletions tests/cpp/plugin/test_federated_learner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,23 @@ void VerifyObjective(std::size_t rows, std::size_t cols, float expected_base_sco
}
} // namespace

class VerticalFederatedLearnerTest : public ::testing::TestWithParam<std::string> {
class VerticalFederatedLearnerTest
: public ::testing::TestWithParam<std::tuple<std::string, bool>> {
static int constexpr kWorldSize{3};

protected:
void Run(std::string tree_method, std::string device, std::string objective) {
void Run(std::string tree_method, std::string device, std::string objective, bool is_encrypted) {
// Following objectives are not yet supported.
if (is_encrypted) {
std::vector<std::string> unsupported{"multi:", "quantile", "absoluteerror"};
auto skip = std::any_of(unsupported.cbegin(), unsupported.cend(), [&](auto const &name) {
return objective.find(name) != std::string::npos;
});
if (skip) {
GTEST_SKIP_("Not supported by the plugin.");
}
}

static auto constexpr kRows{16};
static auto constexpr kCols{16};

Expand Down Expand Up @@ -91,34 +103,64 @@ class VerticalFederatedLearnerTest : public ::testing::TestWithParam<std::string
VerifyObjective(kRows, kCols, score, model, tree_method, device, objective);
});
}

auto GetTestParam() {
std::string objective = get<0>(GetParam());
auto is_encrypted = get<1>(GetParam());
return std::make_tuple(objective, is_encrypted);
}
};

namespace {
auto MakeTestParams() {
auto objs = MakeObjNamesForTest();
std::vector<std::tuple<std::string, bool>> values;
for (auto const &v : objs) {
values.emplace_back(v, true);
values.emplace_back(v, false);
}
return values;
}
} // namespace

TEST_P(VerticalFederatedLearnerTest, Approx) {
std::string objective = GetParam();
this->Run("approx", DeviceSym::CPU(), objective);
auto [objective, is_encrypted] = this->GetTestParam();
if (is_encrypted) {
GTEST_SKIP();
}
this->Run("approx", DeviceSym::CPU(), objective, is_encrypted);
}

TEST_P(VerticalFederatedLearnerTest, Hist) {
std::string objective = GetParam();
this->Run("hist", DeviceSym::CPU(), objective);
auto [objective, is_encrypted] = this->GetTestParam();
this->Run("hist", DeviceSym::CPU(), objective, is_encrypted);
}

#if defined(XGBOOST_USE_CUDA)
TEST_P(VerticalFederatedLearnerTest, GPUApprox) {
std::string objective = GetParam();
this->Run("approx", DeviceSym::CUDA(), objective);
auto [objective, is_encrypted] = this->GetTestParam();
if (is_encrypted) {
GTEST_SKIP();
}
this->Run("approx", DeviceSym::CUDA(), objective, is_encrypted);
}

TEST_P(VerticalFederatedLearnerTest, GPUHist) {
std::string objective = GetParam();
this->Run("hist", DeviceSym::CUDA(), objective);
auto [objective, is_encrypted] = this->GetTestParam();
if (is_encrypted) {
GTEST_SKIP();
}
this->Run("hist", DeviceSym::CUDA(), objective, is_encrypted);
}
#endif // defined(XGBOOST_USE_CUDA)

INSTANTIATE_TEST_SUITE_P(
FederatedLearnerObjective, VerticalFederatedLearnerTest,
::testing::ValuesIn(MakeObjNamesForTest()),
FederatedLearnerObjective, VerticalFederatedLearnerTest, ::testing::ValuesIn(MakeTestParams()),
[](const ::testing::TestParamInfo<VerticalFederatedLearnerTest::ParamType> &info) {
return ObjTestNameGenerator(info);
auto name = ObjTestNameGenerator(std::get<0>(info.param));
if (std::get<1>(info.param)) {
name += "_enc";
}
return name;
});
} // namespace xgboost
2 changes: 1 addition & 1 deletion tests/cpp/test_learner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -717,7 +717,7 @@ TEST_P(TestColumnSplit, Objective) {
INSTANTIATE_TEST_SUITE_P(ColumnSplitObjective, TestColumnSplit,
::testing::ValuesIn(MakeObjNamesForTest()),
[](const ::testing::TestParamInfo<TestColumnSplit::ParamType>& info) {
return ObjTestNameGenerator(info);
return ObjTestNameGenerator(info.param);
});

namespace {
Expand Down

0 comments on commit dc57ef8

Please sign in to comment.