Skip to content

Commit 1fe856b

Browse files
authored
predicate pruning: support cast and try_cast for more types (#15764)
* predicate pruning: support dictionaries * more types * clippy * add tests * add tests * simplify to dicts * revert most changes * just check for strings, more tests * more tests * remove unecessary now confusing clause
1 parent 11088b9 commit 1fe856b

File tree

1 file changed

+314
-14
lines changed

1 file changed

+314
-14
lines changed

datafusion/physical-optimizer/src/pruning.rs

+314-14
Original file line numberDiff line numberDiff line change
@@ -1210,23 +1210,35 @@ fn is_compare_op(op: Operator) -> bool {
12101210
)
12111211
}
12121212

1213+
fn is_string_type(data_type: &DataType) -> bool {
1214+
matches!(
1215+
data_type,
1216+
DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View
1217+
)
1218+
}
1219+
12131220
// The pruning logic is based on the comparing the min/max bounds.
12141221
// Must make sure the two type has order.
12151222
// For example, casts from string to numbers is not correct.
12161223
// Because the "13" is less than "3" with UTF8 comparison order.
12171224
fn verify_support_type_for_prune(from_type: &DataType, to_type: &DataType) -> Result<()> {
1218-
// TODO: support other data type for prunable cast or try cast
1219-
if matches!(
1220-
from_type,
1221-
DataType::Int8
1222-
| DataType::Int16
1223-
| DataType::Int32
1224-
| DataType::Int64
1225-
| DataType::Decimal128(_, _)
1226-
) && matches!(
1227-
to_type,
1228-
DataType::Int8 | DataType::Int32 | DataType::Int64 | DataType::Decimal128(_, _)
1229-
) {
1225+
// Dictionary casts are always supported as long as the value types are supported
1226+
let from_type = match from_type {
1227+
DataType::Dictionary(_, t) => {
1228+
return verify_support_type_for_prune(t.as_ref(), to_type)
1229+
}
1230+
_ => from_type,
1231+
};
1232+
let to_type = match to_type {
1233+
DataType::Dictionary(_, t) => {
1234+
return verify_support_type_for_prune(from_type, t.as_ref())
1235+
}
1236+
_ => to_type,
1237+
};
1238+
// If both types are strings or both are not strings (number, timestamp, etc)
1239+
// then we can compare them.
1240+
// PruningPredicate does not support casting of strings to numbers and such.
1241+
if is_string_type(from_type) == is_string_type(to_type) {
12301242
Ok(())
12311243
} else {
12321244
plan_err!(
@@ -1544,7 +1556,10 @@ fn build_predicate_expression(
15441556
Ok(builder) => builder,
15451557
// allow partial failure in predicate expression generation
15461558
// this can still produce a useful predicate when multiple conditions are joined using AND
1547-
Err(_) => return unhandled_hook.handle(expr),
1559+
Err(e) => {
1560+
dbg!(format!("Error building pruning expression: {e}"));
1561+
return unhandled_hook.handle(expr);
1562+
}
15481563
};
15491564

15501565
build_statistics_expr(&mut expr_builder)
@@ -3006,7 +3021,7 @@ mod tests {
30063021
}
30073022

30083023
#[test]
3009-
fn row_group_predicate_cast() -> Result<()> {
3024+
fn row_group_predicate_cast_int_int() -> Result<()> {
30103025
let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]);
30113026
let expected_expr = "c1_null_count@2 != row_count@3 AND CAST(c1_min@0 AS Int64) <= 1 AND 1 <= CAST(c1_max@1 AS Int64)";
30123027

@@ -3043,6 +3058,291 @@ mod tests {
30433058
Ok(())
30443059
}
30453060

3061+
#[test]
3062+
fn row_group_predicate_cast_string_string() -> Result<()> {
3063+
let schema = Schema::new(vec![Field::new("c1", DataType::Utf8View, false)]);
3064+
let expected_expr = "c1_null_count@2 != row_count@3 AND CAST(c1_min@0 AS Utf8) <= 1 AND 1 <= CAST(c1_max@1 AS Utf8)";
3065+
3066+
// test column on the left
3067+
let expr = cast(col("c1"), DataType::Utf8)
3068+
.eq(lit(ScalarValue::Utf8(Some("1".to_string()))));
3069+
let predicate_expr =
3070+
test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new());
3071+
assert_eq!(predicate_expr.to_string(), expected_expr);
3072+
3073+
// test column on the right
3074+
let expr = lit(ScalarValue::Utf8(Some("1".to_string())))
3075+
.eq(cast(col("c1"), DataType::Utf8));
3076+
let predicate_expr =
3077+
test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new());
3078+
assert_eq!(predicate_expr.to_string(), expected_expr);
3079+
3080+
Ok(())
3081+
}
3082+
3083+
#[test]
3084+
fn row_group_predicate_cast_string_int() -> Result<()> {
3085+
let schema = Schema::new(vec![Field::new("c1", DataType::Utf8View, false)]);
3086+
let expected_expr = "true";
3087+
3088+
// test column on the left
3089+
let expr = cast(col("c1"), DataType::Int32).eq(lit(ScalarValue::Int32(Some(1))));
3090+
let predicate_expr =
3091+
test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new());
3092+
assert_eq!(predicate_expr.to_string(), expected_expr);
3093+
3094+
// test column on the right
3095+
let expr = lit(ScalarValue::Int32(Some(1))).eq(cast(col("c1"), DataType::Int32));
3096+
let predicate_expr =
3097+
test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new());
3098+
assert_eq!(predicate_expr.to_string(), expected_expr);
3099+
3100+
Ok(())
3101+
}
3102+
3103+
#[test]
3104+
fn row_group_predicate_cast_int_string() -> Result<()> {
3105+
let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]);
3106+
let expected_expr = "true";
3107+
3108+
// test column on the left
3109+
let expr = cast(col("c1"), DataType::Utf8)
3110+
.eq(lit(ScalarValue::Utf8(Some("1".to_string()))));
3111+
let predicate_expr =
3112+
test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new());
3113+
assert_eq!(predicate_expr.to_string(), expected_expr);
3114+
3115+
// test column on the right
3116+
let expr = lit(ScalarValue::Utf8(Some("1".to_string())))
3117+
.eq(cast(col("c1"), DataType::Utf8));
3118+
let predicate_expr =
3119+
test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new());
3120+
assert_eq!(predicate_expr.to_string(), expected_expr);
3121+
3122+
Ok(())
3123+
}
3124+
3125+
#[test]
3126+
fn row_group_predicate_date_date() -> Result<()> {
3127+
let schema = Schema::new(vec![Field::new("c1", DataType::Date32, false)]);
3128+
let expected_expr = "c1_null_count@2 != row_count@3 AND CAST(c1_min@0 AS Date64) <= 1970-01-01 AND 1970-01-01 <= CAST(c1_max@1 AS Date64)";
3129+
3130+
// test column on the left
3131+
let expr =
3132+
cast(col("c1"), DataType::Date64).eq(lit(ScalarValue::Date64(Some(123))));
3133+
let predicate_expr =
3134+
test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new());
3135+
assert_eq!(predicate_expr.to_string(), expected_expr);
3136+
3137+
// test column on the right
3138+
let expr =
3139+
lit(ScalarValue::Date64(Some(123))).eq(cast(col("c1"), DataType::Date64));
3140+
let predicate_expr =
3141+
test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new());
3142+
assert_eq!(predicate_expr.to_string(), expected_expr);
3143+
3144+
Ok(())
3145+
}
3146+
3147+
#[test]
3148+
fn row_group_predicate_dict_string_date() -> Result<()> {
3149+
// Test with Dictionary<UInt8, Utf8> for the literal
3150+
let schema = Schema::new(vec![Field::new("c1", DataType::Date32, false)]);
3151+
let expected_expr = "true";
3152+
3153+
// test column on the left
3154+
let expr = cast(
3155+
col("c1"),
3156+
DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Utf8)),
3157+
)
3158+
.eq(lit(ScalarValue::Utf8(Some("2024-01-01".to_string()))));
3159+
let predicate_expr =
3160+
test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new());
3161+
assert_eq!(predicate_expr.to_string(), expected_expr);
3162+
3163+
// test column on the right
3164+
let expr = lit(ScalarValue::Utf8(Some("2024-01-01".to_string()))).eq(cast(
3165+
col("c1"),
3166+
DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Utf8)),
3167+
));
3168+
let predicate_expr =
3169+
test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new());
3170+
assert_eq!(predicate_expr.to_string(), expected_expr);
3171+
3172+
Ok(())
3173+
}
3174+
3175+
#[test]
3176+
fn row_group_predicate_date_dict_string() -> Result<()> {
3177+
// Test with Dictionary<UInt8, Utf8> for the column
3178+
let schema = Schema::new(vec![Field::new(
3179+
"c1",
3180+
DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Utf8)),
3181+
false,
3182+
)]);
3183+
let expected_expr = "true";
3184+
3185+
// test column on the left
3186+
let expr =
3187+
cast(col("c1"), DataType::Date32).eq(lit(ScalarValue::Date32(Some(123))));
3188+
let predicate_expr =
3189+
test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new());
3190+
assert_eq!(predicate_expr.to_string(), expected_expr);
3191+
3192+
// test column on the right
3193+
let expr =
3194+
lit(ScalarValue::Date32(Some(123))).eq(cast(col("c1"), DataType::Date32));
3195+
let predicate_expr =
3196+
test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new());
3197+
assert_eq!(predicate_expr.to_string(), expected_expr);
3198+
3199+
Ok(())
3200+
}
3201+
3202+
#[test]
3203+
fn row_group_predicate_dict_dict_same_value_type() -> Result<()> {
3204+
// Test with Dictionary types that have the same value type but different key types
3205+
let schema = Schema::new(vec![Field::new(
3206+
"c1",
3207+
DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Utf8)),
3208+
false,
3209+
)]);
3210+
3211+
// Direct comparison with no cast
3212+
let expr = col("c1").eq(lit(ScalarValue::Utf8(Some("test".to_string()))));
3213+
let predicate_expr =
3214+
test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new());
3215+
let expected_expr =
3216+
"c1_null_count@2 != row_count@3 AND c1_min@0 <= test AND test <= c1_max@1";
3217+
assert_eq!(predicate_expr.to_string(), expected_expr);
3218+
3219+
// Test with column cast to a dictionary with different key type
3220+
let expr = cast(
3221+
col("c1"),
3222+
DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Utf8)),
3223+
)
3224+
.eq(lit(ScalarValue::Utf8(Some("test".to_string()))));
3225+
let predicate_expr =
3226+
test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new());
3227+
let expected_expr = "c1_null_count@2 != row_count@3 AND CAST(c1_min@0 AS Dictionary(UInt16, Utf8)) <= test AND test <= CAST(c1_max@1 AS Dictionary(UInt16, Utf8))";
3228+
assert_eq!(predicate_expr.to_string(), expected_expr);
3229+
3230+
Ok(())
3231+
}
3232+
3233+
#[test]
3234+
fn row_group_predicate_dict_dict_different_value_type() -> Result<()> {
3235+
// Test with Dictionary types that have different value types
3236+
let schema = Schema::new(vec![Field::new(
3237+
"c1",
3238+
DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Int32)),
3239+
false,
3240+
)]);
3241+
let expected_expr = "c1_null_count@2 != row_count@3 AND CAST(c1_min@0 AS Int64) <= 123 AND 123 <= CAST(c1_max@1 AS Int64)";
3242+
3243+
// Test with literal of a different type
3244+
let expr =
3245+
cast(col("c1"), DataType::Int64).eq(lit(ScalarValue::Int64(Some(123))));
3246+
let predicate_expr =
3247+
test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new());
3248+
assert_eq!(predicate_expr.to_string(), expected_expr);
3249+
3250+
Ok(())
3251+
}
3252+
3253+
#[test]
3254+
fn row_group_predicate_nested_dict() -> Result<()> {
3255+
// Test with nested Dictionary types
3256+
let schema = Schema::new(vec![Field::new(
3257+
"c1",
3258+
DataType::Dictionary(
3259+
Box::new(DataType::UInt8),
3260+
Box::new(DataType::Dictionary(
3261+
Box::new(DataType::UInt16),
3262+
Box::new(DataType::Utf8),
3263+
)),
3264+
),
3265+
false,
3266+
)]);
3267+
let expected_expr =
3268+
"c1_null_count@2 != row_count@3 AND c1_min@0 <= test AND test <= c1_max@1";
3269+
3270+
// Test with a simple literal
3271+
let expr = col("c1").eq(lit(ScalarValue::Utf8(Some("test".to_string()))));
3272+
let predicate_expr =
3273+
test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new());
3274+
assert_eq!(predicate_expr.to_string(), expected_expr);
3275+
3276+
Ok(())
3277+
}
3278+
3279+
#[test]
3280+
fn row_group_predicate_dict_date_dict_date() -> Result<()> {
3281+
// Test with dictionary-wrapped date types for both sides
3282+
let schema = Schema::new(vec![Field::new(
3283+
"c1",
3284+
DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Date32)),
3285+
false,
3286+
)]);
3287+
let expected_expr = "c1_null_count@2 != row_count@3 AND CAST(c1_min@0 AS Dictionary(UInt16, Date64)) <= 1970-01-01 AND 1970-01-01 <= CAST(c1_max@1 AS Dictionary(UInt16, Date64))";
3288+
3289+
// Test with a cast to a different date type
3290+
let expr = cast(
3291+
col("c1"),
3292+
DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Date64)),
3293+
)
3294+
.eq(lit(ScalarValue::Date64(Some(123))));
3295+
let predicate_expr =
3296+
test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new());
3297+
assert_eq!(predicate_expr.to_string(), expected_expr);
3298+
3299+
Ok(())
3300+
}
3301+
3302+
#[test]
3303+
fn row_group_predicate_date_string() -> Result<()> {
3304+
let schema = Schema::new(vec![Field::new("c1", DataType::Utf8, false)]);
3305+
let expected_expr = "true";
3306+
3307+
// test column on the left
3308+
let expr =
3309+
cast(col("c1"), DataType::Date32).eq(lit(ScalarValue::Date32(Some(123))));
3310+
let predicate_expr =
3311+
test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new());
3312+
assert_eq!(predicate_expr.to_string(), expected_expr);
3313+
3314+
// test column on the right
3315+
let expr =
3316+
lit(ScalarValue::Date32(Some(123))).eq(cast(col("c1"), DataType::Date32));
3317+
let predicate_expr =
3318+
test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new());
3319+
assert_eq!(predicate_expr.to_string(), expected_expr);
3320+
3321+
Ok(())
3322+
}
3323+
3324+
#[test]
3325+
fn row_group_predicate_string_date() -> Result<()> {
3326+
let schema = Schema::new(vec![Field::new("c1", DataType::Date32, false)]);
3327+
let expected_expr = "true";
3328+
3329+
// test column on the left
3330+
let expr = cast(col("c1"), DataType::Utf8)
3331+
.eq(lit(ScalarValue::Utf8(Some("2024-01-01".to_string()))));
3332+
let predicate_expr =
3333+
test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new());
3334+
assert_eq!(predicate_expr.to_string(), expected_expr);
3335+
3336+
// test column on the right
3337+
let expr = lit(ScalarValue::Utf8(Some("2024-01-01".to_string())))
3338+
.eq(cast(col("c1"), DataType::Utf8));
3339+
let predicate_expr =
3340+
test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new());
3341+
assert_eq!(predicate_expr.to_string(), expected_expr);
3342+
3343+
Ok(())
3344+
}
3345+
30463346
#[test]
30473347
fn row_group_predicate_cast_list() -> Result<()> {
30483348
let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]);

0 commit comments

Comments
 (0)