@@ -24,7 +24,7 @@ use arrow::array::cast::AsArray;
24
24
use arrow:: array:: { Array , ArrayBuilder , ArrayRef , GenericByteViewBuilder } ;
25
25
use arrow:: datatypes:: { BinaryViewType , ByteViewType , DataType , StringViewType } ;
26
26
use datafusion_common:: hash_utils:: create_hashes;
27
- use datafusion_common:: utils:: proxy:: { RawTableAllocExt , VecAllocExt } ;
27
+ use datafusion_common:: utils:: proxy:: RawTableAllocExt ;
28
28
use std:: fmt:: Debug ;
29
29
use std:: sync:: Arc ;
30
30
@@ -207,6 +207,7 @@ where
207
207
values,
208
208
make_payload_fn,
209
209
observe_payload_fn,
210
+ None ,
210
211
)
211
212
}
212
213
OutputType :: Utf8View => {
@@ -215,6 +216,43 @@ where
215
216
values,
216
217
make_payload_fn,
217
218
observe_payload_fn,
219
+ None ,
220
+ )
221
+ }
222
+ _ => unreachable ! ( "Utf8/Binary should use `ArrowBytesSet`" ) ,
223
+ } ;
224
+ }
225
+
226
+ /// Similar to [`Self::insert_if_new`] but allows the caller to provide the
227
+ /// hash values for the values in `values` instead of computing them
228
+ pub fn insert_if_new_with_hash < MP , OP > (
229
+ & mut self ,
230
+ values : & ArrayRef ,
231
+ make_payload_fn : MP ,
232
+ observe_payload_fn : OP ,
233
+ provided_hash : & Vec < u64 > ,
234
+ ) where
235
+ MP : FnMut ( Option < & [ u8 ] > ) -> V ,
236
+ OP : FnMut ( V ) ,
237
+ {
238
+ // Sanity check array type
239
+ match self . output_type {
240
+ OutputType :: BinaryView => {
241
+ assert ! ( matches!( values. data_type( ) , DataType :: BinaryView ) ) ;
242
+ self . insert_if_new_inner :: < MP , OP , BinaryViewType > (
243
+ values,
244
+ make_payload_fn,
245
+ observe_payload_fn,
246
+ Some ( provided_hash) ,
247
+ )
248
+ }
249
+ OutputType :: Utf8View => {
250
+ assert ! ( matches!( values. data_type( ) , DataType :: Utf8View ) ) ;
251
+ self . insert_if_new_inner :: < MP , OP , StringViewType > (
252
+ values,
253
+ make_payload_fn,
254
+ observe_payload_fn,
255
+ Some ( provided_hash) ,
218
256
)
219
257
}
220
258
_ => unreachable ! ( "Utf8/Binary should use `ArrowBytesSet`" ) ,
@@ -234,19 +272,26 @@ where
234
272
values : & ArrayRef ,
235
273
mut make_payload_fn : MP ,
236
274
mut observe_payload_fn : OP ,
275
+ provided_hash : Option < & Vec < u64 > > ,
237
276
) where
238
277
MP : FnMut ( Option < & [ u8 ] > ) -> V ,
239
278
OP : FnMut ( V ) ,
240
279
B : ByteViewType ,
241
280
{
242
281
// step 1: compute hashes
243
- let batch_hashes = & mut self . hashes_buffer ;
244
- batch_hashes. clear ( ) ;
245
- batch_hashes. resize ( values. len ( ) , 0 ) ;
246
- create_hashes ( & [ values. clone ( ) ] , & self . random_state , batch_hashes)
247
- // hash is supported for all types and create_hashes only
248
- // returns errors for unsupported types
249
- . unwrap ( ) ;
282
+ let batch_hashes = match provided_hash {
283
+ Some ( h) => h,
284
+ None => {
285
+ let batch_hashes = & mut self . hashes_buffer ;
286
+ batch_hashes. clear ( ) ;
287
+ batch_hashes. resize ( values. len ( ) , 0 ) ;
288
+ create_hashes ( & [ values. clone ( ) ] , & self . random_state , batch_hashes)
289
+ // hash is supported for all types and create_hashes only
290
+ // returns errors for unsupported types
291
+ . unwrap ( ) ;
292
+ batch_hashes
293
+ }
294
+ } ;
250
295
251
296
// step 2: insert each value into the set, if not already present
252
297
let values = values. as_byte_view :: < B > ( ) ;
@@ -353,9 +398,7 @@ where
353
398
/// Return the total size, in bytes, of memory used to store the data in
354
399
/// this set, not including `self`
355
400
pub fn size ( & self ) -> usize {
356
- self . map_size
357
- + self . builder . allocated_size ( )
358
- + self . hashes_buffer . allocated_size ( )
401
+ self . map_size + self . builder . allocated_size ( )
359
402
}
360
403
}
361
404
@@ -369,7 +412,6 @@ where
369
412
. field ( "map_size" , & self . map_size )
370
413
. field ( "view_builder" , & self . builder )
371
414
. field ( "random_state" , & self . random_state )
372
- . field ( "hashes_buffer" , & self . hashes_buffer )
373
415
. finish ( )
374
416
}
375
417
}
0 commit comments