diff --git a/velox/exec/fuzzer/AggregationFuzzer.cpp b/velox/exec/fuzzer/AggregationFuzzer.cpp index 2b79320698be..8f9138024522 100644 --- a/velox/exec/fuzzer/AggregationFuzzer.cpp +++ b/velox/exec/fuzzer/AggregationFuzzer.cpp @@ -1052,27 +1052,6 @@ void AggregationFuzzer::Stats::print(size_t numIterations) const { AggregationFuzzerBase::Stats::print(numIterations); } -namespace { -// Merges a vector of RowVectors into one RowVector. -RowVectorPtr mergeRowVectors( - const std::vector& results, - velox::memory::MemoryPool* pool) { - auto totalCount = 0; - for (const auto& result : results) { - totalCount += result->size(); - } - auto copy = - BaseVector::create(results[0]->type(), totalCount, pool); - auto copyCount = 0; - for (const auto& result : results) { - copy->copy(result.get(), copyCount, 0, result->size()); - copyCount += result->size(); - } - return copy; -} - -} // namespace - bool AggregationFuzzer::compareEquivalentPlanResults( const std::vector& plans, bool customVerification, @@ -1123,8 +1102,8 @@ bool AggregationFuzzer::compareEquivalentPlanResults( if (referenceResult.first) { velox::fuzzer::ResultOrError expected; - expected.result = - mergeRowVectors(referenceResult.first.value(), pool_.get()); + expected.result = fuzzer::mergeRowVectors( + referenceResult.first.value(), pool_.get()); compare( resultOrError, customVerification, {customVerifier}, expected); diff --git a/velox/exec/fuzzer/WindowFuzzer.cpp b/velox/exec/fuzzer/WindowFuzzer.cpp index 6a3b8ff587c5..73f520af2487 100644 --- a/velox/exec/fuzzer/WindowFuzzer.cpp +++ b/velox/exec/fuzzer/WindowFuzzer.cpp @@ -677,6 +677,34 @@ void initializeVerifier( frame, "w0"); } + +template +T getReferenceResult( + const core::PlanNodePtr& plan, + core::PlanNodeId windowNodeId, + const std::string& prestoFrameClause, + ReferenceQueryRunner* referenceQueryRunner) { + auto prestoQueryRunner = + dynamic_cast(referenceQueryRunner); + bool isPrestoQueryRunner = (prestoQueryRunner != nullptr); + if (isPrestoQueryRunner) { + prestoQueryRunner->queryRunnerContext() + ->windowFrames_[windowNodeId] + .push_back(prestoFrameClause); + } + + T referenceResult; + if constexpr (resultAsVector) { + referenceResult = + computeReferenceResultsAsVector(plan, referenceQueryRunner); + } else { + referenceResult = computeReferenceResults(plan, referenceQueryRunner); + } + if (isPrestoQueryRunner) { + prestoQueryRunner->queryRunnerContext()->windowFrames_.clear(); + } + return referenceResult; +} } // namespace bool WindowFuzzer::verifyWindow( @@ -689,12 +717,6 @@ bool WindowFuzzer::verifyWindow( const std::shared_ptr& customVerifier, bool enableWindowVerification, const std::string& prestoFrameClause) { - SCOPE_EXIT { - if (customVerifier) { - customVerifier->reset(); - } - }; - core::PlanNodeId windowNodeId; auto frame = getFrame(partitionKeys, sortingKeysAndOrders, frameClause); auto plan = PlanBuilder() @@ -707,6 +729,28 @@ bool WindowFuzzer::verifyWindow( persistReproInfo({{plan, {}}}, reproPersistPath_); } + bool customVerifierInitialized = false; + if (customVerifier) { + try { + initializeVerifier( + plan, + customVerifier, + input, + partitionKeys, + sortingKeysAndOrders, + frame); + customVerifierInitialized = true; + } catch (...) { + LOG(WARNING) << "Custom verifier initialization failed"; + } + } + + SCOPE_EXIT { + if (customVerifier) { + customVerifier->reset(); + } + }; + velox::fuzzer::ResultOrError resultOrError; try { resultOrError = execute(plan); @@ -714,51 +758,68 @@ bool WindowFuzzer::verifyWindow( ++stats_.numFailed; } - if (!customVerification) { - if (resultOrError.result && enableWindowVerification) { - auto prestoQueryRunner = - dynamic_cast(referenceQueryRunner_.get()); - bool isPrestoQueryRunner = (prestoQueryRunner != nullptr); - if (isPrestoQueryRunner) { - prestoQueryRunner->queryRunnerContext() - ->windowFrames_[windowNodeId] - .push_back(prestoFrameClause); - } - auto referenceResult = - computeReferenceResults(plan, referenceQueryRunner_.get()); - if (isPrestoQueryRunner) { - prestoQueryRunner->queryRunnerContext()->windowFrames_.clear(); + if (resultOrError.result) { + if (!customVerification) { + if (enableWindowVerification) { + auto referenceResult = getReferenceResult< + std::pair< + std::optional, + ReferenceQueryErrorCode>, + false>( + plan, + windowNodeId, + prestoFrameClause, + referenceQueryRunner_.get()); + stats_.updateReferenceQueryStats(referenceResult.second); + if (auto expectedResult = referenceResult.first) { + ++stats_.numVerified; + stats_.verifiedFunctionNames.insert( + retrieveWindowFunctionName(plan)[0]); + VELOX_CHECK( + assertEqualResults( + expectedResult.value(), + plan->outputType(), + {resultOrError.result}), + "Velox and reference DB results don't match"); + LOG(INFO) << "Verified results against reference DB"; + } } - stats_.updateReferenceQueryStats(referenceResult.second); - if (auto expectedResult = referenceResult.first) { - ++stats_.numVerified; - stats_.verifiedFunctionNames.insert( - retrieveWindowFunctionName(plan)[0]); - VELOX_CHECK( - assertEqualResults( - expectedResult.value(), - plan->outputType(), - {resultOrError.result}), - "Velox and reference DB results don't match"); - LOG(INFO) << "Verified results against reference DB"; + } else if (referenceQueryRunner_->supportsVeloxVectorResults()) { + if (enableWindowVerification) { + auto referenceResult = getReferenceResult< + std::pair< + std::optional>, + ReferenceQueryErrorCode>, + true>( + plan, + windowNodeId, + prestoFrameClause, + referenceQueryRunner_.get()); + stats_.updateReferenceQueryStats(referenceResult.second); + if (auto expectedResult = referenceResult.first) { + ++stats_.numVerified; + stats_.verifiedFunctionNames.insert( + retrieveWindowFunctionName(plan)[0]); + velox::fuzzer::ResultOrError expected; + expected.result = fuzzer::mergeRowVectors( + referenceResult.first.value(), pool_.get()); + + if (customVerifier) { + VELOX_CHECK(customVerifierInitialized); + } + compare( + resultOrError, customVerification, {customVerifier}, expected); + LOG(INFO) << "Verified results against reference DB"; + } } - } - } else { - LOG(INFO) << "Verification through custom verifier"; - ++stats_.numVerificationSkipped; - - if (customVerifier && resultOrError.result) { - VELOX_CHECK( - customVerifier->supportsVerify(), - "Window fuzzer only uses custom verify() methods."); - initializeVerifier( - plan, - customVerifier, - input, - partitionKeys, - sortingKeysAndOrders, - frame); + } else if (customVerifier && customVerifier->supportsVerify()) { + LOG(INFO) << "Verification through custom verifier"; + ++stats_.numVerificationSkipped; + + VELOX_CHECK(customVerifierInitialized); customVerifier->verify(resultOrError.result); + } else { + LOG(WARNING) << "No Verification Performed"; } } diff --git a/velox/expression/fuzzer/FuzzerToolkit.cpp b/velox/expression/fuzzer/FuzzerToolkit.cpp index e57319c87954..421c96bf6951 100644 --- a/velox/expression/fuzzer/FuzzerToolkit.cpp +++ b/velox/expression/fuzzer/FuzzerToolkit.cpp @@ -157,6 +157,23 @@ void compareVectors( LOG(INFO) << "Two vectors match."; } +RowVectorPtr mergeRowVectors( + const std::vector& results, + velox::memory::MemoryPool* pool) { + auto totalCount = 0; + for (const auto& result : results) { + totalCount += result->size(); + } + auto copy = + BaseVector::create(results[0]->type(), totalCount, pool); + auto copyCount = 0; + for (const auto& result : results) { + copy->copy(result.get(), copyCount, 0, result->size()); + copyCount += result->size(); + } + return copy; +} + void InputRowMetadata::saveToFile(const char* filePath) const { std::ofstream outputFile(filePath, std::ofstream::binary); saveStdVector(columnsToWrapInLazy, outputFile); diff --git a/velox/expression/fuzzer/FuzzerToolkit.h b/velox/expression/fuzzer/FuzzerToolkit.h index e23f9783e9f5..a6c326250165 100644 --- a/velox/expression/fuzzer/FuzzerToolkit.h +++ b/velox/expression/fuzzer/FuzzerToolkit.h @@ -120,6 +120,11 @@ void compareVectors( const std::string& rightName = "right", const std::optional& rows = std::nullopt); +// Merges a vector of RowVectors into one RowVector. +RowVectorPtr mergeRowVectors( + const std::vector& results, + velox::memory::MemoryPool* pool); + struct InputRowMetadata { // Column indices to wrap in LazyVector (in a strictly increasing order) std::vector columnsToWrapInLazy; diff --git a/velox/functions/prestosql/fuzzer/AverageResultVerifier.h b/velox/functions/prestosql/fuzzer/AverageResultVerifier.h index 230046d21cd0..ec7f22f7f917 100644 --- a/velox/functions/prestosql/fuzzer/AverageResultVerifier.h +++ b/velox/functions/prestosql/fuzzer/AverageResultVerifier.h @@ -52,6 +52,20 @@ class AverageResultVerifier : public ResultVerifier { } } + void initializeWindow( + const std::vector& input, + const std::vector& /*partitionByKeys*/, + const std::vector& /*sortingKeysAndOrders*/, + const core::WindowNode::Function& function, + const std::string& /*frame*/, + const std::string& windowName) override { + if (function.functionCall->type()->isIntervalDayTime()) { + projections_ = asRowType(input[0]->type())->names(); + projections_.push_back( + fmt::format("cast(to_milliseconds({}) as double)", windowName)); + } + } + bool compare(const RowVectorPtr& result, const RowVectorPtr& altResult) override { if (projections_.empty()) { diff --git a/velox/functions/prestosql/fuzzer/WindowFuzzerTest.cpp b/velox/functions/prestosql/fuzzer/WindowFuzzerTest.cpp index d67008d2f005..6453df359af8 100644 --- a/velox/functions/prestosql/fuzzer/WindowFuzzerTest.cpp +++ b/velox/functions/prestosql/fuzzer/WindowFuzzerTest.cpp @@ -25,6 +25,7 @@ #include "velox/functions/prestosql/fuzzer/ApproxDistinctResultVerifier.h" #include "velox/functions/prestosql/fuzzer/ApproxPercentileInputGenerator.h" #include "velox/functions/prestosql/fuzzer/ApproxPercentileResultVerifier.h" +#include "velox/functions/prestosql/fuzzer/AverageResultVerifier.h" #include "velox/functions/prestosql/fuzzer/ClassificationAggregationInputGenerator.h" #include "velox/functions/prestosql/fuzzer/MinMaxInputGenerator.h" #include "velox/functions/prestosql/fuzzer/WindowOffsetInputGenerator.h" @@ -130,6 +131,7 @@ int main(int argc, char** argv) { // TODO: allow custom result verifiers. using facebook::velox::exec::test::ApproxDistinctResultVerifier; using facebook::velox::exec::test::ApproxPercentileResultVerifier; + using facebook::velox::exec::test::AverageResultVerifier; static const std::unordered_map< std::string, @@ -149,6 +151,7 @@ int main(int argc, char** argv) { // https://github.com/facebookincubator/velox/issues/6330 {"max_data_size_for_stats", nullptr}, {"sum_data_size_for_stats", nullptr}, + {"avg", std::make_shared()}, }; static const std::unordered_set orderDependentFunctions = {