From 814c3d6037a957418e2d28748257cc064c66d081 Mon Sep 17 00:00:00 2001 From: Lan Lou Date: Sat, 16 Nov 2024 12:50:13 -0500 Subject: [PATCH] Resolve conflict with main --- Cargo.lock | 22 ++++ optd-cost-model/src/cost/agg.rs | 20 ++-- optd-cost-model/src/cost/filter/attribute.rs | 50 ++++---- optd-cost-model/src/cost/filter/comp_op.rs | 42 ++++--- optd-cost-model/src/cost/filter/controller.rs | 113 +++++++++--------- optd-cost-model/src/cost/filter/in_list.rs | 19 +-- optd-cost-model/src/cost/filter/like.rs | 7 +- optd-cost-model/src/cost/filter/log_op.rs | 25 ++-- optd-cost-model/src/cost_model.rs | 3 +- optd-cost-model/src/lib.rs | 9 +- optd-cost-model/src/storage.rs | 43 +++++-- optd-persistent/Cargo.toml | 1 + optd-persistent/src/cost_model/interface.rs | 7 +- optd-persistent/src/cost_model/orm.rs | 32 +++-- 14 files changed, 249 insertions(+), 144 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index bf0b367..d47ecc6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2177,6 +2177,27 @@ dependencies = [ "libc", ] +[[package]] +name = "num_enum" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e613fc340b2220f734a8595782c551f1250e969d87d3be1ae0579e8d4065179" +dependencies = [ + "num_enum_derive", +] + +[[package]] +name = "num_enum_derive" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af1844ef2428cc3e1cb900be36181049ef3d3193c63e43026cfe202983b27a56" +dependencies = [ + "proc-macro-crate", + "proc-macro2", + "quote", + "syn 2.0.87", +] + [[package]] name = "object" version = "0.36.5" @@ -2237,6 +2258,7 @@ version = "0.1.0" dependencies = [ "async-stream", "async-trait", + "num_enum", "sea-orm", "sea-orm-migration", "serde_json", diff --git a/optd-cost-model/src/cost/agg.rs b/optd-cost-model/src/cost/agg.rs index a3de5aa..ff82fec 100644 --- a/optd-cost-model/src/cost/agg.rs +++ b/optd-cost-model/src/cost/agg.rs @@ -8,11 +8,11 @@ use crate::{ }, cost_model::CostModelImpl, stats::DEFAULT_NUM_DISTINCT, - CostModelError, CostModelResult, EstimatedStatistic, + CostModelError, CostModelResult, EstimatedStatistic, SemanticError, }; impl CostModelImpl { - pub fn get_agg_row_cnt( + pub async fn get_agg_row_cnt( &self, group_by: ArcPredicateNode, ) -> CostModelResult { @@ -22,22 +22,24 @@ impl CostModelImpl { } else { // Multiply the n-distinct of all the group by columns. // TODO: improve with multi-dimensional n-distinct - let row_cnt = group_by.0.children.iter().try_fold(1, |acc, node| { + let mut row_cnt = 1; + + for node in &group_by.0.children { match node.typ { PredicateType::AttributeRef => { let attr_ref = AttributeRefPred::from_pred_node(node.clone()).ok_or_else(|| { - CostModelError::InvalidPredicate( + SemanticError::InvalidPredicate( "Expected AttributeRef predicate".to_string(), ) })?; if attr_ref.is_derived() { - Ok(acc * DEFAULT_NUM_DISTINCT) + row_cnt *= DEFAULT_NUM_DISTINCT; } else { let table_id = attr_ref.table_id(); let attr_idx = attr_ref.attr_index(); let stats_option = - self.get_attribute_comb_stats(table_id, &[attr_idx])?; + self.get_attribute_comb_stats(table_id, &[attr_idx]).await?; let ndistinct = match stats_option { Some(stats) => stats.ndistinct, @@ -46,15 +48,15 @@ impl CostModelImpl { DEFAULT_NUM_DISTINCT } }; - Ok(acc * ndistinct) + row_cnt *= ndistinct; } } _ => { // TODO: Consider the case where `GROUP BY 1`. - panic!("GROUP BY must have attribute ref predicate") + panic!("GROUP BY must have attribute ref predicate"); } } - })?; + } Ok(EstimatedStatistic(row_cnt)) } } diff --git a/optd-cost-model/src/cost/filter/attribute.rs b/optd-cost-model/src/cost/filter/attribute.rs index b72802e..c5ad90c 100644 --- a/optd-cost-model/src/cost/filter/attribute.rs +++ b/optd-cost-model/src/cost/filter/attribute.rs @@ -19,7 +19,7 @@ impl CostModelImpl { /// Also, get_attribute_equality_selectivity is a subroutine when computing range /// selectivity, which is another reason for separating these into two functions /// is_eq means whether it's == or != - pub(crate) fn get_attribute_equality_selectivity( + pub(crate) async fn get_attribute_equality_selectivity( &self, table_id: TableId, attr_base_index: usize, @@ -28,8 +28,9 @@ impl CostModelImpl { ) -> CostModelResult { // TODO: The attribute could be a derived attribute let ret_sel = { - if let Some(attribute_stats) = - self.get_attribute_comb_stats(table_id, &[attr_base_index])? + if let Some(attribute_stats) = self + .get_attribute_comb_stats(table_id, &[attr_base_index]) + .await? { let eq_freq = if let Some(freq) = attribute_stats.mcvs.freq(&vec![Some(value.clone())]) { @@ -91,7 +92,7 @@ impl CostModelImpl { } /// Compute the frequency of values in a attribute less than the given value. - fn get_attribute_lt_value_freq( + async fn get_attribute_lt_value_freq( &self, attribute_stats: &AttributeCombValueStats, table_id: TableId, @@ -102,7 +103,9 @@ impl CostModelImpl { // into total_leq_cdf this logic just so happens to be the exact same logic as // get_attribute_equality_selectivity implements let ret_freq = Self::get_attribute_leq_value_freq(attribute_stats, value) - - self.get_attribute_equality_selectivity(table_id, attr_base_index, value, true)?; + - self + .get_attribute_equality_selectivity(table_id, attr_base_index, value, true) + .await?; assert!( (0.0..=1.0).contains(&ret_freq), "ret_freq ({}) should be in [0, 1]", @@ -116,7 +119,7 @@ impl CostModelImpl { /// Range predicates are handled entirely differently from equality predicates so this is its /// own function. If it is unable to find the statistics, it returns DEFAULT_INEQ_SEL. /// The selectivity is computed as quantile of the right bound minus quantile of the left bound. - pub(crate) fn get_attribute_range_selectivity( + pub(crate) async fn get_attribute_range_selectivity( &self, table_id: TableId, attr_base_index: usize, @@ -124,17 +127,21 @@ impl CostModelImpl { end: Bound<&Value>, ) -> CostModelResult { // TODO: Consider attribute is a derived attribute - if let Some(attribute_stats) = - self.get_attribute_comb_stats(table_id, &[attr_base_index])? + if let Some(attribute_stats) = self + .get_attribute_comb_stats(table_id, &[attr_base_index]) + .await? { let left_quantile = match start { Bound::Unbounded => 0.0, - Bound::Included(value) => self.get_attribute_lt_value_freq( - &attribute_stats, - table_id, - attr_base_index, - value, - )?, + Bound::Included(value) => { + self.get_attribute_lt_value_freq( + &attribute_stats, + table_id, + attr_base_index, + value, + ) + .await? + } Bound::Excluded(value) => { Self::get_attribute_leq_value_freq(&attribute_stats, value) } @@ -144,12 +151,15 @@ impl CostModelImpl { Bound::Included(value) => { Self::get_attribute_leq_value_freq(&attribute_stats, value) } - Bound::Excluded(value) => self.get_attribute_lt_value_freq( - &attribute_stats, - table_id, - attr_base_index, - value, - )?, + Bound::Excluded(value) => { + self.get_attribute_lt_value_freq( + &attribute_stats, + table_id, + attr_base_index, + value, + ) + .await? + } }; assert!( left_quantile <= right_quantile, diff --git a/optd-cost-model/src/cost/filter/comp_op.rs b/optd-cost-model/src/cost/filter/comp_op.rs index 7bde869..0a2092e 100644 --- a/optd-cost-model/src/cost/filter/comp_op.rs +++ b/optd-cost-model/src/cost/filter/comp_op.rs @@ -16,11 +16,12 @@ use crate::{ // compute the selectivity. stats::{DEFAULT_EQ_SEL, DEFAULT_INEQ_SEL, UNIMPLEMENTED_SEL}, CostModelResult, + SemanticError, }; impl CostModelImpl { /// Comparison operators are the base case for recursion in get_filter_selectivity() - pub(crate) fn get_comp_op_selectivity( + pub(crate) async fn get_comp_op_selectivity( &self, comp_bin_op_typ: BinOpType, left: ArcPredicateNode, @@ -30,8 +31,11 @@ impl CostModelImpl { // I intentionally performed moves on left and right. This way, we don't accidentally use // them after this block - let (attr_ref_exprs, values, non_attr_ref_exprs, is_left_attr_ref) = - self.get_semantic_nodes(left, right)?; + let semantic_res = self.get_semantic_nodes(left, right).await; + if semantic_res.is_err() { + return Ok(Self::get_default_comparison_op_selectivity(comp_bin_op_typ)); + } + let (attr_ref_exprs, values, non_attr_ref_exprs, is_left_attr_ref) = semantic_res.unwrap(); // Handle the different cases of semantic nodes. if attr_ref_exprs.is_empty() { @@ -51,13 +55,17 @@ impl CostModelImpl { match comp_bin_op_typ { BinOpType::Eq => { self.get_attribute_equality_selectivity(table_id, attr_ref_idx, value, true) + .await + } + BinOpType::Neq => { + self.get_attribute_equality_selectivity( + table_id, + attr_ref_idx, + value, + false, + ) + .await } - BinOpType::Neq => self.get_attribute_equality_selectivity( - table_id, - attr_ref_idx, - value, - false, - ), BinOpType::Lt | BinOpType::Leq | BinOpType::Gt | BinOpType::Geq => { let start = match (comp_bin_op_typ, is_left_attr_ref) { (BinOpType::Lt, true) | (BinOpType::Geq, false) => Bound::Unbounded, @@ -74,6 +82,7 @@ impl CostModelImpl { _ => unreachable!("all comparison BinOpTypes were enumerated. this should be unreachable"), }; self.get_attribute_range_selectivity(table_id, attr_ref_idx, start, end) + .await } _ => unreachable!( "all comparison BinOpTypes were enumerated. this should be unreachable" @@ -109,7 +118,7 @@ impl CostModelImpl { /// This is convenient to avoid repeating the same logic just with "left" and "right" swapped. /// The last return value is true when the input node (left) is a AttributeRefPred. #[allow(clippy::type_complexity)] - fn get_semantic_nodes( + async fn get_semantic_nodes( &self, left: ArcPredicateNode, right: ArcPredicateNode, @@ -175,11 +184,16 @@ impl CostModelImpl { // The "invert" cast is to invert the cast so that we're casting the // non_cast_node to the attribute's original type. // TODO(migration): double check - let invert_cast_data_type = &(self + // TODO: Consider attribute info is None. + let attribute_info = self .storage_manager - .get_attribute_info(table_id, attr_ref_idx as i32)? - .typ - .into_data_type()); + .get_attribute_info(table_id, attr_ref_idx as i32) + .await? + .ok_or({ + SemanticError::AttributeNotFound(table_id, attr_ref_idx as i32) + })?; + + let invert_cast_data_type = &attribute_info.typ.into_data_type(); match non_cast_node.typ { PredicateType::AttributeRef => { diff --git a/optd-cost-model/src/cost/filter/controller.rs b/optd-cost-model/src/cost/filter/controller.rs index c0ce8c9..5369fe5 100644 --- a/optd-cost-model/src/cost/filter/controller.rs +++ b/optd-cost-model/src/cost/filter/controller.rs @@ -13,73 +13,78 @@ use crate::{ impl CostModelImpl { // TODO: is it a good design to pass table_id here? I think it needs to be refactored. // Consider to remove table_id. - pub fn get_filter_row_cnt( + pub async fn get_filter_row_cnt( &self, child_row_cnt: EstimatedStatistic, cond: ArcPredicateNode, ) -> CostModelResult { - let selectivity = { self.get_filter_selectivity(cond)? }; + let selectivity = { self.get_filter_selectivity(cond).await? }; Ok( EstimatedStatistic((child_row_cnt.0 as f64 * selectivity) as u64) .max(EstimatedStatistic(1)), ) } - pub fn get_filter_selectivity(&self, expr_tree: ArcPredicateNode) -> CostModelResult { - match &expr_tree.typ { - PredicateType::Constant(_) => Ok(Self::get_constant_selectivity(expr_tree)), - PredicateType::AttributeRef => unimplemented!("check bool type or else panic"), - PredicateType::UnOp(un_op_typ) => { - assert!(expr_tree.children.len() == 1); - let child = expr_tree.child(0); - match un_op_typ { - // not doesn't care about nulls so there's no complex logic. it just reverses - // the selectivity for instance, != _will not_ include nulls - // but "NOT ==" _will_ include nulls - UnOpType::Not => Ok(1.0 - self.get_filter_selectivity(child)?), - UnOpType::Neg => panic!( - "the selectivity of operations that return numerical values is undefined" - ), + pub async fn get_filter_selectivity( + &self, + expr_tree: ArcPredicateNode, + ) -> CostModelResult { + Box::pin(async move { + match &expr_tree.typ { + PredicateType::Constant(_) => Ok(Self::get_constant_selectivity(expr_tree)), + PredicateType::AttributeRef => unimplemented!("check bool type or else panic"), + PredicateType::UnOp(un_op_typ) => { + assert!(expr_tree.children.len() == 1); + let child = expr_tree.child(0); + match un_op_typ { + // not doesn't care about nulls so there's no complex logic. it just reverses + // the selectivity for instance, != _will not_ include nulls + // but "NOT ==" _will_ include nulls + UnOpType::Not => Ok(1.0 - self.get_filter_selectivity(child).await?), + UnOpType::Neg => panic!( + "the selectivity of operations that return numerical values is undefined" + ), + } } - } - PredicateType::BinOp(bin_op_typ) => { - assert!(expr_tree.children.len() == 2); - let left_child = expr_tree.child(0); - let right_child = expr_tree.child(1); + PredicateType::BinOp(bin_op_typ) => { + assert!(expr_tree.children.len() == 2); + let left_child = expr_tree.child(0); + let right_child = expr_tree.child(1); - if bin_op_typ.is_comparison() { - self.get_comp_op_selectivity(*bin_op_typ, left_child, right_child) - } else if bin_op_typ.is_numerical() { - panic!( - "the selectivity of operations that return numerical values is undefined" - ) - } else { - unreachable!("all BinOpTypes should be true for at least one is_*() function") + if bin_op_typ.is_comparison() { + self.get_comp_op_selectivity(*bin_op_typ, left_child, right_child).await + } else if bin_op_typ.is_numerical() { + panic!( + "the selectivity of operations that return numerical values is undefined" + ) + } else { + unreachable!("all BinOpTypes should be true for at least one is_*() function") + } } + PredicateType::LogOp(log_op_typ) => { + self.get_log_op_selectivity(*log_op_typ, &expr_tree.children).await + } + PredicateType::Func(_) => unimplemented!("check bool type or else panic"), + PredicateType::SortOrder(_) => { + panic!("the selectivity of sort order expressions is undefined") + } + PredicateType::Between => Ok(UNIMPLEMENTED_SEL), + PredicateType::Cast => unimplemented!("check bool type or else panic"), + PredicateType::Like => { + let like_expr = LikePred::from_pred_node(expr_tree).unwrap(); + self.get_like_selectivity(&like_expr).await + } + PredicateType::DataType(_) => { + panic!("the selectivity of a data type is not defined") + } + PredicateType::InList => { + let in_list_expr = InListPred::from_pred_node(expr_tree).unwrap(); + self.get_in_list_selectivity(&in_list_expr).await + } + _ => unreachable!( + "all expression DfPredType were enumerated. this should be unreachable" + ), } - PredicateType::LogOp(log_op_typ) => { - self.get_log_op_selectivity(*log_op_typ, &expr_tree.children) - } - PredicateType::Func(_) => unimplemented!("check bool type or else panic"), - PredicateType::SortOrder(_) => { - panic!("the selectivity of sort order expressions is undefined") - } - PredicateType::Between => Ok(UNIMPLEMENTED_SEL), - PredicateType::Cast => unimplemented!("check bool type or else panic"), - PredicateType::Like => { - let like_expr = LikePred::from_pred_node(expr_tree).unwrap(); - self.get_like_selectivity(&like_expr) - } - PredicateType::DataType(_) => { - panic!("the selectivity of a data type is not defined") - } - PredicateType::InList => { - let in_list_expr = InListPred::from_pred_node(expr_tree).unwrap(); - self.get_in_list_selectivity(&in_list_expr) - } - _ => unreachable!( - "all expression DfPredType were enumerated. this should be unreachable" - ), - } + }).await } } diff --git a/optd-cost-model/src/cost/filter/in_list.rs b/optd-cost-model/src/cost/filter/in_list.rs index cc2a570..f1eed06 100644 --- a/optd-cost-model/src/cost/filter/in_list.rs +++ b/optd-cost-model/src/cost/filter/in_list.rs @@ -15,7 +15,7 @@ use crate::{ impl CostModelImpl { /// Only support attrA in (val1, val2, val3) where attrA is a attribute ref and /// val1, val2, val3 are constants. - pub(crate) fn get_in_list_selectivity(&self, expr: &InListPred) -> CostModelResult { + pub(crate) async fn get_in_list_selectivity(&self, expr: &InListPred) -> CostModelResult { let child = expr.child(); // Check child is a attribute ref. @@ -46,18 +46,19 @@ impl CostModelImpl { let negated = expr.negated(); // TODO: Consider attribute is a derived attribute - let in_sel = list_exprs - .iter() - .try_fold(0.0, |acc, expr| { - let selectivity = self.get_attribute_equality_selectivity( + let mut in_sel = 0.0; + for expr in &list_exprs { + let selectivity = self + .get_attribute_equality_selectivity( table_id, attr_ref_idx, &expr.value(), /* is_equality */ true, - )?; - Ok(acc + selectivity) - })? - .min(1.0); + ) + .await?; + in_sel += selectivity; + } + in_sel = in_sel.min(1.0); if negated { Ok(1.0 - in_sel) } else { diff --git a/optd-cost-model/src/cost/filter/like.rs b/optd-cost-model/src/cost/filter/like.rs index f8a1ab4..fe9214b 100644 --- a/optd-cost-model/src/cost/filter/like.rs +++ b/optd-cost-model/src/cost/filter/like.rs @@ -28,7 +28,7 @@ impl CostModelImpl { /// is composed of MCV frequency and non-MCV selectivity. MCV frequency is computed by /// adding up frequencies of MCVs that match the pattern. Non-MCV selectivity is computed /// in the same way that Postgres computes selectivity for the wildcard part of the pattern. - pub(crate) fn get_like_selectivity(&self, like_expr: &LikePred) -> CostModelResult { + pub(crate) async fn get_like_selectivity(&self, like_expr: &LikePred) -> CostModelResult { let child = like_expr.child(); // Check child is a attribute ref. @@ -67,7 +67,10 @@ impl CostModelImpl { // Compute the selectivity in MCVs. // TODO: Handle the case where `attribute_stats` is None. - if let Some(attribute_stats) = self.get_attribute_comb_stats(table_id, &[attr_ref_idx])? { + if let Some(attribute_stats) = self + .get_attribute_comb_stats(table_id, &[attr_ref_idx]) + .await? + { let (mcv_freq, null_frac) = { let pred = Box::new(move |val: &AttributeCombValue| { let string = diff --git a/optd-cost-model/src/cost/filter/log_op.rs b/optd-cost-model/src/cost/filter/log_op.rs index 46e6a21..63e7cd1 100644 --- a/optd-cost-model/src/cost/filter/log_op.rs +++ b/optd-cost-model/src/cost/filter/log_op.rs @@ -7,22 +7,27 @@ use crate::{ }; impl CostModelImpl { - pub(crate) fn get_log_op_selectivity( + pub(crate) async fn get_log_op_selectivity( &self, log_op_typ: LogOpType, children: &[ArcPredicateNode], ) -> CostModelResult { match log_op_typ { - LogOpType::And => children.iter().try_fold(1.0, |acc, child| { - let selectivity = self.get_filter_selectivity(child.clone())?; - Ok(acc * selectivity) - }), + LogOpType::And => { + let mut and_sel = 1.0; + for child in children { + let selectivity = self.get_filter_selectivity(child.clone()).await?; + and_sel *= selectivity; + } + Ok(and_sel) + } LogOpType::Or => { - let product = children.iter().try_fold(1.0, |acc, child| { - let selectivity = self.get_filter_selectivity(child.clone())?; - Ok(acc * (1.0 - selectivity)) - })?; - Ok(1.0 - product) + let mut or_sel_neg = 1.0; + for child in children { + let selectivity = self.get_filter_selectivity(child.clone()).await?; + or_sel_neg *= (1.0 - selectivity); + } + Ok(1.0 - or_sel_neg) } } } diff --git a/optd-cost-model/src/cost_model.rs b/optd-cost-model/src/cost_model.rs index ebb45c2..1942558 100644 --- a/optd-cost-model/src/cost_model.rs +++ b/optd-cost-model/src/cost_model.rs @@ -97,12 +97,13 @@ impl CostModelImpl { /// TODO: documentation /// TODO: if we have memory cache, /// we should add the reference. (&AttributeCombValueStats) - pub(crate) fn get_attribute_comb_stats( + pub(crate) async fn get_attribute_comb_stats( &self, table_id: TableId, attr_comb: &[usize], ) -> CostModelResult> { self.storage_manager .get_attributes_comb_statistics(table_id, attr_comb) + .await } } diff --git a/optd-cost-model/src/lib.rs b/optd-cost-model/src/lib.rs index a2afcb8..d4a24d2 100644 --- a/optd-cost-model/src/lib.rs +++ b/optd-cost-model/src/lib.rs @@ -42,13 +42,14 @@ pub enum SemanticError { UnknownStatisticType, VersionedStatisticNotFound, AttributeNotFound(TableId, i32), // (table_id, attribute_base_index) + // FIXME: not sure if this should be put here + InvalidPredicate(String), } #[derive(Debug)] pub enum CostModelError { ORMError(BackendError), SemanticError(SemanticError), - InvalidPredicate(String), } impl From for CostModelError { @@ -57,6 +58,12 @@ impl From for CostModelError { } } +impl From for CostModelError { + fn from(err: SemanticError) -> Self { + CostModelError::SemanticError(err) + } +} + pub trait CostModel: 'static + Send + Sync { /// TODO: documentation fn compute_operation_cost( diff --git a/optd-cost-model/src/storage.rs b/optd-cost-model/src/storage.rs index 1ee5d0e..67b813d 100644 --- a/optd-cost-model/src/storage.rs +++ b/optd-cost-model/src/storage.rs @@ -2,16 +2,24 @@ use std::sync::Arc; use optd_persistent::{ - cost_model::interface::{Attr, StatType}, + cost_model::interface::{AttrType, StatType}, CostModelStorageLayer, }; +use serde::{Deserialize, Serialize}; use crate::{ - common::types::TableId, + common::{predicates::constant_pred::ConstantType, types::TableId}, stats::{counter::Counter, AttributeCombValueStats, Distribution, MostCommonValues}, CostModelResult, }; +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct Attribute { + pub name: String, + pub typ: ConstantType, + pub nullable: bool, +} + /// TODO: documentation pub struct CostModelStorageManager { pub backend_manager: Arc, @@ -31,11 +39,24 @@ impl CostModelStorageManager { &self, table_id: TableId, attr_base_index: i32, - ) -> CostModelResult> { - Ok(self + ) -> CostModelResult> { + let attr = self .backend_manager .get_attribute(table_id.into(), attr_base_index) - .await?) + .await?; + match attr { + Some(attr) => Ok(Some(Attribute { + name: attr.name, + typ: match attr.attr_type { + AttrType::Integer => ConstantType::Int32, + AttrType::Float => ConstantType::Float64, + AttrType::Varchar => ConstantType::Utf8String, + AttrType::Boolean => ConstantType::Bool, + }, + nullable: attr.nullable, + })), + None => Ok(None), + } } /// Gets the latest statistics for a given table. @@ -53,13 +74,13 @@ impl CostModelStorageManager { pub async fn get_attributes_comb_statistics( &self, table_id: TableId, - attr_base_indices: &[i32], + attr_base_indices: &[usize], ) -> CostModelResult> { let dist: Option = self .backend_manager .get_stats_for_attr_indices_based( table_id.into(), - attr_base_indices.to_vec(), + attr_base_indices.iter().map(|&x| x as i32).collect(), StatType::Distribution, None, ) @@ -70,7 +91,7 @@ impl CostModelStorageManager { .backend_manager .get_stats_for_attr_indices_based( table_id.into(), - attr_base_indices.to_vec(), + attr_base_indices.iter().map(|&x| x as i32).collect(), StatType::MostCommonValues, None, ) @@ -82,7 +103,7 @@ impl CostModelStorageManager { .backend_manager .get_stats_for_attr_indices_based( table_id.into(), - attr_base_indices.to_vec(), + attr_base_indices.iter().map(|&x| x as i32).collect(), StatType::Cardinality, None, ) @@ -94,7 +115,7 @@ impl CostModelStorageManager { .backend_manager .get_stats_for_attr_indices_based( table_id.into(), - attr_base_indices.to_vec(), + attr_base_indices.iter().map(|&x| x as i32).collect(), StatType::TableRowCount, None, ) @@ -105,7 +126,7 @@ impl CostModelStorageManager { .backend_manager .get_stats_for_attr_indices_based( table_id.into(), - attr_base_indices.to_vec(), + attr_base_indices.iter().map(|&x| x as i32).collect(), StatType::NonNullCount, None, ) diff --git a/optd-persistent/Cargo.toml b/optd-persistent/Cargo.toml index e9b9905..5d03a14 100644 --- a/optd-persistent/Cargo.toml +++ b/optd-persistent/Cargo.toml @@ -21,3 +21,4 @@ trait-variant = "0.1.2" async-trait = "0.1.43" async-stream = "0.3.1" strum = "0.26.1" +num_enum = "0.7.3" diff --git a/optd-persistent/src/cost_model/interface.rs b/optd-persistent/src/cost_model/interface.rs index a03087f..598598d 100644 --- a/optd-persistent/src/cost_model/interface.rs +++ b/optd-persistent/src/cost_model/interface.rs @@ -4,6 +4,7 @@ use crate::entities::cascades_group; use crate::entities::logical_expression; use crate::entities::physical_expression; use crate::StorageResult; +use num_enum::{IntoPrimitive, TryFromPrimitive}; use sea_orm::prelude::Json; use sea_orm::*; use sea_orm_migration::prelude::*; @@ -24,8 +25,10 @@ pub enum CatalogSource { } /// TODO: documentation +#[repr(i32)] +#[derive(Copy, Clone, Debug, PartialEq, IntoPrimitive, TryFromPrimitive)] pub enum AttrType { - Integer, + Integer = 1, Float, Varchar, Boolean, @@ -96,7 +99,7 @@ pub struct Attr { pub table_id: i32, pub name: String, pub compression_method: String, - pub attr_type: i32, + pub attr_type: AttrType, pub base_index: i32, pub nullable: bool, } diff --git a/optd-persistent/src/cost_model/orm.rs b/optd-persistent/src/cost_model/orm.rs index d172c14..65d6035 100644 --- a/optd-persistent/src/cost_model/orm.rs +++ b/optd-persistent/src/cost_model/orm.rs @@ -14,7 +14,8 @@ use serde_json::json; use super::catalog::mock_catalog::{self, MockCatalog}; use super::interface::{ - Attr, AttrId, CatalogSource, EpochId, EpochOption, ExprId, Stat, StatId, StatType, TableId, + Attr, AttrId, AttrType, CatalogSource, EpochId, EpochOption, ExprId, Stat, StatId, StatType, + TableId, }; impl BackendManager { @@ -543,19 +544,28 @@ impl CostModelStorageLayer for BackendManager { table_id: TableId, attribute_base_index: i32, ) -> StorageResult> { - Ok(Attribute::find() + let attr_res = Attribute::find() .filter(attribute::Column::TableId.eq(table_id)) .filter(attribute::Column::BaseAttributeNumber.eq(attribute_base_index)) .one(&self.db) - .await? - .map(|attr| Attr { - table_id, - name: attr.name, - compression_method: attr.compression_method, - attr_type: attr.variant_tag, - base_index: attribute_base_index, - nullable: !attr.is_not_null, - })) + .await?; + match attr_res { + Some(attr) => match AttrType::try_from(attr.variant_tag) { + Ok(attr_type) => Ok(Some(Attr { + table_id: attr.table_id, + name: attr.name, + compression_method: attr.compression_method, + attr_type, + base_index: attr.base_attribute_number, + nullable: attr.is_not_null, + })), + Err(_) => Err(BackendError::BackendError(format!( + "Failed to convert variant tag {} to AttrType", + attr.variant_tag + ))), + }, + None => Ok(None), + } } }