@@ -24,27 +24,156 @@ use std::fs::File;
24
24
use std:: io:: BufReader ;
25
25
use std:: path:: { Path , PathBuf } ;
26
26
use std:: ptr:: NonNull ;
27
+ use std:: sync:: Arc ;
27
28
28
29
use arrow:: array:: ArrayData ;
29
30
use arrow:: datatypes:: { Schema , SchemaRef } ;
30
31
use arrow:: ipc:: { reader:: StreamReader , writer:: StreamWriter } ;
31
32
use arrow:: record_batch:: RecordBatch ;
32
- use tokio:: sync:: mpsc:: Sender ;
33
-
34
- use datafusion_common:: { exec_datafusion_err, HashSet , Result } ;
35
-
36
- fn read_spill ( sender : Sender < Result < RecordBatch > > , path : & Path ) -> Result < ( ) > {
37
- let file = BufReader :: new ( File :: open ( path) ?) ;
38
- // SAFETY: DataFusion's spill writer strictly follows Arrow IPC specifications
39
- // with validated schemas and buffers. Skip redundant validation during read
40
- // to speedup read operation. This is safe for DataFusion as input guaranteed to be correct when written.
41
- let reader = unsafe { StreamReader :: try_new ( file, None ) ?. with_skip_validation ( true ) } ;
42
- for batch in reader {
43
- sender
44
- . blocking_send ( batch. map_err ( Into :: into) )
45
- . map_err ( |e| exec_datafusion_err ! ( "{e}" ) ) ?;
33
+
34
+ use datafusion_common:: { exec_datafusion_err, DataFusionError , HashSet , Result } ;
35
+ use datafusion_common_runtime:: SpawnedTask ;
36
+ use datafusion_execution:: disk_manager:: RefCountedTempFile ;
37
+ use datafusion_execution:: RecordBatchStream ;
38
+ use futures:: { FutureExt as _, Stream } ;
39
+
40
+ /// Stream that reads spill files from disk where each batch is read in a spawned blocking task
41
+ /// It will read one batch at a time and will not do any buffering, to buffer data use [`crate::common::spawn_buffered`]
42
+ struct SpillReaderStream {
43
+ schema : SchemaRef ,
44
+ state : SpillReaderStreamState ,
45
+ }
46
+
47
+ /// When we poll for the next batch, we will get back both the batch and the reader,
48
+ /// so we can call `next` again.
49
+ type NextRecordBatchResult = Result < ( StreamReader < BufReader < File > > , Option < RecordBatch > ) > ;
50
+
51
+ enum SpillReaderStreamState {
52
+ /// Initial state: the stream was not initialized yet
53
+ /// and the file was not opened
54
+ Uninitialized ( RefCountedTempFile ) ,
55
+
56
+ /// A read is in progress in a spawned blocking task for which we hold the handle.
57
+ ReadInProgress ( SpawnedTask < NextRecordBatchResult > ) ,
58
+
59
+ /// A read has finished and we wait for being polled again in order to start reading the next batch.
60
+ Waiting ( StreamReader < BufReader < File > > ) ,
61
+
62
+ /// The stream has finished, successfully or not.
63
+ Done ,
64
+ }
65
+
66
+ impl SpillReaderStream {
67
+ fn new ( schema : SchemaRef , spill_file : RefCountedTempFile ) -> Self {
68
+ Self {
69
+ schema,
70
+ state : SpillReaderStreamState :: Uninitialized ( spill_file) ,
71
+ }
72
+ }
73
+
74
+ fn poll_next_inner (
75
+ & mut self ,
76
+ cx : & mut std:: task:: Context < ' _ > ,
77
+ ) -> std:: task:: Poll < Option < Result < RecordBatch > > > {
78
+ match & mut self . state {
79
+ SpillReaderStreamState :: Uninitialized ( _) => {
80
+ // Temporarily replace with `Done` to be able to pass the file to the task.
81
+ let SpillReaderStreamState :: Uninitialized ( spill_file) =
82
+ std:: mem:: replace ( & mut self . state , SpillReaderStreamState :: Done )
83
+ else {
84
+ unreachable ! ( )
85
+ } ;
86
+
87
+ let task = SpawnedTask :: spawn_blocking ( move || {
88
+ let file = BufReader :: new ( File :: open ( spill_file. path ( ) ) ?) ;
89
+ // SAFETY: DataFusion's spill writer strictly follows Arrow IPC specifications
90
+ // with validated schemas and buffers. Skip redundant validation during read
91
+ // to speedup read operation. This is safe for DataFusion as input guaranteed to be correct when written.
92
+ let mut reader = unsafe {
93
+ StreamReader :: try_new ( file, None ) ?. with_skip_validation ( true )
94
+ } ;
95
+
96
+ let next_batch = reader. next ( ) . transpose ( ) ?;
97
+
98
+ Ok ( ( reader, next_batch) )
99
+ } ) ;
100
+
101
+ self . state = SpillReaderStreamState :: ReadInProgress ( task) ;
102
+
103
+ // Poll again immediately so the inner task is polled and the waker is
104
+ // registered.
105
+ self . poll_next_inner ( cx)
106
+ }
107
+
108
+ SpillReaderStreamState :: ReadInProgress ( task) => {
109
+ let result = futures:: ready!( task. poll_unpin( cx) )
110
+ . unwrap_or_else ( |err| Err ( DataFusionError :: External ( Box :: new ( err) ) ) ) ;
111
+
112
+ match result {
113
+ Ok ( ( reader, batch) ) => {
114
+ match batch {
115
+ Some ( batch) => {
116
+ self . state = SpillReaderStreamState :: Waiting ( reader) ;
117
+
118
+ std:: task:: Poll :: Ready ( Some ( Ok ( batch) ) )
119
+ }
120
+ None => {
121
+ // Stream is done
122
+ self . state = SpillReaderStreamState :: Done ;
123
+
124
+ std:: task:: Poll :: Ready ( None )
125
+ }
126
+ }
127
+ }
128
+ Err ( err) => {
129
+ self . state = SpillReaderStreamState :: Done ;
130
+
131
+ std:: task:: Poll :: Ready ( Some ( Err ( err) ) )
132
+ }
133
+ }
134
+ }
135
+
136
+ SpillReaderStreamState :: Waiting ( _) => {
137
+ // Temporarily replace with `Done` to be able to pass the file to the task.
138
+ let SpillReaderStreamState :: Waiting ( mut reader) =
139
+ std:: mem:: replace ( & mut self . state , SpillReaderStreamState :: Done )
140
+ else {
141
+ unreachable ! ( )
142
+ } ;
143
+
144
+ let task = SpawnedTask :: spawn_blocking ( move || {
145
+ let next_batch = reader. next ( ) . transpose ( ) ?;
146
+
147
+ Ok ( ( reader, next_batch) )
148
+ } ) ;
149
+
150
+ self . state = SpillReaderStreamState :: ReadInProgress ( task) ;
151
+
152
+ // Poll again immediately so the inner task is polled and the waker is
153
+ // registered.
154
+ self . poll_next_inner ( cx)
155
+ }
156
+
157
+ SpillReaderStreamState :: Done => std:: task:: Poll :: Ready ( None ) ,
158
+ }
159
+ }
160
+ }
161
+
162
+ impl Stream for SpillReaderStream {
163
+ type Item = Result < RecordBatch > ;
164
+
165
+ fn poll_next (
166
+ self : std:: pin:: Pin < & mut Self > ,
167
+ cx : & mut std:: task:: Context < ' _ > ,
168
+ ) -> std:: task:: Poll < Option < Self :: Item > > {
169
+ self . get_mut ( ) . poll_next_inner ( cx)
170
+ }
171
+ }
172
+
173
+ impl RecordBatchStream for SpillReaderStream {
174
+ fn schema ( & self ) -> SchemaRef {
175
+ Arc :: clone ( & self . schema )
46
176
}
47
- Ok ( ( ) )
48
177
}
49
178
50
179
/// Spill the `RecordBatch` to disk as smaller batches
@@ -205,6 +334,7 @@ mod tests {
205
334
use arrow:: record_batch:: RecordBatch ;
206
335
use datafusion_common:: Result ;
207
336
use datafusion_execution:: runtime_env:: RuntimeEnv ;
337
+ use futures:: StreamExt as _;
208
338
209
339
use std:: sync:: Arc ;
210
340
@@ -604,4 +734,42 @@ mod tests {
604
734
605
735
Ok ( ( ) )
606
736
}
737
+
738
+ #[ test]
739
+ fn test_reading_more_spills_than_tokio_blocking_threads ( ) -> Result < ( ) > {
740
+ tokio:: runtime:: Builder :: new_current_thread ( )
741
+ . enable_all ( )
742
+ . max_blocking_threads ( 1 )
743
+ . build ( )
744
+ . unwrap ( )
745
+ . block_on ( async {
746
+ let batch = build_table_i32 (
747
+ ( "a2" , & vec ! [ 0 , 1 , 2 ] ) ,
748
+ ( "b2" , & vec ! [ 3 , 4 , 5 ] ) ,
749
+ ( "c2" , & vec ! [ 4 , 5 , 6 ] ) ,
750
+ ) ;
751
+
752
+ let schema = batch. schema ( ) ;
753
+
754
+ // Construct SpillManager
755
+ let env = Arc :: new ( RuntimeEnv :: default ( ) ) ;
756
+ let metrics = SpillMetrics :: new ( & ExecutionPlanMetricsSet :: new ( ) , 0 ) ;
757
+ let spill_manager = SpillManager :: new ( env, metrics, Arc :: clone ( & schema) ) ;
758
+ let batches: [ _ ; 10 ] = std:: array:: from_fn ( |_| batch. clone ( ) ) ;
759
+
760
+ let spill_file_1 = spill_manager
761
+ . spill_record_batch_and_finish ( & batches, "Test1" ) ?
762
+ . unwrap ( ) ;
763
+ let spill_file_2 = spill_manager
764
+ . spill_record_batch_and_finish ( & batches, "Test2" ) ?
765
+ . unwrap ( ) ;
766
+
767
+ let mut stream_1 = spill_manager. read_spill_as_stream ( spill_file_1) ?;
768
+ let mut stream_2 = spill_manager. read_spill_as_stream ( spill_file_2) ?;
769
+ stream_1. next ( ) . await ;
770
+ stream_2. next ( ) . await ;
771
+
772
+ Ok ( ( ) )
773
+ } )
774
+ }
607
775
}
0 commit comments