Skip to content

Commit

Permalink
Fix sort node deserialization from proto (#12626)
Browse files Browse the repository at this point in the history
  • Loading branch information
palaska committed Sep 27, 2024
1 parent 1b3608d commit 9b4f90a
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 9 deletions.
16 changes: 12 additions & 4 deletions datafusion/expr/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -570,10 +570,18 @@ impl LogicalPlanBuilder {
)
}

/// Apply a sort
pub fn sort(
self,
sorts: impl IntoIterator<Item = impl Into<SortExpr>> + Clone,
) -> Result<Self> {
self.sort_with_limit(sorts, None)
}

/// Apply a sort
pub fn sort_with_limit(
self,
sorts: impl IntoIterator<Item = impl Into<SortExpr>> + Clone,
fetch: Option<usize>,
) -> Result<Self> {
let sorts = rewrite_sort_cols_by_aggs(sorts, &self.plan)?;

Expand All @@ -597,7 +605,7 @@ impl LogicalPlanBuilder {
return Ok(Self::new(LogicalPlan::Sort(Sort {
expr: normalize_sorts(sorts, &self.plan)?,
input: self.plan,
fetch: None,
fetch,
})));
}

Expand All @@ -613,7 +621,7 @@ impl LogicalPlanBuilder {
let sort_plan = LogicalPlan::Sort(Sort {
expr: normalize_sorts(sorts, &plan)?,
input: Arc::new(plan),
fetch: None,
fetch,
});

Projection::try_new(new_expr, Arc::new(sort_plan))
Expand Down Expand Up @@ -1202,7 +1210,7 @@ impl LogicalPlanBuilder {

/// Unnest the given columns with the given [`UnnestOptions`]
/// if one column is a list type, it can be recursively and simultaneously
/// unnested into the desired recursion levels
/// unnested into the desired recursion levels
/// e.g select unnest(list_col,depth=1), unnest(list_col,depth=2)
pub fn unnest_columns_recursive_with_options(
self,
Expand Down
13 changes: 8 additions & 5 deletions datafusion/proto/src/logical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -490,17 +490,20 @@ impl AsLogicalPlan for LogicalPlanNode {
into_logical_plan!(sort.input, ctx, extension_codec)?;
let sort_expr: Vec<SortExpr> =
from_proto::parse_sorts(&sort.expr, ctx, extension_codec)?;
LogicalPlanBuilder::from(input).sort(sort_expr)?.build()
let fetch: Option<usize> = sort.fetch.try_into().ok();
LogicalPlanBuilder::from(input)
.sort_with_limit(sort_expr, fetch)?
.build()
}
LogicalPlanType::Repartition(repartition) => {
use datafusion::logical_expr::Partitioning;
let input: LogicalPlan =
into_logical_plan!(repartition.input, ctx, extension_codec)?;
use protobuf::repartition_node::PartitionMethod;
let pb_partition_method = repartition.partition_method.as_ref().ok_or_else(|| {
DataFusionError::Internal(String::from(
"Protobuf deserialization error, RepartitionNode was missing required field 'partition_method'",
))
internal_datafusion_err!(
"Protobuf deserialization error, RepartitionNode was missing required field 'partition_method'"
)
})?;

let partitioning_scheme = match pb_partition_method {
Expand All @@ -526,7 +529,7 @@ impl AsLogicalPlan for LogicalPlanNode {
LogicalPlanType::CreateExternalTable(create_extern_table) => {
let pb_schema = (create_extern_table.schema.clone()).ok_or_else(|| {
DataFusionError::Internal(String::from(
"Protobuf deserialization error, CreateExternalTableNode was missing required field schema.",
"Protobuf deserialization error, CreateExternalTableNode was missing required field schema."
))
})?;

Expand Down
26 changes: 26 additions & 0 deletions datafusion/proto/tests/cases/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,32 @@ async fn roundtrip_logical_plan_aggregation() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn roundtrip_logical_plan_sort() -> Result<()> {
let ctx = SessionContext::new();

let schema = Schema::new(vec![
Field::new("a", DataType::Int64, true),
Field::new("b", DataType::Decimal128(15, 2), true),
]);

ctx.register_csv(
"t1",
"tests/testdata/test.csv",
CsvReadOptions::default().schema(&schema),
)
.await?;

let query = "SELECT a, b FROM t1 ORDER BY b LIMIT 5";
let plan = ctx.sql(query).await?.into_optimized_plan()?;

let bytes = logical_plan_to_bytes(&plan)?;
let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?;
assert_eq!(format!("{plan}"), format!("{logical_round_trip}"));

Ok(())
}

#[tokio::test]
async fn roundtrip_logical_plan_copy_to_sql_options() -> Result<()> {
let ctx = SessionContext::new();
Expand Down

0 comments on commit 9b4f90a

Please sign in to comment.