Skip to content

predicate pruning: support cast and try_cast for more types #15764

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Apr 24, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
328 changes: 314 additions & 14 deletions datafusion/physical-optimizer/src/pruning.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1210,23 +1210,35 @@ fn is_compare_op(op: Operator) -> bool {
)
}

fn is_string_type(data_type: &DataType) -> bool {
matches!(
data_type,
DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View
)
}

// The pruning logic is based on the comparing the min/max bounds.
// Must make sure the two type has order.
// For example, casts from string to numbers is not correct.
// Because the "13" is less than "3" with UTF8 comparison order.
fn verify_support_type_for_prune(from_type: &DataType, to_type: &DataType) -> Result<()> {
// TODO: support other data type for prunable cast or try cast
if matches!(
from_type,
DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::Decimal128(_, _)
) && matches!(
to_type,
DataType::Int8 | DataType::Int32 | DataType::Int64 | DataType::Decimal128(_, _)
) {
// Dictionary casts are always supported as long as the value types are supported
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't remember why we we even have this code -- I think it may predate unwrapping casts in the logical planning phase so by the time this code sees physical expressions it shouldn't have to add casts at all 🤔

Copy link

@etseidl etseidl Apr 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the big issue is when comparing string to numeric (both sides will be coerced to string IIUC). Rather than a whitelist here, maybe instead disallow casts we know will fail, so this could maybe just return Err if the LHS is numeric or temporal and the RHS is string.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @etseidl, I'll implement that instead. I think we do need to keep the dict part in here.

Tangential but do any of you know where we have functionality to simplify/adapt a PhysicalExpr given a schema eg to remove unnecessary casts or add cases where necessary?

let from_type = match from_type {
DataType::Dictionary(_, t) => {
return verify_support_type_for_prune(t.as_ref(), to_type)
}
_ => from_type,
};
let to_type = match to_type {
DataType::Dictionary(_, t) => {
return verify_support_type_for_prune(from_type, t.as_ref())
}
_ => to_type,
};
// If both types are strings or both are not strings (number, timestamp, etc)
// then we can compare them.
// PruningPredicate does not support casting of strings to numbers and such.
if is_string_type(from_type) == is_string_type(to_type) {
Ok(())
} else {
plan_err!(
Expand Down Expand Up @@ -1544,7 +1556,10 @@ fn build_predicate_expression(
Ok(builder) => builder,
// allow partial failure in predicate expression generation
// this can still produce a useful predicate when multiple conditions are joined using AND
Err(_) => return unhandled_hook.handle(expr),
Err(e) => {
dbg!(format!("Error building pruning expression: {e}"));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recommend changing this to debug! (dbug logging rather than printing to stdout on failure)

return unhandled_hook.handle(expr);
}
};

build_statistics_expr(&mut expr_builder)
Expand Down Expand Up @@ -3006,7 +3021,7 @@ mod tests {
}

#[test]
fn row_group_predicate_cast() -> Result<()> {
fn row_group_predicate_cast_int_int() -> Result<()> {
let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]);
let expected_expr = "c1_null_count@2 != row_count@3 AND CAST(c1_min@0 AS Int64) <= 1 AND 1 <= CAST(c1_max@1 AS Int64)";

Expand Down Expand Up @@ -3043,6 +3058,291 @@ mod tests {
Ok(())
}

#[test]
fn row_group_predicate_cast_string_string() -> Result<()> {
let schema = Schema::new(vec![Field::new("c1", DataType::Utf8View, false)]);
let expected_expr = "c1_null_count@2 != row_count@3 AND CAST(c1_min@0 AS Utf8) <= 1 AND 1 <= CAST(c1_max@1 AS Utf8)";

// test column on the left
let expr = cast(col("c1"), DataType::Utf8)
.eq(lit(ScalarValue::Utf8(Some("1".to_string()))));
let predicate_expr =
test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new());
assert_eq!(predicate_expr.to_string(), expected_expr);

// test column on the right
let expr = lit(ScalarValue::Utf8(Some("1".to_string())))
.eq(cast(col("c1"), DataType::Utf8));
let predicate_expr =
test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new());
assert_eq!(predicate_expr.to_string(), expected_expr);

Ok(())
}

#[test]
fn row_group_predicate_cast_string_int() -> Result<()> {
let schema = Schema::new(vec![Field::new("c1", DataType::Utf8View, false)]);
let expected_expr = "true";

// test column on the left
let expr = cast(col("c1"), DataType::Int32).eq(lit(ScalarValue::Int32(Some(1))));
let predicate_expr =
test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new());
assert_eq!(predicate_expr.to_string(), expected_expr);

// test column on the right
let expr = lit(ScalarValue::Int32(Some(1))).eq(cast(col("c1"), DataType::Int32));
let predicate_expr =
test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new());
assert_eq!(predicate_expr.to_string(), expected_expr);

Ok(())
}

#[test]
fn row_group_predicate_cast_int_string() -> Result<()> {
let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]);
let expected_expr = "true";

// test column on the left
let expr = cast(col("c1"), DataType::Utf8)
.eq(lit(ScalarValue::Utf8(Some("1".to_string()))));
let predicate_expr =
test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new());
assert_eq!(predicate_expr.to_string(), expected_expr);

// test column on the right
let expr = lit(ScalarValue::Utf8(Some("1".to_string())))
.eq(cast(col("c1"), DataType::Utf8));
let predicate_expr =
test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new());
assert_eq!(predicate_expr.to_string(), expected_expr);

Ok(())
}

#[test]
fn row_group_predicate_date_date() -> Result<()> {
let schema = Schema::new(vec![Field::new("c1", DataType::Date32, false)]);
let expected_expr = "c1_null_count@2 != row_count@3 AND CAST(c1_min@0 AS Date64) <= 1970-01-01 AND 1970-01-01 <= CAST(c1_max@1 AS Date64)";

// test column on the left
let expr =
cast(col("c1"), DataType::Date64).eq(lit(ScalarValue::Date64(Some(123))));
let predicate_expr =
test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new());
assert_eq!(predicate_expr.to_string(), expected_expr);

// test column on the right
let expr =
lit(ScalarValue::Date64(Some(123))).eq(cast(col("c1"), DataType::Date64));
let predicate_expr =
test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new());
assert_eq!(predicate_expr.to_string(), expected_expr);

Ok(())
}

#[test]
fn row_group_predicate_dict_string_date() -> Result<()> {
// Test with Dictionary<UInt8, Utf8> for the literal
let schema = Schema::new(vec![Field::new("c1", DataType::Date32, false)]);
let expected_expr = "true";

// test column on the left
let expr = cast(
col("c1"),
DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Utf8)),
)
.eq(lit(ScalarValue::Utf8(Some("2024-01-01".to_string()))));
let predicate_expr =
test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new());
assert_eq!(predicate_expr.to_string(), expected_expr);

// test column on the right
let expr = lit(ScalarValue::Utf8(Some("2024-01-01".to_string()))).eq(cast(
col("c1"),
DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Utf8)),
));
let predicate_expr =
test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new());
assert_eq!(predicate_expr.to_string(), expected_expr);

Ok(())
}

#[test]
fn row_group_predicate_date_dict_string() -> Result<()> {
// Test with Dictionary<UInt8, Utf8> for the column
let schema = Schema::new(vec![Field::new(
"c1",
DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Utf8)),
false,
)]);
let expected_expr = "true";

// test column on the left
let expr =
cast(col("c1"), DataType::Date32).eq(lit(ScalarValue::Date32(Some(123))));
let predicate_expr =
test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new());
assert_eq!(predicate_expr.to_string(), expected_expr);

// test column on the right
let expr =
lit(ScalarValue::Date32(Some(123))).eq(cast(col("c1"), DataType::Date32));
let predicate_expr =
test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new());
assert_eq!(predicate_expr.to_string(), expected_expr);

Ok(())
}

#[test]
fn row_group_predicate_dict_dict_same_value_type() -> Result<()> {
// Test with Dictionary types that have the same value type but different key types
let schema = Schema::new(vec![Field::new(
"c1",
DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Utf8)),
false,
)]);

// Direct comparison with no cast
let expr = col("c1").eq(lit(ScalarValue::Utf8(Some("test".to_string()))));
let predicate_expr =
test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new());
let expected_expr =
"c1_null_count@2 != row_count@3 AND c1_min@0 <= test AND test <= c1_max@1";
assert_eq!(predicate_expr.to_string(), expected_expr);

// Test with column cast to a dictionary with different key type
let expr = cast(
col("c1"),
DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Utf8)),
)
.eq(lit(ScalarValue::Utf8(Some("test".to_string()))));
let predicate_expr =
test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new());
let expected_expr = "c1_null_count@2 != row_count@3 AND CAST(c1_min@0 AS Dictionary(UInt16, Utf8)) <= test AND test <= CAST(c1_max@1 AS Dictionary(UInt16, Utf8))";
assert_eq!(predicate_expr.to_string(), expected_expr);

Ok(())
}

#[test]
fn row_group_predicate_dict_dict_different_value_type() -> Result<()> {
// Test with Dictionary types that have different value types
let schema = Schema::new(vec![Field::new(
"c1",
DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Int32)),
false,
)]);
let expected_expr = "c1_null_count@2 != row_count@3 AND CAST(c1_min@0 AS Int64) <= 123 AND 123 <= CAST(c1_max@1 AS Int64)";

// Test with literal of a different type
let expr =
cast(col("c1"), DataType::Int64).eq(lit(ScalarValue::Int64(Some(123))));
let predicate_expr =
test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new());
assert_eq!(predicate_expr.to_string(), expected_expr);

Ok(())
}

#[test]
fn row_group_predicate_nested_dict() -> Result<()> {
// Test with nested Dictionary types
let schema = Schema::new(vec![Field::new(
"c1",
DataType::Dictionary(
Box::new(DataType::UInt8),
Box::new(DataType::Dictionary(
Box::new(DataType::UInt16),
Box::new(DataType::Utf8),
)),
),
false,
)]);
let expected_expr =
"c1_null_count@2 != row_count@3 AND c1_min@0 <= test AND test <= c1_max@1";

// Test with a simple literal
let expr = col("c1").eq(lit(ScalarValue::Utf8(Some("test".to_string()))));
let predicate_expr =
test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new());
assert_eq!(predicate_expr.to_string(), expected_expr);

Ok(())
}

#[test]
fn row_group_predicate_dict_date_dict_date() -> Result<()> {
// Test with dictionary-wrapped date types for both sides
let schema = Schema::new(vec![Field::new(
"c1",
DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Date32)),
false,
)]);
let expected_expr = "c1_null_count@2 != row_count@3 AND CAST(c1_min@0 AS Dictionary(UInt16, Date64)) <= 1970-01-01 AND 1970-01-01 <= CAST(c1_max@1 AS Dictionary(UInt16, Date64))";

// Test with a cast to a different date type
let expr = cast(
col("c1"),
DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Date64)),
)
.eq(lit(ScalarValue::Date64(Some(123))));
let predicate_expr =
test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new());
assert_eq!(predicate_expr.to_string(), expected_expr);

Ok(())
}

#[test]
fn row_group_predicate_date_string() -> Result<()> {
let schema = Schema::new(vec![Field::new("c1", DataType::Utf8, false)]);
let expected_expr = "true";

// test column on the left
let expr =
cast(col("c1"), DataType::Date32).eq(lit(ScalarValue::Date32(Some(123))));
let predicate_expr =
test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new());
assert_eq!(predicate_expr.to_string(), expected_expr);

// test column on the right
let expr =
lit(ScalarValue::Date32(Some(123))).eq(cast(col("c1"), DataType::Date32));
let predicate_expr =
test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new());
assert_eq!(predicate_expr.to_string(), expected_expr);

Ok(())
}

#[test]
fn row_group_predicate_string_date() -> Result<()> {
let schema = Schema::new(vec![Field::new("c1", DataType::Date32, false)]);
let expected_expr = "true";

// test column on the left
let expr = cast(col("c1"), DataType::Utf8)
.eq(lit(ScalarValue::Utf8(Some("2024-01-01".to_string()))));
let predicate_expr =
test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new());
assert_eq!(predicate_expr.to_string(), expected_expr);

// test column on the right
let expr = lit(ScalarValue::Utf8(Some("2024-01-01".to_string())))
.eq(cast(col("c1"), DataType::Utf8));
let predicate_expr =
test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new());
assert_eq!(predicate_expr.to_string(), expected_expr);

Ok(())
}

#[test]
fn row_group_predicate_cast_list() -> Result<()> {
let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]);
Expand Down