From a533ce69c16feffabacbb3cdb3a72a54902205bf Mon Sep 17 00:00:00 2001 From: Lan Lou Date: Thu, 14 Nov 2024 10:11:25 -0500 Subject: [PATCH] Partial migration of filter --- optd-cost-model/src/common/nodes.rs | 27 +- .../src/common/predicates/attr_ref_pred.rs | 45 +++ .../src/common/predicates/cast_pred.rs | 44 +++ .../src/common/predicates/constant_pred.rs | 189 +++++++++++ .../src/common/predicates/data_type_pred.rs | 40 +++ optd-cost-model/src/common/predicates/mod.rs | 3 + optd-cost-model/src/cost/filter.rs | 306 ++++++++++++++++++ optd-cost-model/src/cost_model.rs | 4 +- optd-cost-model/src/lib.rs | 8 +- optd-cost-model/src/storage.rs | 25 ++ 10 files changed, 686 insertions(+), 5 deletions(-) create mode 100644 optd-cost-model/src/common/predicates/attr_ref_pred.rs create mode 100644 optd-cost-model/src/common/predicates/cast_pred.rs create mode 100644 optd-cost-model/src/common/predicates/data_type_pred.rs diff --git a/optd-cost-model/src/common/nodes.rs b/optd-cost-model/src/common/nodes.rs index 38e2500..0bbcca1 100644 --- a/optd-cost-model/src/common/nodes.rs +++ b/optd-cost-model/src/common/nodes.rs @@ -77,7 +77,7 @@ pub struct PredicateNode { /// A generic predicate node type pub typ: PredicateType, /// Child predicate nodes, always materialized - pub children: Vec, + pub children: Vec, /// Data associated with the predicate, if any pub data: Option, } @@ -94,3 +94,28 @@ impl std::fmt::Display for PredicateNode { write!(f, ")") } } + +impl PredicateNode { + pub fn child(&self, idx: usize) -> ArcPredicateNode { + self.children[idx].clone() + } + + pub fn unwrap_data(&self) -> Value { + self.data.clone().unwrap() + } +} +pub trait ReprPredicateNode: 'static + Clone { + fn into_pred_node(self) -> ArcPredicateNode; + + fn from_pred_node(pred_node: ArcPredicateNode) -> Option; +} + +impl ReprPredicateNode for ArcPredicateNode { + fn into_pred_node(self) -> ArcPredicateNode { + self + } + + fn from_pred_node(pred_node: ArcPredicateNode) -> Option { + Some(pred_node) + } +} diff --git a/optd-cost-model/src/common/predicates/attr_ref_pred.rs b/optd-cost-model/src/common/predicates/attr_ref_pred.rs new file mode 100644 index 0000000..34e1901 --- /dev/null +++ b/optd-cost-model/src/common/predicates/attr_ref_pred.rs @@ -0,0 +1,45 @@ +use crate::common::{ + nodes::{ArcPredicateNode, PredicateNode, PredicateType, ReprPredicateNode}, + values::Value, +}; + +#[derive(Clone, Debug)] +pub struct AttributeRefPred(pub ArcPredicateNode); + +impl AttributeRefPred { + /// Creates a new `ColumnRef` expression. + pub fn new(column_idx: usize) -> AttributeRefPred { + // this conversion is always safe since usize is at most u64 + let u64_column_idx = column_idx as u64; + AttributeRefPred( + PredicateNode { + typ: PredicateType::AttributeRef, + children: vec![], + data: Some(Value::UInt64(u64_column_idx)), + } + .into(), + ) + } + + fn get_data_usize(&self) -> usize { + self.0.data.as_ref().unwrap().as_u64() as usize + } + + /// Gets the column index. + pub fn index(&self) -> usize { + self.get_data_usize() + } +} + +impl ReprPredicateNode for AttributeRefPred { + fn into_pred_node(self) -> ArcPredicateNode { + self.0 + } + + fn from_pred_node(pred_node: ArcPredicateNode) -> Option { + if pred_node.typ != PredicateType::AttributeRef { + return None; + } + Some(Self(pred_node)) + } +} diff --git a/optd-cost-model/src/common/predicates/cast_pred.rs b/optd-cost-model/src/common/predicates/cast_pred.rs new file mode 100644 index 0000000..eaafca9 --- /dev/null +++ b/optd-cost-model/src/common/predicates/cast_pred.rs @@ -0,0 +1,44 @@ +use arrow_schema::DataType; + +use crate::common::nodes::{ArcPredicateNode, PredicateNode, PredicateType, ReprPredicateNode}; + +use super::data_type_pred::DataTypePred; + +#[derive(Clone, Debug)] +pub struct CastPred(pub ArcPredicateNode); + +impl CastPred { + pub fn new(child: ArcPredicateNode, cast_to: DataType) -> Self { + CastPred( + PredicateNode { + typ: PredicateType::Cast, + children: vec![child, DataTypePred::new(cast_to).into_pred_node()], + data: None, + } + .into(), + ) + } + + pub fn child(&self) -> ArcPredicateNode { + self.0.child(0) + } + + pub fn cast_to(&self) -> DataType { + DataTypePred::from_pred_node(self.0.child(1)) + .unwrap() + .data_type() + } +} + +impl ReprPredicateNode for CastPred { + fn into_pred_node(self) -> ArcPredicateNode { + self.0 + } + + fn from_pred_node(pred_node: ArcPredicateNode) -> Option { + if !matches!(pred_node.typ, PredicateType::Cast) { + return None; + } + Some(Self(pred_node)) + } +} diff --git a/optd-cost-model/src/common/predicates/constant_pred.rs b/optd-cost-model/src/common/predicates/constant_pred.rs index 7923ae4..2fa06ae 100644 --- a/optd-cost-model/src/common/predicates/constant_pred.rs +++ b/optd-cost-model/src/common/predicates/constant_pred.rs @@ -1,5 +1,13 @@ +use std::sync::Arc; + +use arrow_schema::{DataType, IntervalUnit}; use serde::{Deserialize, Serialize}; +use crate::common::{ + nodes::{ArcPredicateNode, PredicateNode, PredicateType, ReprPredicateNode}, + values::{SerializableOrderedF64, Value}, +}; + /// TODO: documentation #[derive(Copy, Clone, PartialEq, Eq, Hash, Debug, Serialize, Deserialize)] pub enum ConstantType { @@ -19,3 +27,184 @@ pub enum ConstantType { Decimal, Binary, } + +impl ConstantType { + pub fn get_data_type_from_value(value: &Value) -> Self { + match value { + Value::Bool(_) => ConstantType::Bool, + Value::String(_) => ConstantType::Utf8String, + Value::UInt8(_) => ConstantType::UInt8, + Value::UInt16(_) => ConstantType::UInt16, + Value::UInt32(_) => ConstantType::UInt32, + Value::UInt64(_) => ConstantType::UInt64, + Value::Int8(_) => ConstantType::Int8, + Value::Int16(_) => ConstantType::Int16, + Value::Int32(_) => ConstantType::Int32, + Value::Int64(_) => ConstantType::Int64, + Value::Float(_) => ConstantType::Float64, + Value::Date32(_) => ConstantType::Date, + _ => unimplemented!("get_data_type_from_value() not implemented for value {value}"), + } + } + + // TODO: current DataType and ConstantType are not 1 to 1 mapping + // optd schema stores constantType from data type in catalog.get + // for decimal128, the precision is lost + pub fn from_data_type(data_type: DataType) -> Self { + match data_type { + DataType::Binary => ConstantType::Binary, + DataType::Boolean => ConstantType::Bool, + DataType::UInt8 => ConstantType::UInt8, + DataType::UInt16 => ConstantType::UInt16, + DataType::UInt32 => ConstantType::UInt32, + DataType::UInt64 => ConstantType::UInt64, + DataType::Int8 => ConstantType::Int8, + DataType::Int16 => ConstantType::Int16, + DataType::Int32 => ConstantType::Int32, + DataType::Int64 => ConstantType::Int64, + DataType::Float64 => ConstantType::Float64, + DataType::Date32 => ConstantType::Date, + DataType::Interval(IntervalUnit::MonthDayNano) => ConstantType::IntervalMonthDateNano, + DataType::Utf8 => ConstantType::Utf8String, + DataType::Decimal128(_, _) => ConstantType::Decimal, + _ => unimplemented!("no conversion to ConstantType for DataType {data_type}"), + } + } + + pub fn into_data_type(&self) -> DataType { + match self { + ConstantType::Binary => DataType::Binary, + ConstantType::Bool => DataType::Boolean, + ConstantType::UInt8 => DataType::UInt8, + ConstantType::UInt16 => DataType::UInt16, + ConstantType::UInt32 => DataType::UInt32, + ConstantType::UInt64 => DataType::UInt64, + ConstantType::Int8 => DataType::Int8, + ConstantType::Int16 => DataType::Int16, + ConstantType::Int32 => DataType::Int32, + ConstantType::Int64 => DataType::Int64, + ConstantType::Float64 => DataType::Float64, + ConstantType::Date => DataType::Date32, + ConstantType::IntervalMonthDateNano => DataType::Interval(IntervalUnit::MonthDayNano), + ConstantType::Decimal => DataType::Float64, + ConstantType::Utf8String => DataType::Utf8, + } + } +} + +#[derive(Clone, Debug)] +pub struct ConstantPred(pub ArcPredicateNode); + +impl ConstantPred { + pub fn new(value: Value) -> Self { + let typ = ConstantType::get_data_type_from_value(&value); + Self::new_with_type(value, typ) + } + + pub fn new_with_type(value: Value, typ: ConstantType) -> Self { + ConstantPred( + PredicateNode { + typ: PredicateType::Constant(typ), + children: vec![], + data: Some(value), + } + .into(), + ) + } + + pub fn bool(value: bool) -> Self { + Self::new_with_type(Value::Bool(value), ConstantType::Bool) + } + + pub fn string(value: impl AsRef) -> Self { + Self::new_with_type( + Value::String(value.as_ref().into()), + ConstantType::Utf8String, + ) + } + + pub fn uint8(value: u8) -> Self { + Self::new_with_type(Value::UInt8(value), ConstantType::UInt8) + } + + pub fn uint16(value: u16) -> Self { + Self::new_with_type(Value::UInt16(value), ConstantType::UInt16) + } + + pub fn uint32(value: u32) -> Self { + Self::new_with_type(Value::UInt32(value), ConstantType::UInt32) + } + + pub fn uint64(value: u64) -> Self { + Self::new_with_type(Value::UInt64(value), ConstantType::UInt64) + } + + pub fn int8(value: i8) -> Self { + Self::new_with_type(Value::Int8(value), ConstantType::Int8) + } + + pub fn int16(value: i16) -> Self { + Self::new_with_type(Value::Int16(value), ConstantType::Int16) + } + + pub fn int32(value: i32) -> Self { + Self::new_with_type(Value::Int32(value), ConstantType::Int32) + } + + pub fn int64(value: i64) -> Self { + Self::new_with_type(Value::Int64(value), ConstantType::Int64) + } + + pub fn interval_month_day_nano(value: i128) -> Self { + Self::new_with_type(Value::Int128(value), ConstantType::IntervalMonthDateNano) + } + + pub fn float64(value: f64) -> Self { + Self::new_with_type( + Value::Float(SerializableOrderedF64(value.into())), + ConstantType::Float64, + ) + } + + pub fn date(value: i64) -> Self { + Self::new_with_type(Value::Int64(value), ConstantType::Date) + } + + pub fn decimal(value: f64) -> Self { + Self::new_with_type( + Value::Float(SerializableOrderedF64(value.into())), + ConstantType::Decimal, + ) + } + + pub fn serialized(value: Arc<[u8]>) -> Self { + Self::new_with_type(Value::Serialized(value), ConstantType::Binary) + } + + /// Gets the constant value. + pub fn value(&self) -> Value { + self.0.data.clone().unwrap() + } + + pub fn constant_type(&self) -> ConstantType { + if let PredicateType::Constant(typ) = self.0.typ { + typ + } else { + panic!("not a constant") + } + } +} + +impl ReprPredicateNode for ConstantPred { + fn into_pred_node(self) -> ArcPredicateNode { + self.0 + } + + fn from_pred_node(rel_node: ArcPredicateNode) -> Option { + if let PredicateType::Constant(_) = rel_node.typ { + Some(Self(rel_node)) + } else { + None + } + } +} diff --git a/optd-cost-model/src/common/predicates/data_type_pred.rs b/optd-cost-model/src/common/predicates/data_type_pred.rs new file mode 100644 index 0000000..fe29336 --- /dev/null +++ b/optd-cost-model/src/common/predicates/data_type_pred.rs @@ -0,0 +1,40 @@ +use arrow_schema::DataType; + +use crate::common::nodes::{ArcPredicateNode, PredicateNode, PredicateType, ReprPredicateNode}; + +#[derive(Clone, Debug)] +pub struct DataTypePred(pub ArcPredicateNode); + +impl DataTypePred { + pub fn new(typ: DataType) -> Self { + DataTypePred( + PredicateNode { + typ: PredicateType::DataType(typ), + children: vec![], + data: None, + } + .into(), + ) + } + + pub fn data_type(&self) -> DataType { + if let PredicateType::DataType(ref data_type) = self.0.typ { + data_type.clone() + } else { + panic!("not a data type") + } + } +} + +impl ReprPredicateNode for DataTypePred { + fn into_pred_node(self) -> ArcPredicateNode { + self.0 + } + + fn from_pred_node(pred_node: ArcPredicateNode) -> Option { + if !matches!(pred_node.typ, PredicateType::DataType(_)) { + return None; + } + Some(Self(pred_node)) + } +} diff --git a/optd-cost-model/src/common/predicates/mod.rs b/optd-cost-model/src/common/predicates/mod.rs index 87e6e94..d733198 100644 --- a/optd-cost-model/src/common/predicates/mod.rs +++ b/optd-cost-model/src/common/predicates/mod.rs @@ -1,5 +1,8 @@ +pub mod attr_ref_pred; pub mod bin_op_pred; +pub mod cast_pred; pub mod constant_pred; +pub mod data_type_pred; pub mod func_pred; pub mod log_op_pred; pub mod sort_order_pred; diff --git a/optd-cost-model/src/cost/filter.rs b/optd-cost-model/src/cost/filter.rs index 8b13789..056b0c1 100644 --- a/optd-cost-model/src/cost/filter.rs +++ b/optd-cost-model/src/cost/filter.rs @@ -1 +1,307 @@ +#![allow(unused_variables)] +use optd_persistent::CostModelStorageLayer; +use crate::{ + common::{ + nodes::{ArcPredicateNode, PredicateType, ReprPredicateNode}, + predicates::{ + attr_ref_pred::AttributeRefPred, + bin_op_pred::BinOpType, + cast_pred::CastPred, + constant_pred::{ConstantPred, ConstantType}, + un_op_pred::UnOpType, + }, + values::Value, + }, + cost_model::CostModelImpl, + CostModelResult, EstimatedStatistic, +}; + +// A placeholder for unimplemented!() for codepaths which are accessed by plannertest +const UNIMPLEMENTED_SEL: f64 = 0.01; +// Default statistics. All are from selfuncs.h in Postgres unless specified otherwise +// Default selectivity estimate for equalities such as "A = b" +const DEFAULT_EQ_SEL: f64 = 0.005; +// Default selectivity estimate for inequalities such as "A < b" +const DEFAULT_INEQ_SEL: f64 = 0.3333333333333333; + +impl CostModelImpl { + pub fn get_filter_row_cnt( + &self, + child_row_cnt: EstimatedStatistic, + table_id: i32, + cond: ArcPredicateNode, + ) -> CostModelResult { + let selectivity = { self.get_filter_selectivity(cond, table_id)? }; + Ok( + EstimatedStatistic((child_row_cnt.0 as f64 * selectivity) as u64) + .max(EstimatedStatistic(1)), + ) + } + + pub fn get_filter_selectivity( + &self, + expr_tree: ArcPredicateNode, + table_id: i32, + ) -> CostModelResult { + match &expr_tree.typ { + PredicateType::Constant(_) => Ok(Self::get_constant_selectivity(expr_tree)), + PredicateType::AttributeRef => unimplemented!("check bool type or else panic"), + PredicateType::UnOp(un_op_typ) => { + assert!(expr_tree.children.len() == 1); + let child = expr_tree.child(0); + match un_op_typ { + // not doesn't care about nulls so there's no complex logic. it just reverses + // the selectivity for instance, != _will not_ include nulls + // but "NOT ==" _will_ include nulls + UnOpType::Not => Ok(1.0 - self.get_filter_selectivity(child, table_id)?), + UnOpType::Neg => panic!( + "the selectivity of operations that return numerical values is undefined" + ), + } + } + PredicateType::BinOp(bin_op_typ) => { + assert!(expr_tree.children.len() == 2); + let left_child = expr_tree.child(0); + let right_child = expr_tree.child(1); + + if bin_op_typ.is_comparison() { + self.get_comp_op_selectivity(*bin_op_typ, left_child, right_child, table_id) + } else if bin_op_typ.is_numerical() { + panic!( + "the selectivity of operations that return numerical values is undefined" + ) + } else { + unreachable!("all BinOpTypes should be true for at least one is_*() function") + } + } + _ => unimplemented!("check bool type or else panic"), + } + } + + fn get_constant_selectivity(const_node: ArcPredicateNode) -> f64 { + if let PredicateType::Constant(const_typ) = const_node.typ { + if matches!(const_typ, ConstantType::Bool) { + let value = const_node + .as_ref() + .data + .as_ref() + .expect("constants should have data"); + if let Value::Bool(bool_value) = value { + if *bool_value { + 1.0 + } else { + 0.0 + } + } else { + unreachable!( + "if the typ is ConstantType::Bool, the value should be a Value::Bool" + ) + } + } else { + panic!("selectivity is not defined on constants which are not bools") + } + } else { + panic!("get_constant_selectivity must be called on a constant") + } + } + + /// Comparison operators are the base case for recursion in get_filter_selectivity() + fn get_comp_op_selectivity( + &self, + comp_bin_op_typ: BinOpType, + left: ArcPredicateNode, + right: ArcPredicateNode, + table_id: i32, + ) -> CostModelResult { + assert!(comp_bin_op_typ.is_comparison()); + + // I intentionally performed moves on left and right. This way, we don't accidentally use + // them after this block + let (col_ref_exprs, values, non_col_ref_exprs, is_left_col_ref) = + self.get_semantic_nodes(left, right, table_id)?; + + // Handle the different cases of semantic nodes. + if col_ref_exprs.is_empty() { + Ok(UNIMPLEMENTED_SEL) + } else if col_ref_exprs.len() == 1 { + let col_ref_expr = col_ref_exprs + .first() + .expect("we just checked that col_ref_exprs.len() == 1"); + let col_ref_idx = col_ref_expr.index(); + + todo!() + } else if col_ref_exprs.len() == 2 { + Ok(Self::get_default_comparison_op_selectivity(comp_bin_op_typ)) + } else { + unreachable!("we could have at most pushed left and right into col_ref_exprs") + } + } + + /// Convert the left and right child nodes of some operation to what they semantically are. + /// This is convenient to avoid repeating the same logic just with "left" and "right" swapped. + /// The last return value is true when the input node (left) is a ColumnRefPred. + #[allow(clippy::type_complexity)] + fn get_semantic_nodes( + &self, + left: ArcPredicateNode, + right: ArcPredicateNode, + table_id: i32, + ) -> CostModelResult<( + Vec, + Vec, + Vec, + bool, + )> { + let mut col_ref_exprs = vec![]; + let mut values = vec![]; + let mut non_col_ref_exprs = vec![]; + let is_left_col_ref; + + // Recursively unwrap casts as much as we can. + let mut uncasted_left = left; + let mut uncasted_right = right; + loop { + // println!("loop {}, uncasted_left={:?}, uncasted_right={:?}", Local::now(), + // uncasted_left, uncasted_right); + if uncasted_left.as_ref().typ == PredicateType::Cast + && uncasted_right.as_ref().typ == PredicateType::Cast + { + let left_cast_expr = CastPred::from_pred_node(uncasted_left) + .expect("we already checked that the type is Cast"); + let right_cast_expr = CastPred::from_pred_node(uncasted_right) + .expect("we already checked that the type is Cast"); + assert!(left_cast_expr.cast_to() == right_cast_expr.cast_to()); + uncasted_left = left_cast_expr.child().into_pred_node(); + uncasted_right = right_cast_expr.child().into_pred_node(); + } else if uncasted_left.as_ref().typ == PredicateType::Cast + || uncasted_right.as_ref().typ == PredicateType::Cast + { + let is_left_cast = uncasted_left.as_ref().typ == PredicateType::Cast; + let (mut cast_node, mut non_cast_node) = if is_left_cast { + (uncasted_left, uncasted_right) + } else { + (uncasted_right, uncasted_left) + }; + + let cast_expr = CastPred::from_pred_node(cast_node) + .expect("we already checked that the type is Cast"); + let cast_expr_child = cast_expr.child().into_pred_node(); + let cast_expr_cast_to = cast_expr.cast_to(); + + let should_break = match cast_expr_child.typ { + PredicateType::Constant(_) => { + cast_node = ConstantPred::new( + ConstantPred::from_pred_node(cast_expr_child) + .expect("we already checked that the type is Constant") + .value() + .convert_to_type(cast_expr_cast_to), + ) + .into_pred_node(); + false + } + PredicateType::AttributeRef => { + let col_ref_expr = AttributeRefPred::from_pred_node(cast_expr_child) + .expect("we already checked that the type is ColumnRef"); + let col_ref_idx = col_ref_expr.index(); + cast_node = col_ref_expr.into_pred_node(); + // The "invert" cast is to invert the cast so that we're casting the + // non_cast_node to the column's original type. + // TODO(migration): double check + let invert_cast_data_type = &(self + .storage_manager + .get_attribute_info(table_id, col_ref_idx as i32)? + .typ + .into_data_type()); + + match non_cast_node.typ { + PredicateType::AttributeRef => { + // In general, there's no way to remove the Cast here. We can't move + // the Cast to the other ColumnRef + // because that would lead to an infinite loop. Thus, we just leave + // the cast where it is and break. + true + } + _ => { + non_cast_node = + CastPred::new(non_cast_node, invert_cast_data_type.clone()) + .into_pred_node(); + false + } + } + } + _ => todo!(), + }; + + (uncasted_left, uncasted_right) = if is_left_cast { + (cast_node, non_cast_node) + } else { + (non_cast_node, cast_node) + }; + + if should_break { + break; + } + } else { + break; + } + } + + // Sort nodes into col_ref_exprs, values, and non_col_ref_exprs + match uncasted_left.as_ref().typ { + PredicateType::AttributeRef => { + is_left_col_ref = true; + col_ref_exprs.push( + AttributeRefPred::from_pred_node(uncasted_left) + .expect("we already checked that the type is ColumnRef"), + ); + } + PredicateType::Constant(_) => { + is_left_col_ref = false; + values.push( + ConstantPred::from_pred_node(uncasted_left) + .expect("we already checked that the type is Constant") + .value(), + ) + } + _ => { + is_left_col_ref = false; + non_col_ref_exprs.push(uncasted_left); + } + } + match uncasted_right.as_ref().typ { + PredicateType::AttributeRef => { + col_ref_exprs.push( + AttributeRefPred::from_pred_node(uncasted_right) + .expect("we already checked that the type is ColumnRef"), + ); + } + PredicateType::Constant(_) => values.push( + ConstantPred::from_pred_node(uncasted_right) + .expect("we already checked that the type is Constant") + .value(), + ), + _ => { + non_col_ref_exprs.push(uncasted_right); + } + } + + assert!(col_ref_exprs.len() + values.len() + non_col_ref_exprs.len() == 2); + Ok((col_ref_exprs, values, non_col_ref_exprs, is_left_col_ref)) + } + + /// The default selectivity of a comparison expression + /// Used when one side of the comparison is a column while the other side is something too + /// complex/impossible to evaluate (subquery, UDF, another column, we have no stats, etc.) + fn get_default_comparison_op_selectivity(comp_bin_op_typ: BinOpType) -> f64 { + assert!(comp_bin_op_typ.is_comparison()); + match comp_bin_op_typ { + BinOpType::Eq => DEFAULT_EQ_SEL, + BinOpType::Neq => 1.0 - DEFAULT_EQ_SEL, + BinOpType::Lt | BinOpType::Leq | BinOpType::Gt | BinOpType::Geq => DEFAULT_INEQ_SEL, + _ => unreachable!( + "all comparison BinOpTypes were enumerated. this should be unreachable" + ), + } + } +} diff --git a/optd-cost-model/src/cost_model.rs b/optd-cost-model/src/cost_model.rs index c0b0677..0b1760e 100644 --- a/optd-cost-model/src/cost_model.rs +++ b/optd-cost-model/src/cost_model.rs @@ -18,8 +18,8 @@ use crate::{ /// TODO: documentation pub struct CostModelImpl { - storage_manager: CostModelStorageManager, - default_catalog_source: CatalogSource, + pub storage_manager: CostModelStorageManager, + pub default_catalog_source: CatalogSource, } impl CostModelImpl { diff --git a/optd-cost-model/src/lib.rs b/optd-cost-model/src/lib.rs index 6aeb476..e950ae6 100644 --- a/optd-cost-model/src/lib.rs +++ b/optd-cost-model/src/lib.rs @@ -2,7 +2,10 @@ use common::{ nodes::{ArcPredicateNode, PhysicalNodeType}, types::{AttrId, EpochId, ExprId, GroupId, TableId}, }; -use optd_persistent::cost_model::interface::{Stat, StatType}; +use optd_persistent::{ + cost_model::interface::{Stat, StatType}, + BackendError, +}; pub mod common; pub mod cost; @@ -27,6 +30,7 @@ pub struct Cost(pub Vec); /// Estimated statistic calculated by the cost model. /// It is the estimated output row count of the targeted expression. +#[derive(Eq, Ord, PartialEq, PartialOrd)] pub struct EstimatedStatistic(pub u64); pub type CostModelResult = Result; @@ -34,7 +38,7 @@ pub type CostModelResult = Result; #[derive(Debug)] pub enum CostModelError { // TODO: Add more error types - ORMError, + ORMError(BackendError), } pub trait CostModel: 'static + Send + Sync { diff --git a/optd-cost-model/src/storage.rs b/optd-cost-model/src/storage.rs index 2a53d7c..e8bea67 100644 --- a/optd-cost-model/src/storage.rs +++ b/optd-cost-model/src/storage.rs @@ -1,6 +1,27 @@ +#![allow(unused_variables)] use std::sync::Arc; use optd_persistent::CostModelStorageLayer; +use serde::{Deserialize, Serialize}; + +use crate::{common::predicates::constant_pred::ConstantType, CostModelResult}; + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct Field { + pub name: String, + pub typ: ConstantType, + pub nullable: bool, +} + +impl std::fmt::Display for Field { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if self.nullable { + write!(f, "{}:{:?}", self.name, self.typ) + } else { + write!(f, "{}:{:?}(non-null)", self.name, self.typ) + } + } +} /// TODO: documentation pub struct CostModelStorageManager { @@ -13,4 +34,8 @@ impl CostModelStorageManager { pub fn new(backend_manager: Arc) -> Self { Self { backend_manager } } + + pub fn get_attribute_info(&self, table_id: i32, attribute_id: i32) -> CostModelResult { + todo!() + } }