Skip to content

Commit

Permalink
feat: Add SQL support for TRY_CAST function
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-beedie committed May 29, 2024
1 parent ef1c0c8 commit 2097e7e
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 4 deletions.
18 changes: 14 additions & 4 deletions crates/polars-sql/src/sql_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ impl SQLExprVisitor<'_> {
expr,
data_type,
format,
} => self.visit_cast(expr, data_type, format),
} => self.visit_cast(expr, data_type, format, true),
SQLExpr::Ceil { expr, .. } => Ok(self.visit_expr(expr)?.ceil()),
SQLExpr::CompoundIdentifier(idents) => self.visit_compound_identifier(idents),
SQLExpr::Extract { field, expr } => parse_extract(self.visit_expr(expr)?, field),
Expand Down Expand Up @@ -280,6 +280,11 @@ impl SQLExprVisitor<'_> {
trim_what,
trim_characters,
} => self.visit_trim(expr, trim_where, trim_what, trim_characters),
SQLExpr::TryCast {
expr,
data_type,
format,
} => self.visit_cast(expr, data_type, format, false),
SQLExpr::UnaryOp { op, expr } => self.visit_unary_op(op, expr),
SQLExpr::Value(value) => self.visit_literal(value),
e @ SQLExpr::Case { .. } => self.visit_case_when_then(e),
Expand Down Expand Up @@ -610,14 +615,15 @@ impl SQLExprVisitor<'_> {
}
}

/// Visit a SQL `CAST` expression.
/// Visit a SQL `CAST` or `TRY_CAST` expression.
///
/// e.g. `CAST(column AS INT)` or `column::INT`
/// e.g. `CAST(col AS INT)`, `col::int4`, or `TRY_CAST(col AS VARCHAR)`,
fn visit_cast(
&mut self,
expr: &SQLExpr,
data_type: &SQLDataType,
format: &Option<CastFormat>,
strict: bool,
) -> PolarsResult<Expr> {
if format.is_some() {
return Err(polars_err!(ComputeError: "unsupported use of FORMAT in CAST expression"));
Expand All @@ -629,7 +635,11 @@ impl SQLExprVisitor<'_> {
return Ok(expr.str().json_decode(None, None));
}
let polars_type = map_sql_polars_datatype(data_type)?;
Ok(expr.strict_cast(polars_type))
Ok(if strict {
expr.strict_cast(polars_type)
} else {
expr.cast(polars_type)
})
}

/// Visit a SQL literal.
Expand Down
6 changes: 6 additions & 0 deletions py-polars/tests/unit/sql/test_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,15 @@ def test_cast() -> None:
def test_cast_errors(values: Any, cast_op: str, error: str) -> None:
df = pl.DataFrame({"values": values})

# invalid CAST should raise an error...
with pytest.raises(ComputeError, match=error):
df.sql(f"SELECT {cast_op} FROM df")

# ... or return `null` values if using TRY_CAST
target_type = cast_op.split("::")[1]
res = df.sql(f"SELECT TRY_CAST(values AS {target_type}) AS cast_values FROM df")
assert None in res.to_series()


def test_cast_json() -> None:
df = pl.DataFrame({"txt": ['{"a":[1,2,3],"b":["x","y","z"],"c":5.0}']})
Expand Down

0 comments on commit 2097e7e

Please sign in to comment.