Skip to content

Commit

Permalink
Partial migration of filter
Browse files Browse the repository at this point in the history
  • Loading branch information
lanlou1554 committed Nov 14, 2024
1 parent 21d01ae commit a533ce6
Show file tree
Hide file tree
Showing 10 changed files with 686 additions and 5 deletions.
27 changes: 26 additions & 1 deletion optd-cost-model/src/common/nodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ pub struct PredicateNode {
/// A generic predicate node type
pub typ: PredicateType,
/// Child predicate nodes, always materialized
pub children: Vec<PredicateNode>,
pub children: Vec<ArcPredicateNode>,
/// Data associated with the predicate, if any
pub data: Option<Value>,
}
Expand All @@ -94,3 +94,28 @@ impl std::fmt::Display for PredicateNode {
write!(f, ")")
}
}

impl PredicateNode {
pub fn child(&self, idx: usize) -> ArcPredicateNode {
self.children[idx].clone()
}

pub fn unwrap_data(&self) -> Value {
self.data.clone().unwrap()
}
}
pub trait ReprPredicateNode: 'static + Clone {
fn into_pred_node(self) -> ArcPredicateNode;

fn from_pred_node(pred_node: ArcPredicateNode) -> Option<Self>;
}

impl ReprPredicateNode for ArcPredicateNode {
fn into_pred_node(self) -> ArcPredicateNode {
self
}

fn from_pred_node(pred_node: ArcPredicateNode) -> Option<Self> {
Some(pred_node)
}
}
45 changes: 45 additions & 0 deletions optd-cost-model/src/common/predicates/attr_ref_pred.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
use crate::common::{
nodes::{ArcPredicateNode, PredicateNode, PredicateType, ReprPredicateNode},
values::Value,
};

#[derive(Clone, Debug)]
pub struct AttributeRefPred(pub ArcPredicateNode);

impl AttributeRefPred {
/// Creates a new `ColumnRef` expression.
pub fn new(column_idx: usize) -> AttributeRefPred {
// this conversion is always safe since usize is at most u64
let u64_column_idx = column_idx as u64;
AttributeRefPred(
PredicateNode {
typ: PredicateType::AttributeRef,
children: vec![],
data: Some(Value::UInt64(u64_column_idx)),
}
.into(),
)
}

fn get_data_usize(&self) -> usize {
self.0.data.as_ref().unwrap().as_u64() as usize
}

/// Gets the column index.
pub fn index(&self) -> usize {
self.get_data_usize()
}
}

impl ReprPredicateNode for AttributeRefPred {
fn into_pred_node(self) -> ArcPredicateNode {
self.0
}

fn from_pred_node(pred_node: ArcPredicateNode) -> Option<Self> {
if pred_node.typ != PredicateType::AttributeRef {
return None;
}
Some(Self(pred_node))
}
}
44 changes: 44 additions & 0 deletions optd-cost-model/src/common/predicates/cast_pred.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
use arrow_schema::DataType;

use crate::common::nodes::{ArcPredicateNode, PredicateNode, PredicateType, ReprPredicateNode};

use super::data_type_pred::DataTypePred;

#[derive(Clone, Debug)]
pub struct CastPred(pub ArcPredicateNode);

impl CastPred {
pub fn new(child: ArcPredicateNode, cast_to: DataType) -> Self {
CastPred(
PredicateNode {
typ: PredicateType::Cast,
children: vec![child, DataTypePred::new(cast_to).into_pred_node()],
data: None,
}
.into(),
)
}

pub fn child(&self) -> ArcPredicateNode {
self.0.child(0)
}

pub fn cast_to(&self) -> DataType {
DataTypePred::from_pred_node(self.0.child(1))
.unwrap()
.data_type()
}
}

impl ReprPredicateNode for CastPred {
fn into_pred_node(self) -> ArcPredicateNode {
self.0
}

fn from_pred_node(pred_node: ArcPredicateNode) -> Option<Self> {
if !matches!(pred_node.typ, PredicateType::Cast) {
return None;
}
Some(Self(pred_node))
}
}
189 changes: 189 additions & 0 deletions optd-cost-model/src/common/predicates/constant_pred.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
use std::sync::Arc;

use arrow_schema::{DataType, IntervalUnit};
use serde::{Deserialize, Serialize};

use crate::common::{
nodes::{ArcPredicateNode, PredicateNode, PredicateType, ReprPredicateNode},
values::{SerializableOrderedF64, Value},
};

/// TODO: documentation
#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug, Serialize, Deserialize)]
pub enum ConstantType {
Expand All @@ -19,3 +27,184 @@ pub enum ConstantType {
Decimal,
Binary,
}

impl ConstantType {
pub fn get_data_type_from_value(value: &Value) -> Self {
match value {
Value::Bool(_) => ConstantType::Bool,
Value::String(_) => ConstantType::Utf8String,
Value::UInt8(_) => ConstantType::UInt8,
Value::UInt16(_) => ConstantType::UInt16,
Value::UInt32(_) => ConstantType::UInt32,
Value::UInt64(_) => ConstantType::UInt64,
Value::Int8(_) => ConstantType::Int8,
Value::Int16(_) => ConstantType::Int16,
Value::Int32(_) => ConstantType::Int32,
Value::Int64(_) => ConstantType::Int64,
Value::Float(_) => ConstantType::Float64,
Value::Date32(_) => ConstantType::Date,
_ => unimplemented!("get_data_type_from_value() not implemented for value {value}"),
}
}

// TODO: current DataType and ConstantType are not 1 to 1 mapping
// optd schema stores constantType from data type in catalog.get
// for decimal128, the precision is lost
pub fn from_data_type(data_type: DataType) -> Self {
match data_type {
DataType::Binary => ConstantType::Binary,
DataType::Boolean => ConstantType::Bool,
DataType::UInt8 => ConstantType::UInt8,
DataType::UInt16 => ConstantType::UInt16,
DataType::UInt32 => ConstantType::UInt32,
DataType::UInt64 => ConstantType::UInt64,
DataType::Int8 => ConstantType::Int8,
DataType::Int16 => ConstantType::Int16,
DataType::Int32 => ConstantType::Int32,
DataType::Int64 => ConstantType::Int64,
DataType::Float64 => ConstantType::Float64,
DataType::Date32 => ConstantType::Date,
DataType::Interval(IntervalUnit::MonthDayNano) => ConstantType::IntervalMonthDateNano,
DataType::Utf8 => ConstantType::Utf8String,
DataType::Decimal128(_, _) => ConstantType::Decimal,
_ => unimplemented!("no conversion to ConstantType for DataType {data_type}"),
}
}

pub fn into_data_type(&self) -> DataType {
match self {
ConstantType::Binary => DataType::Binary,
ConstantType::Bool => DataType::Boolean,
ConstantType::UInt8 => DataType::UInt8,
ConstantType::UInt16 => DataType::UInt16,
ConstantType::UInt32 => DataType::UInt32,
ConstantType::UInt64 => DataType::UInt64,
ConstantType::Int8 => DataType::Int8,
ConstantType::Int16 => DataType::Int16,
ConstantType::Int32 => DataType::Int32,
ConstantType::Int64 => DataType::Int64,
ConstantType::Float64 => DataType::Float64,
ConstantType::Date => DataType::Date32,
ConstantType::IntervalMonthDateNano => DataType::Interval(IntervalUnit::MonthDayNano),
ConstantType::Decimal => DataType::Float64,
ConstantType::Utf8String => DataType::Utf8,
}
}
}

#[derive(Clone, Debug)]
pub struct ConstantPred(pub ArcPredicateNode);

impl ConstantPred {
pub fn new(value: Value) -> Self {
let typ = ConstantType::get_data_type_from_value(&value);
Self::new_with_type(value, typ)
}

pub fn new_with_type(value: Value, typ: ConstantType) -> Self {
ConstantPred(
PredicateNode {
typ: PredicateType::Constant(typ),
children: vec![],
data: Some(value),
}
.into(),
)
}

pub fn bool(value: bool) -> Self {
Self::new_with_type(Value::Bool(value), ConstantType::Bool)
}

pub fn string(value: impl AsRef<str>) -> Self {
Self::new_with_type(
Value::String(value.as_ref().into()),
ConstantType::Utf8String,
)
}

pub fn uint8(value: u8) -> Self {
Self::new_with_type(Value::UInt8(value), ConstantType::UInt8)
}

pub fn uint16(value: u16) -> Self {
Self::new_with_type(Value::UInt16(value), ConstantType::UInt16)
}

pub fn uint32(value: u32) -> Self {
Self::new_with_type(Value::UInt32(value), ConstantType::UInt32)
}

pub fn uint64(value: u64) -> Self {
Self::new_with_type(Value::UInt64(value), ConstantType::UInt64)
}

pub fn int8(value: i8) -> Self {
Self::new_with_type(Value::Int8(value), ConstantType::Int8)
}

pub fn int16(value: i16) -> Self {
Self::new_with_type(Value::Int16(value), ConstantType::Int16)
}

pub fn int32(value: i32) -> Self {
Self::new_with_type(Value::Int32(value), ConstantType::Int32)
}

pub fn int64(value: i64) -> Self {
Self::new_with_type(Value::Int64(value), ConstantType::Int64)
}

pub fn interval_month_day_nano(value: i128) -> Self {
Self::new_with_type(Value::Int128(value), ConstantType::IntervalMonthDateNano)
}

pub fn float64(value: f64) -> Self {
Self::new_with_type(
Value::Float(SerializableOrderedF64(value.into())),
ConstantType::Float64,
)
}

pub fn date(value: i64) -> Self {
Self::new_with_type(Value::Int64(value), ConstantType::Date)
}

pub fn decimal(value: f64) -> Self {
Self::new_with_type(
Value::Float(SerializableOrderedF64(value.into())),
ConstantType::Decimal,
)
}

pub fn serialized(value: Arc<[u8]>) -> Self {
Self::new_with_type(Value::Serialized(value), ConstantType::Binary)
}

/// Gets the constant value.
pub fn value(&self) -> Value {
self.0.data.clone().unwrap()
}

pub fn constant_type(&self) -> ConstantType {
if let PredicateType::Constant(typ) = self.0.typ {
typ
} else {
panic!("not a constant")
}
}
}

impl ReprPredicateNode for ConstantPred {
fn into_pred_node(self) -> ArcPredicateNode {
self.0
}

fn from_pred_node(rel_node: ArcPredicateNode) -> Option<Self> {
if let PredicateType::Constant(_) = rel_node.typ {
Some(Self(rel_node))
} else {
None
}
}
}
40 changes: 40 additions & 0 deletions optd-cost-model/src/common/predicates/data_type_pred.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
use arrow_schema::DataType;

use crate::common::nodes::{ArcPredicateNode, PredicateNode, PredicateType, ReprPredicateNode};

#[derive(Clone, Debug)]
pub struct DataTypePred(pub ArcPredicateNode);

impl DataTypePred {
pub fn new(typ: DataType) -> Self {
DataTypePred(
PredicateNode {
typ: PredicateType::DataType(typ),
children: vec![],
data: None,
}
.into(),
)
}

pub fn data_type(&self) -> DataType {
if let PredicateType::DataType(ref data_type) = self.0.typ {
data_type.clone()
} else {
panic!("not a data type")
}
}
}

impl ReprPredicateNode for DataTypePred {
fn into_pred_node(self) -> ArcPredicateNode {
self.0
}

fn from_pred_node(pred_node: ArcPredicateNode) -> Option<Self> {
if !matches!(pred_node.typ, PredicateType::DataType(_)) {
return None;
}
Some(Self(pred_node))
}
}
3 changes: 3 additions & 0 deletions optd-cost-model/src/common/predicates/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
pub mod attr_ref_pred;
pub mod bin_op_pred;
pub mod cast_pred;
pub mod constant_pred;
pub mod data_type_pred;
pub mod func_pred;
pub mod log_op_pred;
pub mod sort_order_pred;
Expand Down
Loading

0 comments on commit a533ce6

Please sign in to comment.