Skip to content

Commit

Permalink
Add test for cardinality aware row converter on low card dict
Browse files Browse the repository at this point in the history
  • Loading branch information
JayjeetAtGithub committed Aug 31, 2023
1 parent f0c6853 commit 20f6dbb
Showing 1 changed file with 46 additions and 17 deletions.
63 changes: 46 additions & 17 deletions datafusion/core/src/physical_plan/wrapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,25 +80,38 @@ mod tests {
use super::*;

// Generate a record batch with a high cardinality dictionary field
fn generate_batch_with_high_card_dict_field() -> Result<RecordBatch, ArrowError> {
fn generate_batch_with_cardinality(card: String) -> Result<RecordBatch, ArrowError> {
let schema = SchemaRef::new(Schema::new(vec![
Field::new("a_dict", DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), false),
Field::new("b_prim", DataType::Int32, false),
]));

let col_a: ArrayRef;
if card == "high" {
// building column `a_dict`
let mut values_vector: Vec<String> = Vec::new();
for _i in 1..=20 {
values_vector.push(String::from(Uuid::new_v4().to_string()));
}
let values = StringArray::from(values_vector);

// building column `a_dict`
let mut values_vector: Vec<String> = Vec::new();
for _i in 1..=20 {
values_vector.push(String::from(Uuid::new_v4().to_string()));
}
let values = StringArray::from(values_vector);

let mut keys_vector: Vec<i32> = Vec::new();
for _i in 1..=20 {
keys_vector.push(rand::thread_rng().gen_range(0..20));
let mut keys_vector: Vec<i32> = Vec::new();
for _i in 1..=20 {
keys_vector.push(rand::thread_rng().gen_range(0..20));
}
let keys = Int32Array::from(keys_vector);
col_a = Arc::new(DictionaryArray::<Int32Type>::try_new(keys, Arc::new(values)).unwrap());
} else {
let values_vector = vec!["a", "b", "c"];
let values = StringArray::from(values_vector);

let mut keys_vector: Vec<i32> = Vec::new();
for _i in 1..=20 {
keys_vector.push(rand::thread_rng().gen_range(0..2));
}
let keys = Int32Array::from(keys_vector);
col_a = Arc::new(DictionaryArray::<Int32Type>::try_new(keys, Arc::new(values)).unwrap());
}
let keys = Int32Array::from(keys_vector);
let col_a: ArrayRef = Arc::new(DictionaryArray::<Int32Type>::try_new(keys, Arc::new(values)).unwrap());

// building column `b_prim`
let mut values: Vec<i32> = Vec::new();
Expand All @@ -110,12 +123,10 @@ mod tests {
// building the record batch
RecordBatch::try_from_iter(vec![("a_dict", col_a), ("b_prim", col_b)])
}

// fn generate_batch_with_low_card_dict_field() {}

#[tokio::test]
async fn test_cardinality_decision() {
let batch = generate_batch_with_high_card_dict_field().unwrap();
async fn test_with_high_card() {
let batch = generate_batch_with_cardinality(String::from("high")).unwrap();
let sort_fields = vec![
arrow::row::SortField::new_with_options(DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), SortOptions::default()),
arrow::row::SortField::new_with_options(DataType::Int32, SortOptions::default())
Expand All @@ -130,4 +141,22 @@ mod tests {
let converted_batch = converter.convert_rows(&rows).unwrap();
assert_eq!(converted_batch[0].data_type(), &DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)));
}

#[tokio::test]
async fn test_with_low_card() {
let batch = generate_batch_with_cardinality(String::from("low")).unwrap();
let sort_fields = vec![
arrow::row::SortField::new_with_options(DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), SortOptions::default()),
arrow::row::SortField::new_with_options(DataType::Int32, SortOptions::default())
];
let mut converter = CardinalityAwareRowConverter::new(sort_fields.clone()).unwrap();
let rows = converter.convert_columns(&batch.columns()).unwrap();
let converted_batch = converter.convert_rows(&rows).unwrap();
assert_eq!(converted_batch[0].data_type(), &DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)));

let mut converter = RowConverter::new(sort_fields.clone()).unwrap();
let rows = converter.convert_columns(&batch.columns()).unwrap();
let converted_batch = converter.convert_rows(&rows).unwrap();
assert_eq!(converted_batch[0].data_type(), &DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)));
}
}

0 comments on commit 20f6dbb

Please sign in to comment.