Skip to content
This repository has been archived by the owner on Jun 6, 2024. It is now read-only.

Commit

Permalink
Merge pull request #3 from cmu-db/filter-operator
Browse files Browse the repository at this point in the history
Filter operator
  • Loading branch information
ktbooker authored Feb 25, 2024
2 parents 5834f26 + d6fbeaf commit 337101e
Show file tree
Hide file tree
Showing 6 changed files with 154 additions and 12 deletions.
Empty file added eggstrain/cargo
Empty file.
77 changes: 77 additions & 0 deletions eggstrain/src/execution/operators/filter.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
use super::{Operator, UnaryOperator};
use arrow::compute::filter_record_batch;
use arrow::record_batch::RecordBatch;
use async_trait::async_trait;
use datafusion::common::cast::as_boolean_array;
use datafusion::physical_plan::ExecutionPlan;
use datafusion::physical_plan::PhysicalExpr;
use datafusion_common::Result;
use std::sync::Arc;
use tokio::sync::broadcast;
use tokio::sync::broadcast::error::RecvError;

pub struct Filter {
pub predicate: Arc<dyn PhysicalExpr>,
pub children: Vec<Arc<dyn ExecutionPlan>>,
}

impl Filter {
pub fn new(predicate: Arc<dyn PhysicalExpr>, children: Vec<Arc<dyn ExecutionPlan>>) -> Self {
Self {
predicate,
children,
}
}

/// https://docs.rs/datafusion-physical-plan/36.0.0/src/datafusion_physical_plan/filter.rs.html#307
pub fn batch_filter(&self, batch: RecordBatch) -> Result<RecordBatch> {
self.predicate
.evaluate(&batch)
.and_then(|v| v.into_array(batch.num_rows()))
.and_then(|array| {
Ok(as_boolean_array(&array)?)
// apply filter array to record batch
.and_then(|filter_array| Ok(filter_record_batch(&batch, filter_array)?))
})
}
}

impl Operator for Filter {
fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
self.children.clone()
}
}

#[async_trait]
impl UnaryOperator for Filter {
type In = RecordBatch;
type Out = RecordBatch;

fn into_unary(self) -> Arc<dyn UnaryOperator<In = Self::In, Out = Self::Out>> {
Arc::new(self)
}

async fn execute(
&self,
mut rx: broadcast::Receiver<Self::In>,
tx: broadcast::Sender<Self::Out>,
) {
loop {
match rx.recv().await {
Ok(batch) => {
let filtered_batch = self
.batch_filter(batch)
.expect("Filter::batch_filter() fails");

if filtered_batch.num_rows() > 0 {
tx.send(filtered_batch).expect("tx.send() fails");
}
}
Err(e) => match e {
RecvError::Closed => break,
RecvError::Lagged(_) => todo!(),
},
}
}
}
}
1 change: 1 addition & 0 deletions eggstrain/src/execution/operators/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use datafusion::physical_plan::ExecutionPlan;
use std::sync::Arc;
use tokio::sync::broadcast::{Receiver, Sender};

pub mod filter;
pub mod project;

pub trait Operator {
Expand Down
77 changes: 65 additions & 12 deletions eggstrain/src/execution/query_dag.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use super::operators::filter::Filter;
use super::operators::project::Project;
use super::operators::UnaryOperator;
use arrow::record_batch::RecordBatch;
use datafusion::physical_plan::filter::FilterExec;
use datafusion::physical_plan::{projection::ProjectionExec, ExecutionPlan, Partitioning};
use datafusion_common::{DataFusionError, Result};
use futures::stream::StreamExt;
Expand All @@ -15,17 +17,21 @@ const BATCH_SIZE: usize = 1024;
#[derive(Clone)]
pub(crate) enum EggstrainOperator {
Project(Arc<dyn UnaryOperator<In = RecordBatch, Out = RecordBatch>>),
// Filter(Arc<dyn UnaryNode>),
Filter(Arc<dyn UnaryOperator<In = RecordBatch, Out = RecordBatch>>),
// Sort(Arc<dyn UnaryNode>),

// Aggregate(Arc<dyn UnaryNode>),

// TableScan(Arc<dyn UnaryNode>),

// HashJoin(Arc<dyn BinaryNode>),
}

impl EggstrainOperator {
fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
match self {
Self::Project(x) => x.children(),
Self::Filter(x) => x.children(),
}
}
}
Expand All @@ -49,8 +55,16 @@ fn extract_df_node(plan: Arc<dyn ExecutionPlan>) -> Result<EggstrainOperator> {
let node = Project::new(child_schema, projection_plan).into_unary();

Ok(EggstrainOperator::Project(node))
// } else if id == TypeId::of::<FilterExec>() {
// todo!();
} else if id == TypeId::of::<FilterExec>() {
let filter_plan = root
.downcast_ref::<FilterExec>()
.expect("Unable to downcast_ref to FilterExec");

let node =
Filter::new(filter_plan.predicate().clone(), filter_plan.children()).into_unary();

Ok(EggstrainOperator::Filter(node))

// } else if id == TypeId::of::<HashJoinExec>() {
// todo!();
// } else if id == TypeId::of::<SortExec>() {
Expand Down Expand Up @@ -92,6 +106,7 @@ fn df_execute_node(plan: Arc<dyn ExecutionPlan>, tx: broadcast::Sender<RecordBat
}

pub fn build_query_dag(plan: Arc<dyn ExecutionPlan>) -> Result<broadcast::Receiver<RecordBatch>> {
// A tuple containing a plan node and a sender into that node
let mut queue = VecDeque::new();

// Final output is going to be sent to root_rx
Expand All @@ -103,23 +118,61 @@ pub fn build_query_dag(plan: Arc<dyn ExecutionPlan>) -> Result<broadcast::Receiv
queue.push_back((root, root_tx));

while let Some((node, tx)) = queue.pop_front() {
for child in node.children() {
let (child_tx, child_rx) = broadcast::channel::<RecordBatch>(BATCH_SIZE);
let node = node.clone();

match node.children().len() {
0 => {
todo!();
}
1 => {
let (child_tx, child_rx) = broadcast::channel::<RecordBatch>(BATCH_SIZE);
let child_plan = node.children()[0].clone();

if let Ok(child_node) = extract_df_node(child.clone()) {
match child_node.clone() {
EggstrainOperator::Project(project) => {
match node.clone() {
EggstrainOperator::Project(eggnode) | EggstrainOperator::Filter(eggnode) => {
let tx = tx.clone();
tokio::spawn(async move {
project.execute(child_rx, tx).await;
eggnode.execute(child_rx, tx).await;
});
}
};
queue.push_back((child_node, child_tx));
} else {
df_execute_node(child.clone(), tx.clone());

match extract_df_node(child_plan.clone()) {
Ok(val) => {
queue.push_back((val, child_tx));
}
Err(_) => {
df_execute_node(child_plan, child_tx);
}
}
}
2 => {
todo!();
}
_ => {
return Err(DataFusionError::NotImplemented(
"More than 2 children not implemented".to_string(),
));
}
}

// for child in node.children() {
// let (child_tx, child_rx) = broadcast::channel::<RecordBatch>(BATCH_SIZE);

// if let Ok(child_node) = extract_df_node(child.clone()) {
// match child_node.clone() {
// EggstrainOperator::Project(eggnode) | EggstrainOperator::Filter(eggnode) => {
// let tx = tx.clone();
// tokio::spawn(async move {
// eggnode.execute(child_rx, tx).await;
// });
// }
// };
// queue.push_back((child_node, child_tx));
// } else {
// df_execute_node(child.clone(), tx.clone());
// }
// }
}

Ok(root_rx)
Expand Down
4 changes: 4 additions & 0 deletions eggstrain/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ async fn main() -> Result<()> {

let physical_plan = tpch.clone().create_physical_plan().await?;

println!("{:#?}", physical_plan.clone());

// let physical_plan = physical_plan.children()[0].clone();

let results = run(physical_plan).await;

results.into_iter().for_each(|batch| {
Expand Down
7 changes: 7 additions & 0 deletions queries/basic_filter.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
SELECT
orders.o_totalprice
FROM
orders
WHERE
orders.o_totalprice < 850.00
;

0 comments on commit 337101e

Please sign in to comment.