diff --git a/tests/cpp/objective/test_objective.cc b/tests/cpp/objective/test_objective.cc index efdd03612a0f..ee2cb403dcfe 100644 --- a/tests/cpp/objective/test_objective.cc +++ b/tests/cpp/objective/test_objective.cc @@ -1,5 +1,5 @@ /** - * Copyright 2016-2023 by XGBoost contributors + * Copyright 2016-2024, XGBoost contributors */ #include #include @@ -99,6 +99,6 @@ TEST_P(TestDefaultObjConfig, Objective) { INSTANTIATE_TEST_SUITE_P(Objective, TestDefaultObjConfig, ::testing::ValuesIn(MakeObjNamesForTest()), [](const ::testing::TestParamInfo& info) { - return ObjTestNameGenerator(info); + return ObjTestNameGenerator(info.param); }); } // namespace xgboost diff --git a/tests/cpp/objective_helpers.h b/tests/cpp/objective_helpers.h index 972747c36e21..0951f2c61109 100644 --- a/tests/cpp/objective_helpers.h +++ b/tests/cpp/objective_helpers.h @@ -21,9 +21,7 @@ inline auto MakeObjNamesForTest() { return names; } -template -inline std::string ObjTestNameGenerator(const ::testing::TestParamInfo& info) { - auto name = std::string{info.param}; +inline std::string ObjTestNameGenerator(std::string name) { // Name must be a valid c++ symbol auto it = std::find(name.cbegin(), name.cend(), ':'); if (it != name.cend()) { diff --git a/tests/cpp/plugin/test_federated_learner.cc b/tests/cpp/plugin/test_federated_learner.cc index 4d150e27311f..7e639e45b9bd 100644 --- a/tests/cpp/plugin/test_federated_learner.cc +++ b/tests/cpp/plugin/test_federated_learner.cc @@ -56,11 +56,23 @@ void VerifyObjective(std::size_t rows, std::size_t cols, float expected_base_sco } } // namespace -class VerticalFederatedLearnerTest : public ::testing::TestWithParam { +class VerticalFederatedLearnerTest + : public ::testing::TestWithParam> { static int constexpr kWorldSize{3}; protected: - void Run(std::string tree_method, std::string device, std::string objective) { + void Run(std::string tree_method, std::string device, std::string objective, bool is_encrypted) { + // Following objectives are not yet supported. + if (is_encrypted) { + std::vector unsupported{"multi:", "quantile", "absoluteerror"}; + auto skip = std::any_of(unsupported.cbegin(), unsupported.cend(), [&](auto const &name) { + return objective.find(name) != std::string::npos; + }); + if (skip) { + GTEST_SKIP_("Not supported by the plugin."); + } + } + static auto constexpr kRows{16}; static auto constexpr kCols{16}; @@ -91,34 +103,64 @@ class VerticalFederatedLearnerTest : public ::testing::TestWithParam(GetParam()); + auto is_encrypted = get<1>(GetParam()); + return std::make_tuple(objective, is_encrypted); + } }; +namespace { +auto MakeTestParams() { + auto objs = MakeObjNamesForTest(); + std::vector> values; + for (auto const &v : objs) { + values.emplace_back(v, true); + values.emplace_back(v, false); + } + return values; +} +} // namespace + TEST_P(VerticalFederatedLearnerTest, Approx) { - std::string objective = GetParam(); - this->Run("approx", DeviceSym::CPU(), objective); + auto [objective, is_encrypted] = this->GetTestParam(); + if (is_encrypted) { + GTEST_SKIP(); + } + this->Run("approx", DeviceSym::CPU(), objective, is_encrypted); } TEST_P(VerticalFederatedLearnerTest, Hist) { - std::string objective = GetParam(); - this->Run("hist", DeviceSym::CPU(), objective); + auto [objective, is_encrypted] = this->GetTestParam(); + this->Run("hist", DeviceSym::CPU(), objective, is_encrypted); } #if defined(XGBOOST_USE_CUDA) TEST_P(VerticalFederatedLearnerTest, GPUApprox) { - std::string objective = GetParam(); - this->Run("approx", DeviceSym::CUDA(), objective); + auto [objective, is_encrypted] = this->GetTestParam(); + if (is_encrypted) { + GTEST_SKIP(); + } + this->Run("approx", DeviceSym::CUDA(), objective, is_encrypted); } TEST_P(VerticalFederatedLearnerTest, GPUHist) { - std::string objective = GetParam(); - this->Run("hist", DeviceSym::CUDA(), objective); + auto [objective, is_encrypted] = this->GetTestParam(); + if (is_encrypted) { + GTEST_SKIP(); + } + this->Run("hist", DeviceSym::CUDA(), objective, is_encrypted); } #endif // defined(XGBOOST_USE_CUDA) INSTANTIATE_TEST_SUITE_P( - FederatedLearnerObjective, VerticalFederatedLearnerTest, - ::testing::ValuesIn(MakeObjNamesForTest()), + FederatedLearnerObjective, VerticalFederatedLearnerTest, ::testing::ValuesIn(MakeTestParams()), [](const ::testing::TestParamInfo &info) { - return ObjTestNameGenerator(info); + auto name = ObjTestNameGenerator(std::get<0>(info.param)); + if (std::get<1>(info.param)) { + name += "_enc"; + } + return name; }); } // namespace xgboost diff --git a/tests/cpp/test_learner.cc b/tests/cpp/test_learner.cc index 976ae2147a06..407dc73023d5 100644 --- a/tests/cpp/test_learner.cc +++ b/tests/cpp/test_learner.cc @@ -717,7 +717,7 @@ TEST_P(TestColumnSplit, Objective) { INSTANTIATE_TEST_SUITE_P(ColumnSplitObjective, TestColumnSplit, ::testing::ValuesIn(MakeObjNamesForTest()), [](const ::testing::TestParamInfo& info) { - return ObjTestNameGenerator(info); + return ObjTestNameGenerator(info.param); }); namespace {