diff --git a/datafusion/physical-optimizer/src/pruning.rs b/datafusion/physical-optimizer/src/pruning.rs index 1dd168f18167..b62b72ac6dd7 100644 --- a/datafusion/physical-optimizer/src/pruning.rs +++ b/datafusion/physical-optimizer/src/pruning.rs @@ -1210,23 +1210,35 @@ fn is_compare_op(op: Operator) -> bool { ) } +fn is_string_type(data_type: &DataType) -> bool { + matches!( + data_type, + DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View + ) +} + // The pruning logic is based on the comparing the min/max bounds. // Must make sure the two type has order. // For example, casts from string to numbers is not correct. // Because the "13" is less than "3" with UTF8 comparison order. fn verify_support_type_for_prune(from_type: &DataType, to_type: &DataType) -> Result<()> { - // TODO: support other data type for prunable cast or try cast - if matches!( - from_type, - DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::Decimal128(_, _) - ) && matches!( - to_type, - DataType::Int8 | DataType::Int32 | DataType::Int64 | DataType::Decimal128(_, _) - ) { + // Dictionary casts are always supported as long as the value types are supported + let from_type = match from_type { + DataType::Dictionary(_, t) => { + return verify_support_type_for_prune(t.as_ref(), to_type) + } + _ => from_type, + }; + let to_type = match to_type { + DataType::Dictionary(_, t) => { + return verify_support_type_for_prune(from_type, t.as_ref()) + } + _ => to_type, + }; + // If both types are strings or both are not strings (number, timestamp, etc) + // then we can compare them. + // PruningPredicate does not support casting of strings to numbers and such. + if is_string_type(from_type) == is_string_type(to_type) { Ok(()) } else { plan_err!( @@ -1544,7 +1556,10 @@ fn build_predicate_expression( Ok(builder) => builder, // allow partial failure in predicate expression generation // this can still produce a useful predicate when multiple conditions are joined using AND - Err(_) => return unhandled_hook.handle(expr), + Err(e) => { + dbg!(format!("Error building pruning expression: {e}")); + return unhandled_hook.handle(expr); + } }; build_statistics_expr(&mut expr_builder) @@ -3006,7 +3021,7 @@ mod tests { } #[test] - fn row_group_predicate_cast() -> Result<()> { + fn row_group_predicate_cast_int_int() -> Result<()> { let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); let expected_expr = "c1_null_count@2 != row_count@3 AND CAST(c1_min@0 AS Int64) <= 1 AND 1 <= CAST(c1_max@1 AS Int64)"; @@ -3043,6 +3058,291 @@ mod tests { Ok(()) } + #[test] + fn row_group_predicate_cast_string_string() -> Result<()> { + let schema = Schema::new(vec![Field::new("c1", DataType::Utf8View, false)]); + let expected_expr = "c1_null_count@2 != row_count@3 AND CAST(c1_min@0 AS Utf8) <= 1 AND 1 <= CAST(c1_max@1 AS Utf8)"; + + // test column on the left + let expr = cast(col("c1"), DataType::Utf8) + .eq(lit(ScalarValue::Utf8(Some("1".to_string())))); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + // test column on the right + let expr = lit(ScalarValue::Utf8(Some("1".to_string()))) + .eq(cast(col("c1"), DataType::Utf8)); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + Ok(()) + } + + #[test] + fn row_group_predicate_cast_string_int() -> Result<()> { + let schema = Schema::new(vec![Field::new("c1", DataType::Utf8View, false)]); + let expected_expr = "true"; + + // test column on the left + let expr = cast(col("c1"), DataType::Int32).eq(lit(ScalarValue::Int32(Some(1)))); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + // test column on the right + let expr = lit(ScalarValue::Int32(Some(1))).eq(cast(col("c1"), DataType::Int32)); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + Ok(()) + } + + #[test] + fn row_group_predicate_cast_int_string() -> Result<()> { + let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); + let expected_expr = "true"; + + // test column on the left + let expr = cast(col("c1"), DataType::Utf8) + .eq(lit(ScalarValue::Utf8(Some("1".to_string())))); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + // test column on the right + let expr = lit(ScalarValue::Utf8(Some("1".to_string()))) + .eq(cast(col("c1"), DataType::Utf8)); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + Ok(()) + } + + #[test] + fn row_group_predicate_date_date() -> Result<()> { + let schema = Schema::new(vec![Field::new("c1", DataType::Date32, false)]); + let expected_expr = "c1_null_count@2 != row_count@3 AND CAST(c1_min@0 AS Date64) <= 1970-01-01 AND 1970-01-01 <= CAST(c1_max@1 AS Date64)"; + + // test column on the left + let expr = + cast(col("c1"), DataType::Date64).eq(lit(ScalarValue::Date64(Some(123)))); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + // test column on the right + let expr = + lit(ScalarValue::Date64(Some(123))).eq(cast(col("c1"), DataType::Date64)); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + Ok(()) + } + + #[test] + fn row_group_predicate_dict_string_date() -> Result<()> { + // Test with Dictionary for the literal + let schema = Schema::new(vec![Field::new("c1", DataType::Date32, false)]); + let expected_expr = "true"; + + // test column on the left + let expr = cast( + col("c1"), + DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Utf8)), + ) + .eq(lit(ScalarValue::Utf8(Some("2024-01-01".to_string())))); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + // test column on the right + let expr = lit(ScalarValue::Utf8(Some("2024-01-01".to_string()))).eq(cast( + col("c1"), + DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Utf8)), + )); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + Ok(()) + } + + #[test] + fn row_group_predicate_date_dict_string() -> Result<()> { + // Test with Dictionary for the column + let schema = Schema::new(vec![Field::new( + "c1", + DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Utf8)), + false, + )]); + let expected_expr = "true"; + + // test column on the left + let expr = + cast(col("c1"), DataType::Date32).eq(lit(ScalarValue::Date32(Some(123)))); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + // test column on the right + let expr = + lit(ScalarValue::Date32(Some(123))).eq(cast(col("c1"), DataType::Date32)); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + Ok(()) + } + + #[test] + fn row_group_predicate_dict_dict_same_value_type() -> Result<()> { + // Test with Dictionary types that have the same value type but different key types + let schema = Schema::new(vec![Field::new( + "c1", + DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Utf8)), + false, + )]); + + // Direct comparison with no cast + let expr = col("c1").eq(lit(ScalarValue::Utf8(Some("test".to_string())))); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + let expected_expr = + "c1_null_count@2 != row_count@3 AND c1_min@0 <= test AND test <= c1_max@1"; + assert_eq!(predicate_expr.to_string(), expected_expr); + + // Test with column cast to a dictionary with different key type + let expr = cast( + col("c1"), + DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Utf8)), + ) + .eq(lit(ScalarValue::Utf8(Some("test".to_string())))); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + let expected_expr = "c1_null_count@2 != row_count@3 AND CAST(c1_min@0 AS Dictionary(UInt16, Utf8)) <= test AND test <= CAST(c1_max@1 AS Dictionary(UInt16, Utf8))"; + assert_eq!(predicate_expr.to_string(), expected_expr); + + Ok(()) + } + + #[test] + fn row_group_predicate_dict_dict_different_value_type() -> Result<()> { + // Test with Dictionary types that have different value types + let schema = Schema::new(vec![Field::new( + "c1", + DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Int32)), + false, + )]); + let expected_expr = "c1_null_count@2 != row_count@3 AND CAST(c1_min@0 AS Int64) <= 123 AND 123 <= CAST(c1_max@1 AS Int64)"; + + // Test with literal of a different type + let expr = + cast(col("c1"), DataType::Int64).eq(lit(ScalarValue::Int64(Some(123)))); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + Ok(()) + } + + #[test] + fn row_group_predicate_nested_dict() -> Result<()> { + // Test with nested Dictionary types + let schema = Schema::new(vec![Field::new( + "c1", + DataType::Dictionary( + Box::new(DataType::UInt8), + Box::new(DataType::Dictionary( + Box::new(DataType::UInt16), + Box::new(DataType::Utf8), + )), + ), + false, + )]); + let expected_expr = + "c1_null_count@2 != row_count@3 AND c1_min@0 <= test AND test <= c1_max@1"; + + // Test with a simple literal + let expr = col("c1").eq(lit(ScalarValue::Utf8(Some("test".to_string())))); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + Ok(()) + } + + #[test] + fn row_group_predicate_dict_date_dict_date() -> Result<()> { + // Test with dictionary-wrapped date types for both sides + let schema = Schema::new(vec![Field::new( + "c1", + DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Date32)), + false, + )]); + let expected_expr = "c1_null_count@2 != row_count@3 AND CAST(c1_min@0 AS Dictionary(UInt16, Date64)) <= 1970-01-01 AND 1970-01-01 <= CAST(c1_max@1 AS Dictionary(UInt16, Date64))"; + + // Test with a cast to a different date type + let expr = cast( + col("c1"), + DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Date64)), + ) + .eq(lit(ScalarValue::Date64(Some(123)))); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + Ok(()) + } + + #[test] + fn row_group_predicate_date_string() -> Result<()> { + let schema = Schema::new(vec![Field::new("c1", DataType::Utf8, false)]); + let expected_expr = "true"; + + // test column on the left + let expr = + cast(col("c1"), DataType::Date32).eq(lit(ScalarValue::Date32(Some(123)))); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + // test column on the right + let expr = + lit(ScalarValue::Date32(Some(123))).eq(cast(col("c1"), DataType::Date32)); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + Ok(()) + } + + #[test] + fn row_group_predicate_string_date() -> Result<()> { + let schema = Schema::new(vec![Field::new("c1", DataType::Date32, false)]); + let expected_expr = "true"; + + // test column on the left + let expr = cast(col("c1"), DataType::Utf8) + .eq(lit(ScalarValue::Utf8(Some("2024-01-01".to_string())))); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + // test column on the right + let expr = lit(ScalarValue::Utf8(Some("2024-01-01".to_string()))) + .eq(cast(col("c1"), DataType::Utf8)); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + Ok(()) + } + #[test] fn row_group_predicate_cast_list() -> Result<()> { let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]);