From d6e18257a3442f209734e5e6ec5245d5a27a4246 Mon Sep 17 00:00:00 2001 From: Lan Lou Date: Sat, 16 Nov 2024 23:19:39 -0500 Subject: [PATCH] Improve filter tests --- optd-cost-model/src/cost/filter/controller.rs | 143 ++++++++---------- optd-cost-model/src/cost/filter/in_list.rs | 14 +- optd-cost-model/src/cost/filter/like.rs | 13 +- optd-cost-model/src/stats/mod.rs | 12 +- .../src/stats/utilities/simple_map.rs | 9 +- 5 files changed, 93 insertions(+), 98 deletions(-) diff --git a/optd-cost-model/src/cost/filter/controller.rs b/optd-cost-model/src/cost/filter/controller.rs index 39462ae..c10ea1d 100644 --- a/optd-cost-model/src/cost/filter/controller.rs +++ b/optd-cost-model/src/cost/filter/controller.rs @@ -136,11 +136,11 @@ mod tests { #[tokio::test] async fn test_attr_ref_eq_constint_in_mcv() { - let mut mcvs_counts = HashMap::new(); - mcvs_counts.insert(vec![Some(Value::Int32(1))], 3); - let mcvs_total_count = 10; let per_attribute_stats = TestPerAttributeStats::new( - MostCommonValues::Counter(Counter::new_from_existing(mcvs_counts, mcvs_total_count)), + MostCommonValues::SimpleFrequency(SimpleMap::new(vec![( + vec![Some(Value::Int32(1))], + 0.3, + )])), 0, 0.0, None, @@ -170,12 +170,11 @@ mod tests { #[tokio::test] async fn test_attr_ref_eq_constint_not_in_mcv() { - let mut mcvs_counts = HashMap::new(); - mcvs_counts.insert(vec![Some(Value::Int32(1))], 20); - mcvs_counts.insert(vec![Some(Value::Int32(3))], 44); - let mcvs_total_count = 100; let per_attribute_stats = TestPerAttributeStats::new( - MostCommonValues::Counter(Counter::new_from_existing(mcvs_counts, mcvs_total_count)), + MostCommonValues::SimpleFrequency(SimpleMap::new(vec![ + (vec![Some(Value::Int32(1))], 0.2), + (vec![Some(Value::Int32(3))], 0.44), + ])), 5, 0.0, None, @@ -206,11 +205,11 @@ mod tests { /// I only have one test for NEQ since I'll assume that it uses the same underlying logic as EQ #[tokio::test] async fn test_attr_ref_neq_constint_in_mcv() { - let mut mcvs_counts = HashMap::new(); - mcvs_counts.insert(vec![Some(Value::Int32(1))], 3); - let mcvs_total_count = 10; let per_attribute_stats = TestPerAttributeStats::new( - MostCommonValues::Counter(Counter::new_from_existing(mcvs_counts, mcvs_total_count)), + MostCommonValues::SimpleFrequency(SimpleMap::new(vec![( + vec![Some(Value::Int32(1))], + 0.3, + )])), 0, 0.0, None, @@ -240,10 +239,8 @@ mod tests { #[tokio::test] async fn test_attr_ref_leq_constint_no_mcvs_in_range() { - let mut mcvs_counts = HashMap::new(); - let mcvs_total_count = 10; let per_attribute_stats = TestPerAttributeStats::new( - MostCommonValues::Counter(Counter::new_from_existing(mcvs_counts, mcvs_total_count)), + MostCommonValues::SimpleFrequency(SimpleMap::new(vec![])), 10, 0.0, Some(Distribution::SimpleDistribution(SimpleMap::new(vec![( @@ -280,14 +277,13 @@ mod tests { #[tokio::test] async fn test_attr_ref_leq_constint_with_mcvs_in_range_not_at_border() { - let mut mcvs_counts = HashMap::new(); - mcvs_counts.insert(vec![Some(Value::Int32(6))], 5); - mcvs_counts.insert(vec![Some(Value::Int32(10))], 10); - mcvs_counts.insert(vec![Some(Value::Int32(17))], 8); - mcvs_counts.insert(vec![Some(Value::Int32(25))], 7); - let mcvs_total_count = 100; let per_attribute_stats = TestPerAttributeStats::new( - MostCommonValues::Counter(Counter::new_from_existing(mcvs_counts, mcvs_total_count)), + MostCommonValues::SimpleFrequency(SimpleMap::new(vec![ + (vec![Some(Value::Int32(6))], 0.05), + (vec![Some(Value::Int32(10))], 0.1), + (vec![Some(Value::Int32(17))], 0.08), + (vec![Some(Value::Int32(25))], 0.07), + ])), 10, 0.0, Some(Distribution::SimpleDistribution(SimpleMap::new(vec![( @@ -324,14 +320,13 @@ mod tests { #[tokio::test] async fn test_attr_ref_leq_constint_with_mcv_at_border() { - let mut mcvs_counts = HashMap::new(); - mcvs_counts.insert(vec![Some(Value::Int32(6))], 5); - mcvs_counts.insert(vec![Some(Value::Int32(10))], 10); - mcvs_counts.insert(vec![Some(Value::Int32(15))], 8); - mcvs_counts.insert(vec![Some(Value::Int32(25))], 7); - let mcvs_total_count = 100; let per_attribute_stats = TestPerAttributeStats::new( - MostCommonValues::Counter(Counter::new_from_existing(mcvs_counts, mcvs_total_count)), + MostCommonValues::SimpleFrequency(SimpleMap::new(vec![ + (vec![Some(Value::Int32(6))], 0.05), + (vec![Some(Value::Int32(10))], 0.1), + (vec![Some(Value::Int32(15))], 0.08), + (vec![Some(Value::Int32(25))], 0.07), + ])), 10, 0.0, Some(Distribution::SimpleDistribution(SimpleMap::new(vec![( @@ -368,10 +363,8 @@ mod tests { #[tokio::test] async fn test_attr_ref_lt_constint_no_mcvs_in_range() { - let mut mcvs_counts = HashMap::new(); - let mcvs_total_count = 10; let per_attribute_stats = TestPerAttributeStats::new( - MostCommonValues::Counter(Counter::new_from_existing(mcvs_counts, mcvs_total_count)), + MostCommonValues::SimpleFrequency(SimpleMap::new(vec![])), 10, 0.0, Some(Distribution::SimpleDistribution(SimpleMap::new(vec![( @@ -408,14 +401,13 @@ mod tests { #[tokio::test] async fn test_attr_ef_lt_constint_with_mcvs_in_range_not_at_border() { - let mut mcvs_counts = HashMap::new(); - mcvs_counts.insert(vec![Some(Value::Int32(6))], 5); - mcvs_counts.insert(vec![Some(Value::Int32(10))], 10); - mcvs_counts.insert(vec![Some(Value::Int32(17))], 8); - mcvs_counts.insert(vec![Some(Value::Int32(25))], 7); - let mcvs_total_count = 100; let per_attribute_stats = TestPerAttributeStats::new( - MostCommonValues::Counter(Counter::new_from_existing(mcvs_counts, mcvs_total_count)), + MostCommonValues::SimpleFrequency(SimpleMap::new(vec![ + (vec![Some(Value::Int32(6))], 0.05), + (vec![Some(Value::Int32(10))], 0.1), + (vec![Some(Value::Int32(17))], 0.08), + (vec![Some(Value::Int32(25))], 0.07), + ])), 11, /* there are 4 MCVs which together add up to 0.3. With 11 total ndistinct, each * remaining value has freq 0.1 */ 0.0, @@ -453,14 +445,13 @@ mod tests { #[tokio::test] async fn test_attr_ref_lt_constint_with_mcv_at_border() { - let mut mcvs_counts = HashMap::new(); - mcvs_counts.insert(vec![Some(Value::Int32(6))], 5); - mcvs_counts.insert(vec![Some(Value::Int32(10))], 10); - mcvs_counts.insert(vec![Some(Value::Int32(15))], 8); - mcvs_counts.insert(vec![Some(Value::Int32(25))], 7); - let mcvs_total_count = 100; let per_attribute_stats = TestPerAttributeStats::new( - MostCommonValues::Counter(Counter::new_from_existing(mcvs_counts, mcvs_total_count)), + MostCommonValues::SimpleFrequency(SimpleMap::new(vec![ + (vec![Some(Value::Int32(6))], 0.05), + (vec![Some(Value::Int32(10))], 0.1), + (vec![Some(Value::Int32(15))], 0.08), + (vec![Some(Value::Int32(25))], 0.07), + ])), 11, /* there are 4 MCVs which together add up to 0.3. With 11 total ndistinct, each * remaining value has freq 0.1 */ 0.0, @@ -500,10 +491,8 @@ mod tests { /// The only interesting thing to test is that if there are nulls, those aren't included in GT #[tokio::test] async fn test_attr_ref_gt_constint() { - let mut mcvs_counts = HashMap::new(); - let mcvs_total_count = 100; let per_attribute_stats = TestPerAttributeStats::new( - MostCommonValues::Counter(Counter::new_from_existing(mcvs_counts, mcvs_total_count)), + MostCommonValues::SimpleFrequency(SimpleMap::new(vec![])), 10, 0.0, Some(Distribution::SimpleDistribution(SimpleMap::new(vec![( @@ -540,10 +529,8 @@ mod tests { #[tokio::test] async fn test_attr_ref_geq_constint() { - let mut mcvs_counts = HashMap::new(); - let mcvs_total_count = 100; let per_attribute_stats = TestPerAttributeStats::new( - MostCommonValues::Counter(Counter::new_from_existing(mcvs_counts, mcvs_total_count)), + MostCommonValues::SimpleFrequency(SimpleMap::new(vec![])), 10, 0.0, Some(Distribution::SimpleDistribution(SimpleMap::new(vec![( @@ -581,13 +568,12 @@ mod tests { #[tokio::test] async fn test_and() { - let mut mcvs_counts = HashMap::new(); - mcvs_counts.insert(vec![Some(Value::Int32(1))], 3); - mcvs_counts.insert(vec![Some(Value::Int32(5))], 5); - mcvs_counts.insert(vec![Some(Value::Int32(8))], 2); - let mcvs_total_count = 10; let per_attribute_stats = TestPerAttributeStats::new( - MostCommonValues::Counter(Counter::new_from_existing(mcvs_counts, mcvs_total_count)), + MostCommonValues::SimpleFrequency(SimpleMap::new(vec![ + (vec![Some(Value::Int32(1))], 0.3), + (vec![Some(Value::Int32(5))], 0.5), + (vec![Some(Value::Int32(8))], 0.2), + ])), 0, 0.0, None, @@ -629,13 +615,12 @@ mod tests { #[tokio::test] async fn test_or() { - let mut mcvs_counts = HashMap::new(); - mcvs_counts.insert(vec![Some(Value::Int32(1))], 3); - mcvs_counts.insert(vec![Some(Value::Int32(5))], 5); - mcvs_counts.insert(vec![Some(Value::Int32(8))], 2); - let mcvs_total_count = 10; let per_attribute_stats = TestPerAttributeStats::new( - MostCommonValues::Counter(Counter::new_from_existing(mcvs_counts, mcvs_total_count)), + MostCommonValues::SimpleFrequency(SimpleMap::new(vec![ + (vec![Some(Value::Int32(1))], 0.3), + (vec![Some(Value::Int32(5))], 0.5), + (vec![Some(Value::Int32(8))], 0.2), + ])), 0, 0.0, None, @@ -677,11 +662,11 @@ mod tests { #[tokio::test] async fn test_not() { - let mut mcvs_counts = HashMap::new(); - mcvs_counts.insert(vec![Some(Value::Int32(1))], 3); - let mcvs_total_count = 10; let per_attribute_stats = TestPerAttributeStats::new( - MostCommonValues::Counter(Counter::new_from_existing(mcvs_counts, mcvs_total_count)), + MostCommonValues::SimpleFrequency(SimpleMap::new(vec![( + vec![Some(Value::Int32(1))], + 0.3, + )])), 0, 0.0, None, @@ -710,11 +695,11 @@ mod tests { #[tokio::test] async fn test_attr_ref_eq_cast_value() { - let mut mcvs_counts = HashMap::new(); - mcvs_counts.insert(vec![Some(Value::Int32(1))], 3); - let mcvs_total_count = 10; let per_attribute_stats = TestPerAttributeStats::new( - MostCommonValues::Counter(Counter::new_from_existing(mcvs_counts, mcvs_total_count)), + MostCommonValues::SimpleFrequency(SimpleMap::new(vec![( + vec![Some(Value::Int32(1))], + 0.3, + )])), 0, 0.0, None, @@ -753,11 +738,11 @@ mod tests { #[tokio::test] async fn test_cast_attr_ref_eq_value() { - let mut mcvs_counts = HashMap::new(); - mcvs_counts.insert(vec![Some(Value::Int32(1))], 3); - let mcvs_total_count = 10; let per_attribute_stats = TestPerAttributeStats::new( - MostCommonValues::Counter(Counter::new_from_existing(mcvs_counts, mcvs_total_count)), + MostCommonValues::SimpleFrequency(SimpleMap::new(vec![( + vec![Some(Value::Int32(1))], + 0.3, + )])), 0, 0.1, None, @@ -812,10 +797,8 @@ mod tests { /// pretty good signal that the Cast was left as is. #[tokio::test] async fn test_cast_attr_ref_eq_attr_ref() { - let mut mcvs_counts = HashMap::new(); - let mcvs_total_count = 10; let per_attribute_stats = TestPerAttributeStats::new( - MostCommonValues::Counter(Counter::new_from_existing(mcvs_counts, mcvs_total_count)), + MostCommonValues::SimpleFrequency(SimpleMap::new(vec![])), 0, 0.0, None, diff --git a/optd-cost-model/src/cost/filter/in_list.rs b/optd-cost-model/src/cost/filter/in_list.rs index 16080de..2363d4a 100644 --- a/optd-cost-model/src/cost/filter/in_list.rs +++ b/optd-cost-model/src/cost/filter/in_list.rs @@ -73,17 +73,19 @@ mod tests { use crate::{ common::{types::TableId, values::Value}, cost_model::tests::*, - stats::{utilities::counter::Counter, MostCommonValues}, + stats::{ + utilities::{counter::Counter, simple_map::SimpleMap}, + MostCommonValues, + }, }; #[tokio::test] async fn test_in_list() { - let mut mcvs_counts = HashMap::new(); - mcvs_counts.insert(vec![Some(Value::Int32(1))], 8); - mcvs_counts.insert(vec![Some(Value::Int32(2))], 2); - let mcvs_total_count = 10; let per_attribute_stats = TestPerAttributeStats::new( - MostCommonValues::Counter(Counter::new_from_existing(mcvs_counts, mcvs_total_count)), + MostCommonValues::SimpleFrequency(SimpleMap::new(vec![ + (vec![Some(Value::Int32(1))], 0.8), + (vec![Some(Value::Int32(2))], 0.2), + ])), 2, 0.0, None, diff --git a/optd-cost-model/src/cost/filter/like.rs b/optd-cost-model/src/cost/filter/like.rs index f49ca18..03da4d1 100644 --- a/optd-cost-model/src/cost/filter/like.rs +++ b/optd-cost-model/src/cost/filter/like.rs @@ -108,19 +108,18 @@ mod tests { common::{types::TableId, values::Value}, cost_model::tests::*, stats::{ - utilities::counter::Counter, MostCommonValues, FIXED_CHAR_SEL_FACTOR, - FULL_WILDCARD_SEL_FACTOR, + utilities::{counter::Counter, simple_map::SimpleMap}, + MostCommonValues, FIXED_CHAR_SEL_FACTOR, FULL_WILDCARD_SEL_FACTOR, }, }; #[tokio::test] async fn test_like_no_nulls() { - let mut mcvs_counts = HashMap::new(); - mcvs_counts.insert(vec![Some(Value::String("abcd".into()))], 1); - mcvs_counts.insert(vec![Some(Value::String("abc".into()))], 1); - let mcvs_total_count = 10; let per_attribute_stats = TestPerAttributeStats::new( - MostCommonValues::Counter(Counter::new_from_existing(mcvs_counts, mcvs_total_count)), + MostCommonValues::SimpleFrequency(SimpleMap::new(vec![ + (vec![Some(Value::String("abcd".into()))], 0.1), + (vec![Some(Value::String("abc".into()))], 0.1), + ])), 2, 0.0, None, diff --git a/optd-cost-model/src/stats/mod.rs b/optd-cost-model/src/stats/mod.rs index 5440ea1..a20cdf2 100644 --- a/optd-cost-model/src/stats/mod.rs +++ b/optd-cost-model/src/stats/mod.rs @@ -35,6 +35,7 @@ pub type AttributeCombValue = Vec>; #[serde(tag = "type")] pub enum MostCommonValues { Counter(Counter), + SimpleFrequency(SimpleMap), // Add more types here... } @@ -47,12 +48,14 @@ impl MostCommonValues { pub fn freq(&self, value: &AttributeCombValue) -> Option { match self { MostCommonValues::Counter(counter) => counter.frequencies().get(value).copied(), + MostCommonValues::SimpleFrequency(simple_map) => simple_map.m.get(value).copied(), } } pub fn total_freq(&self) -> f64 { match self { MostCommonValues::Counter(counter) => counter.frequencies().values().sum(), + MostCommonValues::SimpleFrequency(simple_map) => simple_map.m.values().sum(), } } @@ -64,6 +67,12 @@ impl MostCommonValues { .filter(|(val, _)| pred(val)) .map(|(_, freq)| freq) .sum(), + MostCommonValues::SimpleFrequency(simple_map) => simple_map + .m + .iter() + .filter(|(val, _)| pred(val)) + .map(|(_, freq)| freq) + .sum(), } } @@ -71,6 +80,7 @@ impl MostCommonValues { pub fn cnt(&self) -> usize { match self { MostCommonValues::Counter(counter) => counter.frequencies().len(), + MostCommonValues::SimpleFrequency(simple_map) => simple_map.m.len(), } } } @@ -80,7 +90,7 @@ impl MostCommonValues { #[serde(tag = "type")] pub enum Distribution { TDigest(TDigest), - SimpleDistribution(SimpleMap), + SimpleDistribution(SimpleMap), // Add more types here... } diff --git a/optd-cost-model/src/stats/utilities/simple_map.rs b/optd-cost-model/src/stats/utilities/simple_map.rs index 5503b2f..f685fe6 100644 --- a/optd-cost-model/src/stats/utilities/simple_map.rs +++ b/optd-cost-model/src/stats/utilities/simple_map.rs @@ -1,4 +1,5 @@ use std::collections::HashMap; +use std::hash::Hash; use serde::{Deserialize, Serialize}; @@ -7,12 +8,12 @@ use crate::common::values::Value; /// TODO: documentation /// Now it is mainly for testing purposes. #[derive(Clone, Serialize, Deserialize, Debug)] -pub struct SimpleMap { - pub(crate) m: HashMap, +pub struct SimpleMap { + pub(crate) m: HashMap, } -impl SimpleMap { - pub fn new(v: Vec<(Value, f64)>) -> Self { +impl SimpleMap { + pub fn new(v: Vec<(K, f64)>) -> Self { Self { m: v.into_iter().collect(), }