Skip to content

Commit

Permalink
fix(rust,python,cli): handle unary minus applied to numbers used in S…
Browse files Browse the repository at this point in the history
…QL `IN` clauses
  • Loading branch information
alexander-beedie committed Oct 6, 2023
1 parent 6a44fa1 commit c3d040a
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 6 deletions.
35 changes: 29 additions & 6 deletions crates/polars-sql/src/sql_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -372,17 +372,32 @@ impl SqlExprVisitor<'_> {
})
}

// similar to visit_literal, but returns an AnyValue instead of Expr
fn visit_anyvalue(&self, value: &SqlValue) -> PolarsResult<AnyValue> {
/// Visit a SQL literal (like [visit_literal]), but return AnyValue instead of Expr
fn visit_anyvalue(
&self,
value: &SqlValue,
op: Option<&UnaryOperator>,
) -> PolarsResult<AnyValue> {
Ok(match value {
SqlValue::Boolean(b) => AnyValue::Boolean(*b),
SqlValue::Null => AnyValue::Null,
SqlValue::Number(s, _) => {
let negate = match op {
Some(UnaryOperator::Minus) => true,
Some(UnaryOperator::Plus) => false,
_ => {
polars_bail!(ComputeError: "Unary op {:?} not supported for numeric SQL value", op)
},
};
// Check for existence of decimal separator dot
if s.contains('.') {
s.parse::<f64>().map(AnyValue::Float64).map_err(|_| ())
s.parse::<f64>()
.map(|n: f64| AnyValue::Float64(if negate { -n } else { n }))
.map_err(|_| ())
} else {
s.parse::<i64>().map(AnyValue::Int64).map_err(|_| ())
s.parse::<i64>()
.map(|n: i64| AnyValue::Int64(if negate { -n } else { n }))
.map_err(|_| ())
}
.map_err(|_| polars_err!(ComputeError: "cannot parse literal: {s:?}"))?
},
Expand Down Expand Up @@ -483,9 +498,17 @@ impl SqlExprVisitor<'_> {
.iter()
.map(|e| {
if let SqlExpr::Value(v) = e {
let av = self.visit_anyvalue(v)?;
let av = self.visit_anyvalue(v, None)?;
Ok(av)
} else {
} else if let SqlExpr::UnaryOp {op, expr} = e {
match expr.as_ref() {
SqlExpr::Value(v) => {
let av = self.visit_anyvalue(v, Some(op))?;
Ok(av)
},
_ => Err(polars_err!(ComputeError: "SQL expression {:?} is not yet supported", e))
}
}else{
Err(polars_err!(ComputeError: "SQL expression {:?} is not yet supported", e))
}
})
Expand Down
20 changes: 20 additions & 0 deletions py-polars/tests/unit/sql/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -1123,3 +1123,23 @@ def test_sql_expr() -> None:
pl.InvalidOperationError, match=r"Unable to parse 'xyz\.\*' as Expr"
):
pl.sql_expr("xyz.*")


@pytest.mark.parametrize("match_float", [False, True])
def test_sql_unary_ops_8890(match_float: bool) -> None:
with pl.SQLContext(
df=pl.DataFrame({"a": [-2, -1, 1, 2], "b": ["w", "x", "y", "z"]}),
) as ctx:
in_values = "(-3.0, -1.0, +2.0, +4.0)" if match_float else "(-3, -1, +2, +4)"
res = ctx.execute(
f"""
SELECT *, -(3) as c, (+4) as d
FROM df WHERE a IN {in_values}
"""
)
assert res.collect().to_dict(False) == {
"a": [-1, 2],
"b": ["x", "z"],
"c": [-3, -3],
"d": [4, 4],
}

0 comments on commit c3d040a

Please sign in to comment.