Skip to content

Commit

Permalink
Fix join on arrays of unhashable types
Browse files Browse the repository at this point in the history
Update can_hash to match currently supported hashes.
  • Loading branch information
findepi committed Nov 13, 2024
1 parent 1e69946 commit 4961ca6
Showing 1 changed file with 42 additions and 17 deletions.
59 changes: 42 additions & 17 deletions datafusion/expr/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use crate::{
};
use datafusion_expr_common::signature::{Signature, TypeSignature};

use arrow::datatypes::{DataType, Field, Schema, TimeUnit};
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::tree_node::{
Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
};
Expand Down Expand Up @@ -958,7 +958,7 @@ pub(crate) fn find_column_indexes_referenced_by_expr(

/// Can this data type be used in hash join equal conditions??
/// Data types here come from function 'equal_rows', if more data types are supported
/// in equal_rows(hash join), add those data types here to generate join logical plan.
/// in create_hashes, add those data types here to generate join logical plan.
pub fn can_hash(data_type: &DataType) -> bool {
match data_type {
DataType::Null => true,
Expand All @@ -971,31 +971,38 @@ pub fn can_hash(data_type: &DataType) -> bool {
DataType::UInt16 => true,
DataType::UInt32 => true,
DataType::UInt64 => true,
DataType::Float16 => true,
DataType::Float32 => true,
DataType::Float64 => true,
DataType::Timestamp(time_unit, _) => match time_unit {
TimeUnit::Second => true,
TimeUnit::Millisecond => true,
TimeUnit::Microsecond => true,
TimeUnit::Nanosecond => true,
},
DataType::Decimal128(_, _) => true,
DataType::Decimal256(_, _) => true,
DataType::Timestamp(_, _) => true,
DataType::Utf8 => true,
DataType::LargeUtf8 => true,
DataType::Utf8View => true,
DataType::Decimal128(_, _) => true,
DataType::Binary => true,
DataType::LargeBinary => true,
DataType::BinaryView => true,
DataType::Date32 => true,
DataType::Date64 => true,
DataType::Time32(_) => true,
DataType::Time64(_) => true,
DataType::Duration(_) => true,
DataType::Interval(_) => true,
DataType::FixedSizeBinary(_) => true,
DataType::Dictionary(key_type, value_type)
if *value_type.as_ref() == DataType::Utf8 =>
{
DataType::is_dictionary_key_type(key_type)
DataType::Dictionary(key_type, value_type) => {
DataType::is_dictionary_key_type(key_type) && can_hash(value_type)
}
DataType::List(_) => true,
DataType::LargeList(_) => true,
DataType::FixedSizeList(_, _) => true,
DataType::List(value_type) => can_hash(value_type.data_type()),
DataType::LargeList(value_type) => can_hash(value_type.data_type()),
DataType::FixedSizeList(value_type, _) => can_hash(value_type.data_type()),
DataType::Map(map_struct, true | false) => can_hash(map_struct.data_type()),
DataType::Struct(fields) => fields.iter().all(|f| can_hash(f.data_type())),
_ => false,

DataType::ListView(_)
| DataType::LargeListView(_)
| DataType::Union(_, _)
| DataType::RunEndEncoded(_, _) => false,
}
}

Expand Down Expand Up @@ -1403,6 +1410,7 @@ mod tests {
test::function_stub::max_udaf, test::function_stub::min_udaf,
test::function_stub::sum_udaf, Cast, ExprFunctionExt, WindowFunctionDefinition,
};
use arrow::datatypes::{UnionFields, UnionMode};

#[test]
fn test_group_window_expr_by_sort_keys_empty_case() -> Result<()> {
Expand Down Expand Up @@ -1805,4 +1813,21 @@ mod tests {
assert!(accum.contains(&Column::from_name("a")));
Ok(())
}

#[test]
fn test_can_hash() {
let union_fields: UnionFields = [
(0, Arc::new(Field::new("A", DataType::Int32, true))),
(1, Arc::new(Field::new("B", DataType::Float64, true))),
]
.into_iter()
.collect();

let union_type = DataType::Union(union_fields, UnionMode::Sparse);
assert!(!can_hash(&union_type));

let list_union_type =
DataType::List(Arc::new(Field::new("my_union", union_type, true)));
assert!(!can_hash(&list_union_type));
}
}

0 comments on commit 4961ca6

Please sign in to comment.