diff --git a/src/core/algorithms/algorithm_types.h b/src/core/algorithms/algorithm_types.h index 45836120cb..25f078b92c 100644 --- a/src/core/algorithms/algorithm_types.h +++ b/src/core/algorithms/algorithm_types.h @@ -10,8 +10,8 @@ using AlgorithmTypes = std::tuple; + Fastod, order::Order, od_verifier::ODVerifier, GfdValidation, EGfdValidation, + NaiveGfdValidation, dd::Split>; // clang-format off /* Enumeration of all supported non-pipeline algorithms. If you implement a new @@ -66,15 +66,16 @@ BETTER_ENUM(AlgorithmType, char, /* Order dependency mining algorithms */ fastod, + order, + +/* Canonical OD verifier algorithm */ + od_verifier, /* Graph functional dependency mining algorithms */ gfdvalid, egfdvalid, naivegfdvalid, -/* Order dependency mining algorithms */ - order, - /* Differential dependencies mining algorithm */ split ) diff --git a/src/core/algorithms/algorithms.h b/src/core/algorithms/algorithms.h index 0076f7706e..feed9a9f3e 100644 --- a/src/core/algorithms/algorithms.h +++ b/src/core/algorithms/algorithms.h @@ -10,6 +10,7 @@ #include "algorithms/ind/mining_algorithms.h" #include "algorithms/metric/verification_algorithms.h" #include "algorithms/od/mining_algorithms.h" +#include "algorithms/od/verification_algorithms.h" #include "algorithms/statistics/algorithms.h" #include "algorithms/ucc/mining_algorithms.h" #include "algorithms/ucc/verification_algorithms.h" diff --git a/src/core/algorithms/od/fastod/partitions/complex_stripped_partition.h b/src/core/algorithms/od/fastod/partitions/complex_stripped_partition.h index 881d5a5b94..17e0c3988c 100644 --- a/src/core/algorithms/od/fastod/partitions/complex_stripped_partition.h +++ b/src/core/algorithms/od/fastod/partitions/complex_stripped_partition.h @@ -9,7 +9,7 @@ namespace algos::fastod { class ComplexStrippedPartition { -private: +protected: std::shared_ptr> sp_indexes_; std::shared_ptr> sp_begins_; std::shared_ptr> rb_indexes_; diff --git a/src/core/algorithms/od/od_verifier/od_verifier.cpp b/src/core/algorithms/od/od_verifier/od_verifier.cpp new file mode 100644 index 0000000000..c3fcb7b4a6 --- /dev/null +++ b/src/core/algorithms/od/od_verifier/od_verifier.cpp @@ -0,0 +1,75 @@ +#include "od_verifier.h" + +#include "ascending_od/option.h" +#include "config/equal_nulls/option.h" +#include "config/indices/od_context.h" +#include "config/indices/option.h" +#include "config/tabular_data/input_table/option.h" + +namespace algos::od_verifier { + +ODVerifier::ODVerifier() : Algorithm({}) { + RegisterOptions(); + MakeOptionsAvailable({config::kTableOpt.GetName(), config::kEqualNullsOpt.GetName()}); +} + +void ODVerifier::RegisterOptions() { + auto get_schema_cols = [this]() { return relation_->GetSchema()->GetNumColumns(); }; + + IndicesType lhs_indices_, rhs_indices_; + RegisterOption(config::kTableOpt(&input_table_)); + RegisterOption(config::kEqualNullsOpt(&is_null_equal_null_)); + RegisterOption(config::kLhsIndicesOpt(&lhs_indices_, get_schema_cols)); + RegisterOption(config::kRhsIndicesOpt(&rhs_indices_, get_schema_cols)); + RegisterOption(config::kODContextOpt(&context_indices_)); + RegisterOption(config::kAscendingODOpt(&ascending_)); + lhs_indicex_ = lhs_indices_[0]; + rhs_indicex_ = rhs_indices_[0]; +} + +void ODVerifier::MakeExecuteOptsAvailable() { + MakeOptionsAvailable({config::kLhsIndicesOpt.GetName(), config::kRhsIndicesOpt.GetName(), + config::kODContextOpt.GetName(), config::kAscendingODOpt.GetName()}); +} + +void ODVerifier::LoadDataInternal() { + relation_ = ColumnLayoutRelationData::CreateFrom(*input_table_, is_null_equal_null_); + + if (relation_->GetColumnData().empty()) { + throw std::runtime_error("Got an empty dataset: OD verifying is meaningless."); + } + input_table_->Reset(); + data_ = std::make_shared(DataFrame::FromInputTable(input_table_)); + if (data_->GetColumnCount() == 0) { + throw std::runtime_error("Got an empty dataset: OD verifying is meaningless."); + } +} + +unsigned long long ODVerifier::ExecuteInternal() { + auto start_time = std::chrono::system_clock::now(); + if (ascending_) { + VerifyOD(); + } else { + VerifyOD(); + } + auto elapsed_milliseconds = std::chrono::duration_cast( + std::chrono::system_clock::now() - start_time); + return elapsed_milliseconds.count(); +} + +// checks whether the OD has broken +bool ODVerifier::ODHolds() const { + return row_violate_ods_by_swap_.empty() && row_violate_ods_by_split_.empty(); +} + +// Returns the number of rows that violate the OD by split +size_t ODVerifier::GetNumRowsViolateBySplit() const { + return row_violate_ods_by_split_.size(); +} + +// Returns the number of rows that violate the OD by swap +size_t ODVerifier::GetNumRowsViolateBySwap() const { + return row_violate_ods_by_swap_.size(); +} + +} // namespace algos::od_verifier diff --git a/src/core/algorithms/od/od_verifier/od_verifier.h b/src/core/algorithms/od/od_verifier/od_verifier.h new file mode 100644 index 0000000000..4f2008629e --- /dev/null +++ b/src/core/algorithms/od/od_verifier/od_verifier.h @@ -0,0 +1,101 @@ +#pragma once + +#include "algorithms/algorithm.h" +#include "algorithms/od/fastod/model/canonical_od.h" +#include "config/indices/type.h" +#include "model/table/column_layout_relation_data.h" +#include "partition.h" + +namespace algos::od_verifier { + +class ODVerifier : public Algorithm { +private: + using IndicesType = config::IndicesType; + using IndexType = config::IndexType; + using DataFrame = fastod::DataFrame; + using PartitionCache = fastod::PartitionCache; + using AscCanonicalOD = fastod::AscCanonicalOD; + using DescCanonicalOD = fastod::DescCanonicalOD; + using SimpleCanonicalOD = fastod::SimpleCanonicalOD; + using AttributeSet = fastod::AttributeSet; + + // input data + config::InputTable input_table_; + config::EqNullsType is_null_equal_null_; + IndexType lhs_indicex_; + IndexType rhs_indicex_; + IndicesType context_indices_; + bool ascending_; + + // auxiliary data + std::shared_ptr relation_; + std::shared_ptr data_; + PartitionCache partition_cache_; + + // rows that vioalates ods + std::vector row_violate_ods_by_swap_; + std::vector row_violate_ods_by_split_; + + // load input data + void RegisterOptions(); + void MakeExecuteOptsAvailable() override; + void LoadDataInternal() override; + + // runs the algorithm and measures its time + unsigned long long ExecuteInternal() override; + + // checks whether OD is violated and finds the rows where it is violated + template + void VerifyOD() { + AttributeSet context; + + for (auto column : context_indices_) context.Set(column); + + fastod::ComplexStrippedPartition stripped_partition_swap( + (partition_cache_.GetStrippedPartition(context, data_))); + + if (stripped_partition_swap.Swap(lhs_indicex_, rhs_indicex_)) { + ComplaxStrippedPartition part{stripped_partition_swap}; + std::vector> violates( + part.FindViolationsBySwap(lhs_indicex_, rhs_indicex_)); + + for (auto position_violate : violates) + row_violate_ods_by_swap_.push_back(position_violate.second + 1); + } + + context.Set(lhs_indicex_); + fastod::ComplexStrippedPartition stripped_partition_split( + partition_cache_.GetStrippedPartition(context, data_)); + + if (stripped_partition_split.Split(rhs_indicex_)) { + ComplaxStrippedPartition part{stripped_partition_split}; + std::vector> violates(part.FindViolationsBySplit(rhs_indicex_)); + + for (auto position_violate : violates) + row_violate_ods_by_split_.push_back(position_violate.second + 1); + } + std::sort(row_violate_ods_by_split_.begin(), row_violate_ods_by_split_.end()); + std::sort(row_violate_ods_by_swap_.begin(), row_violate_ods_by_swap_.end()); + } + + // reset statistic of violations + void ResetState() override { + row_violate_ods_by_swap_.clear(); + row_violate_ods_by_split_.clear(); + } + +public: + // base constructor + ODVerifier(); + + // checks whether the OD has broken + bool ODHolds() const; + + // Returns the number of rows that violate the OD by split + size_t GetNumRowsViolateBySplit() const; + + // Returns the number of rows that violate the OD by swap + size_t GetNumRowsViolateBySwap() const; +}; + +} // namespace algos::od_verifier diff --git a/src/core/algorithms/od/od_verifier/partition.cpp b/src/core/algorithms/od/od_verifier/partition.cpp new file mode 100644 index 0000000000..1e1ae9d933 --- /dev/null +++ b/src/core/algorithms/od/od_verifier/partition.cpp @@ -0,0 +1,58 @@ +#include "partition.h" + +#include +#include + +namespace algos::od_verifier { + +std::vector +ComplaxStrippedPartition::CommonViolationBySplit(model::ColumnIndex right) const { + std::vector violates; + + for (size_t begin_pointer = 0; begin_pointer < sp_begins_->size() - 1; begin_pointer++) { + size_t const group_begin = (*sp_begins_)[begin_pointer]; + size_t const group_end = (*sp_begins_)[begin_pointer + 1]; + + int const group_value = data_->GetValue((*sp_indexes_)[group_begin], right); + + for (size_t i = group_begin + 1; i < group_end; i++) { + if (data_->GetValue((*sp_indexes_)[i], right) != group_value) { + violates.emplace_back(right, (*sp_indexes_)[i]); + } + } + } + + return violates; +} + +std::vector +ComplaxStrippedPartition::RangeBasedViolationBySplit(model::ColumnIndex right) const { + std::vector violates; + + for (size_t begin_pointer = 0; begin_pointer < rb_begins_->size() - 1; ++begin_pointer) { + size_t const group_begin = (*rb_begins_)[begin_pointer]; + size_t const group_end = (*rb_begins_)[begin_pointer + 1]; + + int const group_value = data_->GetValue((*rb_indexes_)[group_begin].first, right); + + for (size_t i = group_begin; i < group_end; ++i) { + algos::fastod::DataFrame::Range const range = (*rb_indexes_)[i]; + + for (size_t j = range.first; j <= range.second; ++j) { + if (data_->GetValue(j, right) != group_value) { + violates.emplace_back(right, j); + } + } + } + } + + return violates; +} + +std::vector +ComplaxStrippedPartition::FindViolationsBySplit(model::ColumnIndex right) const { + return is_stripped_partition_ ? CommonViolationBySplit(right) + : RangeBasedViolationBySplit(right); +} + +} // namespace algos::od_verifier diff --git a/src/core/algorithms/od/od_verifier/partition.h b/src/core/algorithms/od/od_verifier/partition.h new file mode 100644 index 0000000000..64244d231d --- /dev/null +++ b/src/core/algorithms/od/od_verifier/partition.h @@ -0,0 +1,92 @@ +#pragma once + +#include "algorithms/od/fastod/partitions/complex_stripped_partition.h" + +namespace algos::od_verifier { + +class ComplaxStrippedPartition : protected algos::fastod::ComplexStrippedPartition { +private: + using ViolationDescription = std::pair; + + std::vector CommonViolationBySplit(model::ColumnIndex right) const; + + std::vector RangeBasedViolationBySplit(model::ColumnIndex right) const; + +public: + ComplaxStrippedPartition() : algos::fastod::ComplexStrippedPartition() {} + + ComplaxStrippedPartition(algos::fastod::ComplexStrippedPartition const& daddy) + : algos::fastod::ComplexStrippedPartition(daddy) {} + + std::vector FindViolationsBySplit(model::ColumnIndex right) const; + + template + std::vector FindViolationsBySwap(model::ColumnIndex left, + model::ColumnIndex right) const { + size_t const group_count = is_stripped_partition_ ? sp_begins_->size() : rb_begins_->size(); + std::vector violates; + + for (size_t begin_pointer = 0; begin_pointer < group_count - 1; begin_pointer++) { + size_t const group_begin = is_stripped_partition_ ? (*sp_begins_)[begin_pointer] + : (*rb_begins_)[begin_pointer]; + + size_t const group_end = is_stripped_partition_ ? (*sp_begins_)[begin_pointer + 1] + : (*rb_begins_)[begin_pointer + 1]; + + std::vector> values; + std::vector row_pos; + + if (is_stripped_partition_) { + values.reserve(group_end - group_begin); + + for (size_t i = group_begin; i < group_end; ++i) { + size_t const index = (*sp_indexes_)[i]; + + values.emplace_back(data_->GetValue(index, left), + data_->GetValue(index, right)); + row_pos.emplace_back(index); + } + } else { + for (size_t i = group_begin; i < group_end; ++i) { + algos::fastod::DataFrame::Range const range = (*rb_indexes_)[i]; + + for (size_t j = range.first; j <= range.second; ++j) { + values.emplace_back(data_->GetValue(j, left), data_->GetValue(j, right)); + } + } + } + + if constexpr (Ascending) { + std::sort(values.begin(), values.end(), + [](auto const& p1, auto const& p2) { return p1.first < p2.first; }); + } else { + std::sort(values.begin(), values.end(), + [](auto const& p1, auto const& p2) { return p2.first < p1.first; }); + } + + size_t prev_group_max_index = 0; + size_t current_group_max_index = 0; + bool is_first_group = true; + + for (size_t i = 0; i < values.size(); i++) { + auto const& [first, second] = values[i]; + + if (i != 0 && values[i - 1].first != first) { + is_first_group = false; + prev_group_max_index = current_group_max_index; + current_group_max_index = i; + } else if (values[current_group_max_index].second <= second) { + current_group_max_index = i; + } + + if (!is_first_group && values[prev_group_max_index].second > second) { + violates.emplace_back(right, row_pos[i]); + } + } + } + + return violates; + } +}; + +} // namespace algos::od_verifier diff --git a/src/core/algorithms/od/verification_algorithms.h b/src/core/algorithms/od/verification_algorithms.h new file mode 100644 index 0000000000..d3028d84ab --- /dev/null +++ b/src/core/algorithms/od/verification_algorithms.h @@ -0,0 +1,3 @@ +#pragma once + +#include "algorithms/od/od_verifier/od_verifier.h" diff --git a/src/core/config/ascending_od/option.cpp b/src/core/config/ascending_od/option.cpp new file mode 100644 index 0000000000..a6074dade8 --- /dev/null +++ b/src/core/config/ascending_od/option.cpp @@ -0,0 +1,9 @@ +#include "ascending_od/option.h" + +#include "ascending_od/type.h" +#include "config/names_and_descriptions.h" + +namespace config { +extern CommonOption const kAscendingODOpt{names::kAscendingOD, + descriptions::kDAscendingOD, true}; +} // namespace config diff --git a/src/core/config/ascending_od/option.h b/src/core/config/ascending_od/option.h new file mode 100644 index 0000000000..5ee7737a63 --- /dev/null +++ b/src/core/config/ascending_od/option.h @@ -0,0 +1,9 @@ +#pragma once + +#include "config/ascending_od/type.h" +#include "config/common_option.h" + +namespace config { +extern CommonOption const kAscendingODOpt; + +} // namespace config diff --git a/src/core/config/ascending_od/type.h b/src/core/config/ascending_od/type.h new file mode 100644 index 0000000000..f5b6ad9150 --- /dev/null +++ b/src/core/config/ascending_od/type.h @@ -0,0 +1,5 @@ +#pragma once + +namespace config { +using AscendingODFlagType = bool; +} // namespace config diff --git a/src/core/config/descriptions.h b/src/core/config/descriptions.h index 4e8b85c3ea..b5b541668d 100644 --- a/src/core/config/descriptions.h +++ b/src/core/config/descriptions.h @@ -44,6 +44,8 @@ constexpr auto kDItemColumnIndex = "index of the column where an item name is st constexpr auto kDFirstColumnTId = "indicates that the first column contains the transaction IDs"; auto const kDMetric = details::kDMetricString.c_str(); constexpr auto kDLhsIndices = "LHS column indices"; +constexpr auto kDODContext = "context columns indices"; +constexpr auto kDAscendingOD = "flag shows whether the dependence is ascending or descending"; constexpr auto kDRhsIndices = "RHS column indices"; constexpr auto kDRhsIndex = "RHS column index"; constexpr auto kDUCCIndices = "column indices for UCC verification"; diff --git a/src/core/config/indices/od_context.cpp b/src/core/config/indices/od_context.cpp new file mode 100644 index 0000000000..3c476ee7fa --- /dev/null +++ b/src/core/config/indices/od_context.cpp @@ -0,0 +1,9 @@ +#include "config/indices/od_context.h" + +#include "config/names_and_descriptions.h" +#include "indices/type.h" + +namespace config { +extern CommonOption const kODContextOpt{names::kODContext, descriptions::kDODContext, + IndicesType({})}; +} // namespace config diff --git a/src/core/config/indices/od_context.h b/src/core/config/indices/od_context.h new file mode 100644 index 0000000000..7930af1ee7 --- /dev/null +++ b/src/core/config/indices/od_context.h @@ -0,0 +1,8 @@ +#pragma once + +#include "config/common_option.h" +#include "config/indices/type.h" + +namespace config { +extern CommonOption const kODContextOpt; +} // namespace config diff --git a/src/core/config/names.h b/src/core/config/names.h index 74f532cc95..5ecd9eadf3 100644 --- a/src/core/config/names.h +++ b/src/core/config/names.h @@ -29,6 +29,8 @@ constexpr auto kRhsIndex = "rhs_index"; constexpr auto kUCCIndices = "ucc_indices"; constexpr auto kParameter = "parameter"; constexpr auto kDistFromNullIsInfinity = "dist_from_null_is_infinity"; +constexpr auto kODContext = "od_context"; +constexpr auto kAscendingOD = "ascending"; constexpr auto kQGramLength = "q"; constexpr auto kMetricAlgorithm = "metric_algorithm"; constexpr auto kRadius = "radius"; diff --git a/src/tests/all_csv_configs.cpp b/src/tests/all_csv_configs.cpp index e1bdf6bcd2..cc42aebb0c 100644 --- a/src/tests/all_csv_configs.cpp +++ b/src/tests/all_csv_configs.cpp @@ -46,6 +46,7 @@ CSVConfig const kTestEmpty = CreateCsvConfig("TestEmpty.csv", ',', true); CSVConfig const kTestSingleColumn = CreateCsvConfig("TestSingleColumn.csv", ',', true); CSVConfig const kTestLong = CreateCsvConfig("TestLong.csv", ',', true); CSVConfig const kTestFD = CreateCsvConfig("TestFD.csv", ',', true); +CSVConfig const kTestODVerifier = CreateCsvConfig("ODVerificationData.csv", ',', true); CSVConfig const kOdTestNormOd = CreateCsvConfig("od_norm_data/OD_norm.csv", ',', true); CSVConfig const kOdTestNormSmall2x3 = CreateCsvConfig("od_norm_data/small_2x3.csv", ',', true); CSVConfig const kOdTestNormSmall3x3 = CreateCsvConfig("od_norm_data/small_3x3.csv", ',', true); diff --git a/src/tests/all_csv_configs.h b/src/tests/all_csv_configs.h index 9564a1f747..a8e874cbdb 100644 --- a/src/tests/all_csv_configs.h +++ b/src/tests/all_csv_configs.h @@ -36,6 +36,7 @@ extern CSVConfig const kTestEmpty; extern CSVConfig const kTestSingleColumn; extern CSVConfig const kTestLong; extern CSVConfig const kTestFD; +extern CSVConfig const kTestODVerifier; extern CSVConfig const kOdTestNormOd; extern CSVConfig const kOdTestNormSmall2x3; extern CSVConfig const kOdTestNormSmall3x3; diff --git a/src/tests/test_ind_util.h b/src/tests/test_ind_util.h index 63209a5006..64f838db2b 100644 --- a/src/tests/test_ind_util.h +++ b/src/tests/test_ind_util.h @@ -1,5 +1,6 @@ #pragma once #include +#include #include #include #include diff --git a/src/tests/test_od_verifier.cpp b/src/tests/test_od_verifier.cpp new file mode 100644 index 0000000000..2edfec2cb7 --- /dev/null +++ b/src/tests/test_od_verifier.cpp @@ -0,0 +1,52 @@ +#include + +#include "algo_factory.h" +#include "all_csv_configs.h" +#include "config/names.h" +#include "od/od_verifier/od_verifier.h" + +namespace tests { + +struct ODVerifyingParams { + algos::StdParamsMap params; + size_t const number_of_rows_violate_by_split = 0; + size_t const number_of_rows_violate_by_swap = 0; + + ODVerifyingParams(config::IndicesType lhs_indices, config::IndicesType rhs_indices, + config::IndicesType context, bool ascending, size_t const row_error_split = 0, + size_t const row_error_swap = 0, + CSVConfig const& csv_config = kTestODVerifier) + : params({{config::names::kCsvConfig, csv_config}, + {config::names::kLhsIndices, std::move(lhs_indices)}, + {config::names::kRhsIndices, std::move(rhs_indices)}, + {config::names::kODContext, std::move(context)}, + {config::names::kAscendingOD, ascending}}), + number_of_rows_violate_by_split(row_error_split), + number_of_rows_violate_by_swap(row_error_swap) {} +}; + +class TestODVerifying : public ::testing::TestWithParam {}; + +TEST_P(TestODVerifying, DefaultTest) { + auto const& p = GetParam(); + auto mp = algos::StdParamsMap(p.params); + auto verifier = algos::CreateAndLoadAlgorithm(mp); + verifier->Execute(); + EXPECT_EQ(verifier->GetNumRowsViolateBySwap(), p.number_of_rows_violate_by_swap); + EXPECT_EQ(verifier->GetNumRowsViolateBySplit(), p.number_of_rows_violate_by_split); +} + +// clang-format off +INSTANTIATE_TEST_SUITE_P( + ODVerifierTestSuite, TestODVerifying, + ::testing::Values( + ODVerifyingParams({1}, {2}, {0}, true, 0, 0), + ODVerifyingParams({1}, {2}, {}, true, 1, 2), + ODVerifyingParams({3}, {4}, {0}, true, 0, 1), + ODVerifyingParams({1}, {2}, {0}, false, 0, 3), + ODVerifyingParams({3}, {4}, {0}, false, 0, 2), + ODVerifyingParams({5}, {6}, {0}, true, 1, 0) + )); +// clang-format on + +} // namespace tests diff --git a/test_input_data/ODVerificationData.csv b/test_input_data/ODVerificationData.csv new file mode 100644 index 0000000000..05eaa71bcf --- /dev/null +++ b/test_input_data/ODVerificationData.csv @@ -0,0 +1,7 @@ +1,2,3,4,5,6,7 +2020,10,1000,10,1000,10,1000 +2020,20,2000,20,2000,10,1000 +2020,30,3000,30,10,10,1001 +2021,10,1000,40,1000,10,1002 +2021,20,1500,50,2000,10,1002 +2022,5,10000,60,1000,10,1003