Skip to content

Commit

Permalink
Support (order by / sort) for DataFrameWriteOptions (#13874)
Browse files Browse the repository at this point in the history
* Support (order by / sort) for DataFrameWriteOptions

* Fix fmt

* Fix import

* Add insert into example
  • Loading branch information
zhuqi-lucas authored Dec 24, 2024
1 parent b4b267a commit 6cfd1cf
Show file tree
Hide file tree
Showing 2 changed files with 282 additions and 4 deletions.
276 changes: 273 additions & 3 deletions datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ pub struct DataFrameWriteOptions {
/// Sets which columns should be used for hive-style partitioned writes by name.
/// Can be set to empty vec![] for non-partitioned writes.
partition_by: Vec<String>,
/// Sets which columns should be used for sorting the output by name.
/// Can be set to empty vec![] for non-sorted writes.
sort_by: Vec<SortExpr>,
}

impl DataFrameWriteOptions {
Expand All @@ -86,6 +89,7 @@ impl DataFrameWriteOptions {
insert_op: InsertOp::Append,
single_file_output: false,
partition_by: vec![],
sort_by: vec![],
}
}

Expand All @@ -106,6 +110,12 @@ impl DataFrameWriteOptions {
self.partition_by = partition_by;
self
}

/// Sets the sort_by columns for output sorting
pub fn with_sort_by(mut self, sort_by: Vec<SortExpr>) -> Self {
self.sort_by = sort_by;
self
}
}

impl Default for DataFrameWriteOptions {
Expand Down Expand Up @@ -1517,8 +1527,17 @@ impl DataFrame {
write_options: DataFrameWriteOptions,
) -> Result<Vec<RecordBatch>, DataFusionError> {
let arrow_schema = Schema::from(self.schema());

let plan = if write_options.sort_by.is_empty() {
self.plan
} else {
LogicalPlanBuilder::from(self.plan)
.sort(write_options.sort_by)?
.build()?
};

let plan = LogicalPlanBuilder::insert_into(
self.plan,
plan,
table_name.to_owned(),
&arrow_schema,
write_options.insert_op,
Expand Down Expand Up @@ -1577,8 +1596,16 @@ impl DataFrame {

let file_type = format_as_file_type(format);

let plan = if options.sort_by.is_empty() {
self.plan
} else {
LogicalPlanBuilder::from(self.plan)
.sort(options.sort_by)?
.build()?
};

let plan = LogicalPlanBuilder::copy_to(
self.plan,
plan,
path.into(),
file_type,
HashMap::new(),
Expand Down Expand Up @@ -1638,8 +1665,16 @@ impl DataFrame {

let file_type = format_as_file_type(format);

let plan = if options.sort_by.is_empty() {
self.plan
} else {
LogicalPlanBuilder::from(self.plan)
.sort(options.sort_by)?
.build()?
};

let plan = LogicalPlanBuilder::copy_to(
self.plan,
plan,
path.into(),
file_type,
Default::default(),
Expand Down Expand Up @@ -1940,6 +1975,7 @@ mod tests {
use crate::physical_plan::{ColumnarValue, Partitioning, PhysicalExpr};
use crate::test_util::{register_aggregate_csv, test_table, test_table_with_name};

use crate::prelude::{CsvReadOptions, NdJsonReadOptions, ParquetReadOptions};
use arrow::array::Int32Array;
use datafusion_common::{assert_batches_eq, Constraint, Constraints, ScalarValue};
use datafusion_common_runtime::SpawnedTask;
Expand All @@ -1954,6 +1990,7 @@ mod tests {
use datafusion_physical_expr::expressions::Column;
use datafusion_physical_plan::{get_plan_string, ExecutionPlanProperties};
use sqlparser::ast::NullTreatment;
use tempfile::TempDir;

// Get string representation of the plan
async fn assert_physical_plan(df: &DataFrame, expected: Vec<&str>) {
Expand Down Expand Up @@ -4057,4 +4094,237 @@ mod tests {

Ok(())
}

// Test issue: https://github.com/apache/datafusion/issues/13873
#[tokio::test]
async fn write_parquet_with_order() -> Result<()> {
let tmp_dir = TempDir::new()?;
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Int32, true),
]));

let ctx = SessionContext::new();
let write_df = ctx.read_batch(RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from(vec![1, 5, 7, 3, 2])),
Arc::new(Int32Array::from(vec![2, 3, 4, 5, 6])),
],
)?)?;

let test_path = tmp_dir.path().join("test.parquet");

write_df
.clone()
.write_parquet(
test_path.to_str().unwrap(),
DataFrameWriteOptions::new()
.with_sort_by(vec![col("a").sort(true, true)]),
None,
)
.await?;

let ctx = SessionContext::new();
ctx.register_parquet(
"data",
test_path.to_str().unwrap(),
ParquetReadOptions::default(),
)
.await?;

let df = ctx.sql("SELECT * FROM data").await?;
let results = df.collect().await?;

let df_explain = ctx.sql("explain SELECT a FROM data").await?;
let explain_result = df_explain.collect().await?;

println!("explain_result {:?}", explain_result);

assert_batches_eq!(
&[
"+---+---+",
"| a | b |",
"+---+---+",
"| 1 | 2 |",
"| 2 | 6 |",
"| 3 | 5 |",
"| 5 | 3 |",
"| 7 | 4 |",
"+---+---+",
],
&results
);
Ok(())
}

// Test issue: https://github.com/apache/datafusion/issues/13873
#[tokio::test]
async fn write_csv_with_order() -> Result<()> {
let tmp_dir = TempDir::new()?;
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Int32, true),
]));

let ctx = SessionContext::new();
let write_df = ctx.read_batch(RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from(vec![1, 5, 7, 3, 2])),
Arc::new(Int32Array::from(vec![2, 3, 4, 5, 6])),
],
)?)?;

let test_path = tmp_dir.path().join("test.csv");

write_df
.clone()
.write_csv(
test_path.to_str().unwrap(),
DataFrameWriteOptions::new()
.with_sort_by(vec![col("a").sort(true, true)]),
None,
)
.await?;

let ctx = SessionContext::new();
ctx.register_csv(
"data",
test_path.to_str().unwrap(),
CsvReadOptions::new().schema(&schema),
)
.await?;

let df = ctx.sql("SELECT * FROM data").await?;
let results = df.collect().await?;

assert_batches_eq!(
&[
"+---+---+",
"| a | b |",
"+---+---+",
"| 1 | 2 |",
"| 2 | 6 |",
"| 3 | 5 |",
"| 5 | 3 |",
"| 7 | 4 |",
"+---+---+",
],
&results
);
Ok(())
}

// Test issue: https://github.com/apache/datafusion/issues/13873
#[tokio::test]
async fn write_json_with_order() -> Result<()> {
let tmp_dir = TempDir::new()?;
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Int32, true),
]));

let ctx = SessionContext::new();
let write_df = ctx.read_batch(RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from(vec![1, 5, 7, 3, 2])),
Arc::new(Int32Array::from(vec![2, 3, 4, 5, 6])),
],
)?)?;

let test_path = tmp_dir.path().join("test.json");

write_df
.clone()
.write_json(
test_path.to_str().unwrap(),
DataFrameWriteOptions::new()
.with_sort_by(vec![col("a").sort(true, true)]),
None,
)
.await?;

let ctx = SessionContext::new();
ctx.register_json(
"data",
test_path.to_str().unwrap(),
NdJsonReadOptions::default().schema(&schema),
)
.await?;

let df = ctx.sql("SELECT * FROM data").await?;
let results = df.collect().await?;

assert_batches_eq!(
&[
"+---+---+",
"| a | b |",
"+---+---+",
"| 1 | 2 |",
"| 2 | 6 |",
"| 3 | 5 |",
"| 5 | 3 |",
"| 7 | 4 |",
"+---+---+",
],
&results
);
Ok(())
}

// Test issue: https://github.com/apache/datafusion/issues/13873
#[tokio::test]
async fn write_table_with_order() -> Result<()> {
let tmp_dir = TempDir::new()?;
let ctx = SessionContext::new();
let location = tmp_dir.path().join("test_table/");

let mut write_df = ctx
.sql("values ('z'), ('x'), ('a'), ('b'), ('c')")
.await
.unwrap();

// Ensure the column names and types match the target table
write_df = write_df
.with_column_renamed("column1", "tablecol1")
.unwrap();
let sql_str =
"create external table data(tablecol1 varchar) stored as parquet location '"
.to_owned()
+ location.to_str().unwrap()
+ "'";

ctx.sql(sql_str.as_str()).await?.collect().await?;

// This is equivalent to INSERT INTO test.
write_df
.clone()
.write_table(
"data",
DataFrameWriteOptions::new()
.with_sort_by(vec![col("tablecol1").sort(true, true)]),
)
.await?;

let df = ctx.sql("SELECT * FROM data").await?;
let results = df.collect().await?;

assert_batches_eq!(
&[
"+-----------+",
"| tablecol1 |",
"+-----------+",
"| a |",
"| b |",
"| c |",
"| x |",
"| z |",
"+-----------+",
],
&results
);
Ok(())
}
}
10 changes: 9 additions & 1 deletion datafusion/core/src/dataframe/parquet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,16 @@ impl DataFrame {

let file_type = format_as_file_type(format);

let plan = if options.sort_by.is_empty() {
self.plan
} else {
LogicalPlanBuilder::from(self.plan)
.sort(options.sort_by)?
.build()?
};

let plan = LogicalPlanBuilder::copy_to(
self.plan,
plan,
path.into(),
file_type,
Default::default(),
Expand Down

0 comments on commit 6cfd1cf

Please sign in to comment.