Skip to content

Commit 24283fb

Browse files
committed
multi-col agg
1 parent f4e519f commit 24283fb

File tree

2 files changed

+209
-68
lines changed

2 files changed

+209
-68
lines changed

datafusion/physical-expr-common/src/binary_view_map.rs

+54-12
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ use arrow::array::cast::AsArray;
2424
use arrow::array::{Array, ArrayBuilder, ArrayRef, GenericByteViewBuilder};
2525
use arrow::datatypes::{BinaryViewType, ByteViewType, DataType, StringViewType};
2626
use datafusion_common::hash_utils::create_hashes;
27-
use datafusion_common::utils::proxy::{RawTableAllocExt, VecAllocExt};
27+
use datafusion_common::utils::proxy::RawTableAllocExt;
2828
use std::fmt::Debug;
2929
use std::sync::Arc;
3030

@@ -207,6 +207,7 @@ where
207207
values,
208208
make_payload_fn,
209209
observe_payload_fn,
210+
None,
210211
)
211212
}
212213
OutputType::Utf8View => {
@@ -215,6 +216,43 @@ where
215216
values,
216217
make_payload_fn,
217218
observe_payload_fn,
219+
None,
220+
)
221+
}
222+
_ => unreachable!("Utf8/Binary should use `ArrowBytesSet`"),
223+
};
224+
}
225+
226+
/// Similar to [`Self::insert_if_new`] but allows the caller to provide the
227+
/// hash values for the values in `values` instead of computing them
228+
pub fn insert_if_new_with_hash<MP, OP>(
229+
&mut self,
230+
values: &ArrayRef,
231+
make_payload_fn: MP,
232+
observe_payload_fn: OP,
233+
provided_hash: &Vec<u64>,
234+
) where
235+
MP: FnMut(Option<&[u8]>) -> V,
236+
OP: FnMut(V),
237+
{
238+
// Sanity check array type
239+
match self.output_type {
240+
OutputType::BinaryView => {
241+
assert!(matches!(values.data_type(), DataType::BinaryView));
242+
self.insert_if_new_inner::<MP, OP, BinaryViewType>(
243+
values,
244+
make_payload_fn,
245+
observe_payload_fn,
246+
Some(provided_hash),
247+
)
248+
}
249+
OutputType::Utf8View => {
250+
assert!(matches!(values.data_type(), DataType::Utf8View));
251+
self.insert_if_new_inner::<MP, OP, StringViewType>(
252+
values,
253+
make_payload_fn,
254+
observe_payload_fn,
255+
Some(provided_hash),
218256
)
219257
}
220258
_ => unreachable!("Utf8/Binary should use `ArrowBytesSet`"),
@@ -234,19 +272,26 @@ where
234272
values: &ArrayRef,
235273
mut make_payload_fn: MP,
236274
mut observe_payload_fn: OP,
275+
provided_hash: Option<&Vec<u64>>,
237276
) where
238277
MP: FnMut(Option<&[u8]>) -> V,
239278
OP: FnMut(V),
240279
B: ByteViewType,
241280
{
242281
// step 1: compute hashes
243-
let batch_hashes = &mut self.hashes_buffer;
244-
batch_hashes.clear();
245-
batch_hashes.resize(values.len(), 0);
246-
create_hashes(&[values.clone()], &self.random_state, batch_hashes)
247-
// hash is supported for all types and create_hashes only
248-
// returns errors for unsupported types
249-
.unwrap();
282+
let batch_hashes = match provided_hash {
283+
Some(h) => h,
284+
None => {
285+
let batch_hashes = &mut self.hashes_buffer;
286+
batch_hashes.clear();
287+
batch_hashes.resize(values.len(), 0);
288+
create_hashes(&[values.clone()], &self.random_state, batch_hashes)
289+
// hash is supported for all types and create_hashes only
290+
// returns errors for unsupported types
291+
.unwrap();
292+
batch_hashes
293+
}
294+
};
250295

251296
// step 2: insert each value into the set, if not already present
252297
let values = values.as_byte_view::<B>();
@@ -353,9 +398,7 @@ where
353398
/// Return the total size, in bytes, of memory used to store the data in
354399
/// this set, not including `self`
355400
pub fn size(&self) -> usize {
356-
self.map_size
357-
+ self.builder.allocated_size()
358-
+ self.hashes_buffer.allocated_size()
401+
self.map_size + self.builder.allocated_size()
359402
}
360403
}
361404

@@ -369,7 +412,6 @@ where
369412
.field("map_size", &self.map_size)
370413
.field("view_builder", &self.builder)
371414
.field("random_state", &self.random_state)
372-
.field("hashes_buffer", &self.hashes_buffer)
373415
.finish()
374416
}
375417
}

0 commit comments

Comments
 (0)