diff --git a/datafusion/physical-plan/src/common.rs b/datafusion/physical-plan/src/common.rs index b83641acf2ce..a8d4a3ddf3d1 100644 --- a/datafusion/physical-plan/src/common.rs +++ b/datafusion/physical-plan/src/common.rs @@ -180,7 +180,7 @@ pub fn compute_record_batch_statistics( } } -/// Write in Arrow IPC format. +/// Write in Arrow IPC File format. pub struct IPCWriter { /// Path pub path: PathBuf, diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index 6c933ca21807..9b008f5242c4 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -59,7 +59,7 @@ use arrow::compute::{ }; use arrow::datatypes::{DataType, SchemaRef, TimeUnit}; use arrow::error::ArrowError; -use arrow::ipc::reader::FileReader; +use arrow::ipc::reader::StreamReader; use datafusion_common::{ exec_err, internal_err, not_impl_err, plan_err, DataFusionError, HashSet, JoinSide, JoinType, Result, @@ -1394,7 +1394,7 @@ impl SortMergeJoinStream { if let Some(batch) = buffered_batch.batch { spill_record_batches( - vec![batch], + &[batch], spill_file.path().into(), Arc::clone(&self.buffered_schema), )?; @@ -2270,7 +2270,7 @@ fn fetch_right_columns_from_batch_by_idxs( Vec::with_capacity(buffered_indices.len()); let file = BufReader::new(File::open(spill_file.path())?); - let reader = FileReader::try_new(file, None)?; + let reader = StreamReader::try_new(file, None)?; for batch in reader { batch?.columns().iter().for_each(|column| { diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index d84068527a64..751496c70808 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -407,7 +407,7 @@ impl ExternalSorter { let spill_file = self.runtime.disk_manager.create_tmp_file("Sorting")?; let batches = std::mem::take(&mut self.in_mem_batches); let (spilled_rows, spilled_bytes) = spill_record_batches( - batches, + &batches, spill_file.path().into(), Arc::clone(&self.schema), )?; diff --git a/datafusion/physical-plan/src/spill.rs b/datafusion/physical-plan/src/spill.rs index b45353ae13f0..fa1b8a91cec7 100644 --- a/datafusion/physical-plan/src/spill.rs +++ b/datafusion/physical-plan/src/spill.rs @@ -23,8 +23,8 @@ use std::path::{Path, PathBuf}; use std::ptr::NonNull; use arrow::array::ArrayData; -use arrow::datatypes::SchemaRef; -use arrow::ipc::reader::FileReader; +use arrow::datatypes::{Schema, SchemaRef}; +use arrow::ipc::{reader::StreamReader, writer::StreamWriter}; use arrow::record_batch::RecordBatch; use log::debug; use tokio::sync::mpsc::Sender; @@ -34,7 +34,6 @@ use datafusion_execution::disk_manager::RefCountedTempFile; use datafusion_execution::memory_pool::human_readable_size; use datafusion_execution::SendableRecordBatchStream; -use crate::common::IPCWriter; use crate::stream::RecordBatchReceiverStream; /// Read spilled batches from the disk @@ -59,13 +58,13 @@ pub(crate) fn read_spill_as_stream( /// /// Returns total number of the rows spilled to disk. pub(crate) fn spill_record_batches( - batches: Vec, + batches: &[RecordBatch], path: PathBuf, schema: SchemaRef, ) -> Result<(usize, usize)> { - let mut writer = IPCWriter::new(path.as_ref(), schema.as_ref())?; + let mut writer = IPCStreamWriter::new(path.as_ref(), schema.as_ref())?; for batch in batches { - writer.write(&batch)?; + writer.write(batch)?; } writer.finish()?; debug!( @@ -79,7 +78,7 @@ pub(crate) fn spill_record_batches( fn read_spill(sender: Sender>, path: &Path) -> Result<()> { let file = BufReader::new(File::open(path)?); - let reader = FileReader::try_new(file, None)?; + let reader = StreamReader::try_new(file, None)?; for batch in reader { sender .blocking_send(batch.map_err(Into::into)) @@ -98,7 +97,7 @@ pub fn spill_record_batch_by_size( ) -> Result<()> { let mut offset = 0; let total_rows = batch.num_rows(); - let mut writer = IPCWriter::new(&path, schema.as_ref())?; + let mut writer = IPCStreamWriter::new(&path, schema.as_ref())?; while offset < total_rows { let length = std::cmp::min(total_rows - offset, batch_size_rows); @@ -130,7 +129,7 @@ pub fn spill_record_batch_by_size( /// {xxxxxxxxxxxxxxxxxxx} <--- buffer /// ^ ^ ^ ^ /// | | | | -/// col1->{ } | | +/// col1->{ } | | /// col2--------->{ } /// /// In the above case, `get_record_batch_memory_size` will return the size of @@ -179,17 +178,64 @@ fn count_array_data_memory_size( } } +/// Write in Arrow IPC Stream format to a file. +/// +/// Stream format is used for spill because it supports dictionary replacement, and the random +/// access of IPC File format is not needed (IPC File format doesn't support dictionary replacement). +struct IPCStreamWriter { + /// Inner writer + pub writer: StreamWriter, + /// Batches written + pub num_batches: usize, + /// Rows written + pub num_rows: usize, + /// Bytes written + pub num_bytes: usize, +} + +impl IPCStreamWriter { + /// Create new writer + pub fn new(path: &Path, schema: &Schema) -> Result { + let file = File::create(path).map_err(|e| { + exec_datafusion_err!("Failed to create partition file at {path:?}: {e:?}") + })?; + Ok(Self { + num_batches: 0, + num_rows: 0, + num_bytes: 0, + writer: StreamWriter::try_new(file, schema)?, + }) + } + + /// Write one single batch + pub fn write(&mut self, batch: &RecordBatch) -> Result<()> { + self.writer.write(batch)?; + self.num_batches += 1; + self.num_rows += batch.num_rows(); + let num_bytes: usize = batch.get_array_memory_size(); + self.num_bytes += num_bytes; + Ok(()) + } + + /// Finish the writer + pub fn finish(&mut self) -> Result<()> { + self.writer.finish().map_err(Into::into) + } +} + #[cfg(test)] mod tests { use super::*; use crate::spill::{spill_record_batch_by_size, spill_record_batches}; use crate::test::build_table_i32; use arrow::array::{Float64Array, Int32Array, ListArray}; + use arrow::compute::cast; use arrow::datatypes::{DataType, Field, Int32Type, Schema}; use arrow::record_batch::RecordBatch; use datafusion_common::Result; use datafusion_execution::disk_manager::DiskManagerConfig; use datafusion_execution::DiskManager; + use itertools::Itertools; use std::fs::File; use std::io::BufReader; use std::sync::Arc; @@ -214,18 +260,85 @@ mod tests { let schema = batch1.schema(); let num_rows = batch1.num_rows() + batch2.num_rows(); let (spilled_rows, _) = spill_record_batches( - vec![batch1, batch2], + &[batch1, batch2], spill_file.path().into(), Arc::clone(&schema), )?; assert_eq!(spilled_rows, num_rows); let file = BufReader::new(File::open(spill_file.path())?); - let reader = FileReader::try_new(file, None)?; + let reader = StreamReader::try_new(file, None)?; - assert_eq!(reader.num_batches(), 2); assert_eq!(reader.schema(), schema); + let batches = reader.collect_vec(); + assert!(batches.len() == 2); + + Ok(()) + } + + #[test] + fn test_batch_spill_and_read_dictionary_arrays() -> Result<()> { + // See https://github.com/apache/datafusion/issues/4658 + + let batch1 = build_table_i32( + ("a2", &vec![0, 1, 2]), + ("b2", &vec![3, 4, 5]), + ("c2", &vec![4, 5, 6]), + ); + + let batch2 = build_table_i32( + ("a2", &vec![10, 11, 12]), + ("b2", &vec![13, 14, 15]), + ("c2", &vec![14, 15, 16]), + ); + + // Dictionary encode the arrays + let dict_type = + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Int32)); + let dict_schema = Arc::new(Schema::new(vec![ + Field::new("a2", dict_type.clone(), true), + Field::new("b2", dict_type.clone(), true), + Field::new("c2", dict_type.clone(), true), + ])); + + let batch1 = RecordBatch::try_new( + Arc::clone(&dict_schema), + batch1 + .columns() + .iter() + .map(|array| cast(array, &dict_type)) + .collect::>()?, + )?; + + let batch2 = RecordBatch::try_new( + Arc::clone(&dict_schema), + batch2 + .columns() + .iter() + .map(|array| cast(array, &dict_type)) + .collect::>()?, + )?; + + let disk_manager = DiskManager::try_new(DiskManagerConfig::NewOs)?; + + let spill_file = disk_manager.create_tmp_file("Test Spill")?; + let num_rows = batch1.num_rows() + batch2.num_rows(); + let (spilled_rows, _) = spill_record_batches( + &[batch1, batch2], + spill_file.path().into(), + Arc::clone(&dict_schema), + )?; + assert_eq!(spilled_rows, num_rows); + + let file = BufReader::new(File::open(spill_file.path())?); + let reader = StreamReader::try_new(file, None)?; + + assert_eq!(reader.schema(), dict_schema); + + let batches = reader.collect_vec(); + assert!(batches.len() == 2); + Ok(()) } @@ -249,11 +362,13 @@ mod tests { )?; let file = BufReader::new(File::open(spill_file.path())?); - let reader = FileReader::try_new(file, None)?; + let reader = StreamReader::try_new(file, None)?; - assert_eq!(reader.num_batches(), 4); assert_eq!(reader.schema(), schema); + let batches = reader.collect_vec(); + assert!(batches.len() == 4); + Ok(()) }