diff --git a/src/tests/test_ucc_verifier.cpp b/src/tests/test_ucc_verifier.cpp index 56027a0ffd..3212a823ed 100644 --- a/src/tests/test_ucc_verifier.cpp +++ b/src/tests/test_ucc_verifier.cpp @@ -8,7 +8,7 @@ namespace tests { namespace onam = config::names; -class UCCVerifyingSimpleParams { +class UCCVerifierSimpleParams { private: algos::StdParamsMap params_map_; size_t num_clusters_violating_ucc_ = 0; @@ -16,12 +16,12 @@ class UCCVerifyingSimpleParams { std::vector clusters_violating_ucc_; public: - UCCVerifyingSimpleParams(config::IndicesType column_indices, - size_t const num_clusters_violating_ucc, - size_t const num_rows_violating_ucc, - std::vector clusters_violating_ucc, - std::string_view dataset, char const separator = ',', - bool const has_header = true) + UCCVerifierSimpleParams(config::IndicesType column_indices, + size_t const num_clusters_violating_ucc, + size_t const num_rows_violating_ucc, + std::vector clusters_violating_ucc, + std::string_view dataset, char const separator = ',', + bool const has_header = true) : params_map_({{onam::kUCCIndices, std::move(column_indices)}, {onam::kCsvPath, test_data_dir / dataset}, {onam::kSeparator, separator}, @@ -48,10 +48,39 @@ class UCCVerifyingSimpleParams { } }; -class TestUCCVerifyingSimple : public ::testing::TestWithParam {}; +class TestUCCVerifierSimple : public ::testing::TestWithParam {}; -TEST_P(TestUCCVerifyingSimple, DefaultTest) { - UCCVerifyingSimpleParams const& p(GetParam()); +class UCCVerifierWithHyUCCParams { +private: + algos::StdParamsMap ucc_verifier_params_map_; + algos::StdParamsMap hyucc_params_map_; + +public: + explicit UCCVerifierWithHyUCCParams(std::string_view dataset, char const separator = ',', + bool const has_header = true) + : ucc_verifier_params_map_({{onam::kCsvPath, test_data_dir / dataset}, + {onam::kSeparator, separator}, + {onam::kHasHeader, has_header}, + {onam::kEqualNulls, true}}), + hyucc_params_map_({{onam::kThreads, static_cast(1)}, + {onam::kCsvPath, test_data_dir / dataset}, + {onam::kSeparator, separator}, + {onam::kHasHeader, has_header}, + {onam::kEqualNulls, true}}) {} + + algos::StdParamsMap GetUCCVerifierParamsMap() const { + return ucc_verifier_params_map_; + } + + algos::StdParamsMap GetHyUCCParamsMap() const { + return hyucc_params_map_; + } +}; + +class TestUCCVerifierWithHyUCC : public ::testing::TestWithParam {}; + +TEST_P(TestUCCVerifierSimple, DefaultTest) { + UCCVerifierSimpleParams const& p(GetParam()); auto verifier = algos::CreateAndLoadAlgorithm(p.GetParamsMap()); verifier->Execute(); EXPECT_EQ(verifier->UCCHolds(), p.GetExpectedNumClustersViolatingUCC() == 0); @@ -61,17 +90,57 @@ TEST_P(TestUCCVerifyingSimple, DefaultTest) { } INSTANTIATE_TEST_SUITE_P( - UCCVerifierSimpleTestSuite, TestUCCVerifyingSimple, - ::testing::Values(UCCVerifyingSimpleParams({0}, 1, 12, - {{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}}, - "TestFD.csv"), - UCCVerifyingSimpleParams({0, 1}, 4, 12, - {{0, 1, 2}, {3, 4, 5}, {6, 7, 8}, {9, 10, 11}}, - "TestFD.csv"), - UCCVerifyingSimpleParams({0, 1, 2}, 4, 8, - {{0, 1}, {3, 4}, {6, 7}, {9, 10}}, "TestFD.csv"), - UCCVerifyingSimpleParams({0, 1, 2, 3, 4, 5}, 3, 6, - {{3, 4}, {6, 7}, {9, 10}}, "TestFD.csv"), - UCCVerifyingSimpleParams({0}, 0, 0, {}, "TestWide.csv"), - UCCVerifyingSimpleParams({0, 1, 2, 3, 4}, 0, 0, {}, "TestWide.csv"))); + UCCVerifierSimpleTestSuite, TestUCCVerifierSimple, + ::testing::Values(UCCVerifierSimpleParams({0}, 1, 12, + {{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}}, + "TestFD.csv"), + UCCVerifierSimpleParams({0, 1}, 4, 12, + {{0, 1, 2}, {3, 4, 5}, {6, 7, 8}, {9, 10, 11}}, + "TestFD.csv"), + UCCVerifierSimpleParams({0, 1, 2}, 4, 8, + {{0, 1}, {3, 4}, {6, 7}, {9, 10}}, "TestFD.csv"), + UCCVerifierSimpleParams({0, 1, 2, 3, 4, 5}, 3, 6, + {{3, 4}, {6, 7}, {9, 10}}, "TestFD.csv"), + UCCVerifierSimpleParams({0}, 0, 0, {}, "TestWide.csv"), + UCCVerifierSimpleParams({0, 1, 2, 3, 4}, 0, 0, {}, "TestWide.csv"))); + +TEST_P(TestUCCVerifierWithHyUCC, TestWithHyUCC) { + UCCVerifierWithHyUCCParams const& p(GetParam()); + + auto hyucc = algos::CreateAndLoadAlgorithm(p.GetHyUCCParamsMap()); + hyucc->Execute(); + std::list const& mined_uccs = hyucc->UCCList(); + + // run ucc_verifier on each UCC from mined_ucc (UCC must hold) + for (auto const& current_ucc : mined_uccs) { + algos::StdParamsMap ucc_verifier_params_map = p.GetUCCVerifierParamsMap(); + ucc_verifier_params_map.insert({onam::kUCCIndices, current_ucc.GetColumnIndicesAsVector()}); + auto verifier = algos::CreateAndLoadAlgorithm( + std::move(ucc_verifier_params_map)); + verifier->Execute(); + EXPECT_TRUE(verifier->UCCHolds()); + } + + // Сases of prevent false negative triggering + // run ucc_verifier on each UCC from mined_ucc with one column index removed + // (UCC must not hold because HyUCC returns minimal UCCs) + for (auto const& current_ucc : mined_uccs) { + algos::StdParamsMap ucc_verifier_params_map = p.GetUCCVerifierParamsMap(); + std::vector current_ucc_vec = current_ucc.GetColumnIndicesAsVector(); + if (current_ucc_vec.size() < 2) { + continue; + } + current_ucc_vec.erase(--current_ucc_vec.end()); + ucc_verifier_params_map.insert({onam::kUCCIndices, std::move(current_ucc_vec)}); + auto verifier = algos::CreateAndLoadAlgorithm( + std::move(ucc_verifier_params_map)); + verifier->Execute(); + EXPECT_FALSE(verifier->UCCHolds()); + } +} + +INSTANTIATE_TEST_SUITE_P(UCCVerifierWithHyUCCTestSuite, TestUCCVerifierWithHyUCC, + ::testing::Values(UCCVerifierWithHyUCCParams("abalone.csv", ',', false), + UCCVerifierWithHyUCCParams("breast_cancer.csv"), + UCCVerifierWithHyUCCParams("CIPublicHighway10k.csv"))); } // namespace tests