@@ -23,8 +23,8 @@ use std::path::{Path, PathBuf};
23
23
use std:: ptr:: NonNull ;
24
24
25
25
use arrow:: array:: ArrayData ;
26
- use arrow:: datatypes:: SchemaRef ;
27
- use arrow:: ipc:: reader:: FileReader ;
26
+ use arrow:: datatypes:: { Schema , SchemaRef } ;
27
+ use arrow:: ipc:: { reader:: StreamReader , writer :: StreamWriter } ;
28
28
use arrow:: record_batch:: RecordBatch ;
29
29
use log:: debug;
30
30
use tokio:: sync:: mpsc:: Sender ;
@@ -34,7 +34,6 @@ use datafusion_execution::disk_manager::RefCountedTempFile;
34
34
use datafusion_execution:: memory_pool:: human_readable_size;
35
35
use datafusion_execution:: SendableRecordBatchStream ;
36
36
37
- use crate :: common:: IPCWriter ;
38
37
use crate :: stream:: RecordBatchReceiverStream ;
39
38
40
39
/// Read spilled batches from the disk
@@ -59,13 +58,13 @@ pub(crate) fn read_spill_as_stream(
59
58
///
60
59
/// Returns total number of the rows spilled to disk.
61
60
pub ( crate ) fn spill_record_batches (
62
- batches : Vec < RecordBatch > ,
61
+ batches : & [ RecordBatch ] ,
63
62
path : PathBuf ,
64
63
schema : SchemaRef ,
65
64
) -> Result < ( usize , usize ) > {
66
- let mut writer = IPCWriter :: new ( path. as_ref ( ) , schema. as_ref ( ) ) ?;
65
+ let mut writer = IPCStreamWriter :: new ( path. as_ref ( ) , schema. as_ref ( ) ) ?;
67
66
for batch in batches {
68
- writer. write ( & batch) ?;
67
+ writer. write ( batch) ?;
69
68
}
70
69
writer. finish ( ) ?;
71
70
debug ! (
@@ -79,7 +78,7 @@ pub(crate) fn spill_record_batches(
79
78
80
79
fn read_spill ( sender : Sender < Result < RecordBatch > > , path : & Path ) -> Result < ( ) > {
81
80
let file = BufReader :: new ( File :: open ( path) ?) ;
82
- let reader = FileReader :: try_new ( file, None ) ?;
81
+ let reader = StreamReader :: try_new ( file, None ) ?;
83
82
for batch in reader {
84
83
sender
85
84
. blocking_send ( batch. map_err ( Into :: into) )
@@ -98,7 +97,7 @@ pub fn spill_record_batch_by_size(
98
97
) -> Result < ( ) > {
99
98
let mut offset = 0 ;
100
99
let total_rows = batch. num_rows ( ) ;
101
- let mut writer = IPCWriter :: new ( & path, schema. as_ref ( ) ) ?;
100
+ let mut writer = IPCStreamWriter :: new ( & path, schema. as_ref ( ) ) ?;
102
101
103
102
while offset < total_rows {
104
103
let length = std:: cmp:: min ( total_rows - offset, batch_size_rows) ;
@@ -130,7 +129,7 @@ pub fn spill_record_batch_by_size(
130
129
/// {xxxxxxxxxxxxxxxxxxx} <--- buffer
131
130
/// ^ ^ ^ ^
132
131
/// | | | |
133
- /// col1->{ } | |
132
+ /// col1->{ } | |
134
133
/// col2--------->{ }
135
134
///
136
135
/// In the above case, `get_record_batch_memory_size` will return the size of
@@ -179,17 +178,64 @@ fn count_array_data_memory_size(
179
178
}
180
179
}
181
180
181
+ /// Write in Arrow IPC Stream format to a file.
182
+ ///
183
+ /// Stream format is used for spill because it supports dictionary replacement, and the random
184
+ /// access of IPC File format is not needed (IPC File format doesn't support dictionary replacement).
185
+ struct IPCStreamWriter {
186
+ /// Inner writer
187
+ pub writer : StreamWriter < File > ,
188
+ /// Batches written
189
+ pub num_batches : usize ,
190
+ /// Rows written
191
+ pub num_rows : usize ,
192
+ /// Bytes written
193
+ pub num_bytes : usize ,
194
+ }
195
+
196
+ impl IPCStreamWriter {
197
+ /// Create new writer
198
+ pub fn new ( path : & Path , schema : & Schema ) -> Result < Self > {
199
+ let file = File :: create ( path) . map_err ( |e| {
200
+ exec_datafusion_err ! ( "Failed to create partition file at {path:?}: {e:?}" )
201
+ } ) ?;
202
+ Ok ( Self {
203
+ num_batches : 0 ,
204
+ num_rows : 0 ,
205
+ num_bytes : 0 ,
206
+ writer : StreamWriter :: try_new ( file, schema) ?,
207
+ } )
208
+ }
209
+
210
+ /// Write one single batch
211
+ pub fn write ( & mut self , batch : & RecordBatch ) -> Result < ( ) > {
212
+ self . writer . write ( batch) ?;
213
+ self . num_batches += 1 ;
214
+ self . num_rows += batch. num_rows ( ) ;
215
+ let num_bytes: usize = batch. get_array_memory_size ( ) ;
216
+ self . num_bytes += num_bytes;
217
+ Ok ( ( ) )
218
+ }
219
+
220
+ /// Finish the writer
221
+ pub fn finish ( & mut self ) -> Result < ( ) > {
222
+ self . writer . finish ( ) . map_err ( Into :: into)
223
+ }
224
+ }
225
+
182
226
#[ cfg( test) ]
183
227
mod tests {
184
228
use super :: * ;
185
229
use crate :: spill:: { spill_record_batch_by_size, spill_record_batches} ;
186
230
use crate :: test:: build_table_i32;
187
231
use arrow:: array:: { Float64Array , Int32Array , ListArray } ;
232
+ use arrow:: compute:: cast;
188
233
use arrow:: datatypes:: { DataType , Field , Int32Type , Schema } ;
189
234
use arrow:: record_batch:: RecordBatch ;
190
235
use datafusion_common:: Result ;
191
236
use datafusion_execution:: disk_manager:: DiskManagerConfig ;
192
237
use datafusion_execution:: DiskManager ;
238
+ use itertools:: Itertools ;
193
239
use std:: fs:: File ;
194
240
use std:: io:: BufReader ;
195
241
use std:: sync:: Arc ;
@@ -214,18 +260,85 @@ mod tests {
214
260
let schema = batch1. schema ( ) ;
215
261
let num_rows = batch1. num_rows ( ) + batch2. num_rows ( ) ;
216
262
let ( spilled_rows, _) = spill_record_batches (
217
- vec ! [ batch1, batch2] ,
263
+ & [ batch1, batch2] ,
218
264
spill_file. path ( ) . into ( ) ,
219
265
Arc :: clone ( & schema) ,
220
266
) ?;
221
267
assert_eq ! ( spilled_rows, num_rows) ;
222
268
223
269
let file = BufReader :: new ( File :: open ( spill_file. path ( ) ) ?) ;
224
- let reader = FileReader :: try_new ( file, None ) ?;
270
+ let reader = StreamReader :: try_new ( file, None ) ?;
225
271
226
- assert_eq ! ( reader. num_batches( ) , 2 ) ;
227
272
assert_eq ! ( reader. schema( ) , schema) ;
228
273
274
+ let batches = reader. collect_vec ( ) ;
275
+ assert ! ( batches. len( ) == 2 ) ;
276
+
277
+ Ok ( ( ) )
278
+ }
279
+
280
+ #[ test]
281
+ fn test_batch_spill_and_read_dictionary_arrays ( ) -> Result < ( ) > {
282
+ // See https://github.com/apache/datafusion/issues/4658
283
+
284
+ let batch1 = build_table_i32 (
285
+ ( "a2" , & vec ! [ 0 , 1 , 2 ] ) ,
286
+ ( "b2" , & vec ! [ 3 , 4 , 5 ] ) ,
287
+ ( "c2" , & vec ! [ 4 , 5 , 6 ] ) ,
288
+ ) ;
289
+
290
+ let batch2 = build_table_i32 (
291
+ ( "a2" , & vec ! [ 10 , 11 , 12 ] ) ,
292
+ ( "b2" , & vec ! [ 13 , 14 , 15 ] ) ,
293
+ ( "c2" , & vec ! [ 14 , 15 , 16 ] ) ,
294
+ ) ;
295
+
296
+ // Dictionary encode the arrays
297
+ let dict_type =
298
+ DataType :: Dictionary ( Box :: new ( DataType :: Int32 ) , Box :: new ( DataType :: Int32 ) ) ;
299
+ let dict_schema = Arc :: new ( Schema :: new ( vec ! [
300
+ Field :: new( "a2" , dict_type. clone( ) , true ) ,
301
+ Field :: new( "b2" , dict_type. clone( ) , true ) ,
302
+ Field :: new( "c2" , dict_type. clone( ) , true ) ,
303
+ ] ) ) ;
304
+
305
+ let batch1 = RecordBatch :: try_new (
306
+ Arc :: clone ( & dict_schema) ,
307
+ batch1
308
+ . columns ( )
309
+ . iter ( )
310
+ . map ( |array| cast ( array, & dict_type) )
311
+ . collect :: < Result < _ , _ > > ( ) ?,
312
+ ) ?;
313
+
314
+ let batch2 = RecordBatch :: try_new (
315
+ Arc :: clone ( & dict_schema) ,
316
+ batch2
317
+ . columns ( )
318
+ . iter ( )
319
+ . map ( |array| cast ( array, & dict_type) )
320
+ . collect :: < Result < _ , _ > > ( ) ?,
321
+ ) ?;
322
+
323
+ let disk_manager = DiskManager :: try_new ( DiskManagerConfig :: NewOs ) ?;
324
+
325
+ let spill_file = disk_manager. create_tmp_file ( "Test Spill" ) ?;
326
+ let num_rows = batch1. num_rows ( ) + batch2. num_rows ( ) ;
327
+ let ( spilled_rows, _) = spill_record_batches (
328
+ & [ batch1, batch2] ,
329
+ spill_file. path ( ) . into ( ) ,
330
+ Arc :: clone ( & dict_schema) ,
331
+ ) ?;
332
+ assert_eq ! ( spilled_rows, num_rows) ;
333
+
334
+ let file = BufReader :: new ( File :: open ( spill_file. path ( ) ) ?) ;
335
+ let reader = StreamReader :: try_new ( file, None ) ?;
336
+
337
+ assert_eq ! ( reader. schema( ) , dict_schema) ;
338
+
339
+ let batches = reader. collect_vec ( ) ;
340
+ assert ! ( batches. len( ) == 2 ) ;
341
+
229
342
Ok ( ( ) )
230
343
}
231
344
@@ -249,11 +362,13 @@ mod tests {
249
362
) ?;
250
363
251
364
let file = BufReader :: new ( File :: open ( spill_file. path ( ) ) ?) ;
252
- let reader = FileReader :: try_new ( file, None ) ?;
365
+ let reader = StreamReader :: try_new ( file, None ) ?;
253
366
254
- assert_eq ! ( reader. num_batches( ) , 4 ) ;
255
367
assert_eq ! ( reader. schema( ) , schema) ;
256
368
369
+ let batches = reader. collect_vec ( ) ;
370
+ assert ! ( batches. len( ) == 4 ) ;
371
+
257
372
Ok ( ( ) )
258
373
}
259
374
0 commit comments