Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support PostgreSQL ^@ ("starts with"), and ~~,~~*,!~~,!~~* ("like", "ilike") string-matching operators #17251

Merged
merged 1 commit into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 64 additions & 32 deletions crates/polars-sql/src/sql_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -527,55 +527,87 @@ impl SQLExprVisitor<'_> {
op: &BinaryOperator,
right: &SQLExpr,
) -> PolarsResult<Expr> {
let left = self.visit_expr(left)?;
let mut right = self.visit_expr(right)?;
right = self.convert_temporal_strings(&left, &right);
let lhs = self.visit_expr(left)?;
let mut rhs = self.visit_expr(right)?;
rhs = self.convert_temporal_strings(&lhs, &rhs);

Ok(match op {
SQLBinaryOperator::And => left.and(right),
SQLBinaryOperator::Divide => left / right,
SQLBinaryOperator::DuckIntegerDivide => left.floor_div(right).cast(DataType::Int64),
SQLBinaryOperator::Eq => left.eq(right),
SQLBinaryOperator::Gt => left.gt(right),
SQLBinaryOperator::GtEq => left.gt_eq(right),
SQLBinaryOperator::Lt => left.lt(right),
SQLBinaryOperator::LtEq => left.lt_eq(right),
SQLBinaryOperator::Minus => left - right,
SQLBinaryOperator::Modulo => left % right,
SQLBinaryOperator::Multiply => left * right,
SQLBinaryOperator::NotEq => left.eq(right).not(),
SQLBinaryOperator::Or => left.or(right),
SQLBinaryOperator::Plus => left + right,
SQLBinaryOperator::Spaceship => left.eq_missing(right),
SQLBinaryOperator::And => lhs.and(rhs),
SQLBinaryOperator::Divide => lhs / rhs,
SQLBinaryOperator::DuckIntegerDivide => lhs.floor_div(rhs).cast(DataType::Int64),
SQLBinaryOperator::Eq => lhs.eq(rhs),
SQLBinaryOperator::Gt => lhs.gt(rhs),
SQLBinaryOperator::GtEq => lhs.gt_eq(rhs),
SQLBinaryOperator::Lt => lhs.lt(rhs),
SQLBinaryOperator::LtEq => lhs.lt_eq(rhs),
SQLBinaryOperator::Minus => lhs - rhs,
SQLBinaryOperator::Modulo => lhs % rhs,
SQLBinaryOperator::Multiply => lhs * rhs,
SQLBinaryOperator::NotEq => lhs.eq(rhs).not(),
SQLBinaryOperator::Or => lhs.or(rhs),
SQLBinaryOperator::Plus => lhs + rhs,
SQLBinaryOperator::Spaceship => lhs.eq_missing(rhs),
SQLBinaryOperator::StringConcat => {
left.cast(DataType::String) + right.cast(DataType::String)
lhs.cast(DataType::String) + rhs.cast(DataType::String)
},
SQLBinaryOperator::Xor => left.xor(right),
SQLBinaryOperator::Xor => lhs.xor(rhs),
SQLBinaryOperator::PGStartsWith => lhs.str().starts_with(rhs),
// ----
// Regular expression operators
// ----
SQLBinaryOperator::PGRegexMatch => match right {
Expr::Literal(LiteralValue::String(_)) => left.str().contains(right, true),
_ => polars_bail!(SQLSyntax: "invalid pattern for '~' operator: {:?}", right),
// "a ~ b"
SQLBinaryOperator::PGRegexMatch => match rhs {
Expr::Literal(LiteralValue::String(_)) => lhs.str().contains(rhs, true),
_ => polars_bail!(SQLSyntax: "invalid pattern for '~' operator: {:?}", rhs),
},
SQLBinaryOperator::PGRegexNotMatch => match right {
Expr::Literal(LiteralValue::String(_)) => left.str().contains(right, true).not(),
_ => polars_bail!(SQLSyntax: "invalid pattern for '!~' operator: {:?}", right),
// "a !~ b"
SQLBinaryOperator::PGRegexNotMatch => match rhs {
Expr::Literal(LiteralValue::String(_)) => lhs.str().contains(rhs, true).not(),
_ => polars_bail!(SQLSyntax: "invalid pattern for '!~' operator: {:?}", rhs),
},
SQLBinaryOperator::PGRegexIMatch => match right {
// "a ~* b"
SQLBinaryOperator::PGRegexIMatch => match rhs {
Expr::Literal(LiteralValue::String(pat)) => {
left.str().contains(lit(format!("(?i){}", pat)), true)
lhs.str().contains(lit(format!("(?i){}", pat)), true)
},
_ => polars_bail!(SQLSyntax: "invalid pattern for '~*' operator: {:?}", right),
_ => polars_bail!(SQLSyntax: "invalid pattern for '~*' operator: {:?}", rhs),
},
SQLBinaryOperator::PGRegexNotIMatch => match right {
// "a !~* b"
SQLBinaryOperator::PGRegexNotIMatch => match rhs {
Expr::Literal(LiteralValue::String(pat)) => {
left.str().contains(lit(format!("(?i){}", pat)), true).not()
lhs.str().contains(lit(format!("(?i){}", pat)), true).not()
},
_ => {
polars_bail!(SQLSyntax: "invalid pattern for '!~*' operator: {:?}", right)
polars_bail!(SQLSyntax: "invalid pattern for '!~*' operator: {:?}", rhs)
},
},
// ----
// LIKE/ILIKE operators
// ----
SQLBinaryOperator::PGLikeMatch
| SQLBinaryOperator::PGNotLikeMatch
| SQLBinaryOperator::PGILikeMatch
| SQLBinaryOperator::PGNotILikeMatch => {
let expr = if matches!(
op,
SQLBinaryOperator::PGLikeMatch | SQLBinaryOperator::PGNotLikeMatch
) {
SQLExpr::Like {
negated: matches!(op, SQLBinaryOperator::PGNotLikeMatch),
expr: Box::new(left.clone()),
pattern: Box::new(right.clone()),
escape_char: None,
}
} else {
SQLExpr::ILike {
negated: matches!(op, SQLBinaryOperator::PGNotILikeMatch),
expr: Box::new(left.clone()),
pattern: Box::new(right.clone()),
escape_char: None,
}
};
self.visit_expr(&expr)?
},
other => {
polars_bail!(SQLInterface: "operator {:?} is not currently supported", other)
},
Expand Down
19 changes: 19 additions & 0 deletions py-polars/tests/unit/sql/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,25 @@ def test_is_between(foods_ipc_path: Path) -> None:
assert not any((22 <= cal <= 30) for cal in out["calories"])


def test_starts_with() -> None:
lf = pl.LazyFrame(
{
"x": ["aaa", "bbb", "a"],
"y": ["abc", "b", "aa"],
},
)
assert lf.sql("SELECT x ^@ 'a' AS x_starts_with_a FROM self").collect().rows() == [
(True,),
(False,),
(True,),
]
assert lf.sql("SELECT x ^@ y AS x_starts_with_y FROM self").collect().rows() == [
(False,),
(True,),
(False,),
]


@pytest.mark.parametrize("match_float", [False, True])
def test_unary_ops_8890(match_float: bool) -> None:
with pl.SQLContext(
Expand Down
22 changes: 11 additions & 11 deletions py-polars/tests/unit/sql/test_strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,15 +216,15 @@ def test_string_lengths() -> None:
("_0%_", "LIKE", [2, 4]),
("%0", "LIKE", [2]),
("0%", "LIKE", [2]),
("__0%", "LIKE", [2, 3]),
("%*%", "ILIKE", [3]),
("____", "LIKE", [4]),
("a%C", "LIKE", []),
("a%C", "ILIKE", [0, 1, 3]),
("%C?", "ILIKE", [4]),
("a0c?", "LIKE", [4]),
("000", "LIKE", [2]),
("00", "LIKE", []),
("__0%", "~~", [2, 3]),
("%*%", "~~*", [3]),
("____", "~~", [4]),
("a%C", "~~", []),
("a%C", "~~*", [0, 1, 3]),
("%C?", "~~*", [4]),
("a0c?", "~~", [4]),
("000", "~~", [2]),
("00", "~~", []),
],
)
def test_string_like(pattern: str, like: str, expected: list[int]) -> None:
Expand All @@ -235,9 +235,9 @@ def test_string_like(pattern: str, like: str, expected: list[int]) -> None:
}
)
with pl.SQLContext(df=df) as ctx:
for not_ in ("", "NOT "):
for not_ in ("", ("NOT " if like.endswith("LIKE") else "!")):
out = ctx.execute(
f"""SELECT idx FROM df WHERE txt {not_}{like} '{pattern}'"""
f"SELECT idx FROM df WHERE txt {not_}{like} '{pattern}'"
).collect()

res = out["idx"].to_list()
Expand Down