diff --git a/crates/polars-sql/src/sql_expr.rs b/crates/polars-sql/src/sql_expr.rs index 7a31f8bbb988..7d080d6d04d9 100644 --- a/crates/polars-sql/src/sql_expr.rs +++ b/crates/polars-sql/src/sql_expr.rs @@ -209,7 +209,7 @@ impl SQLExprVisitor<'_> { expr, data_type, format, - } => self.visit_cast(expr, data_type, format), + } => self.visit_cast(expr, data_type, format, true), SQLExpr::Ceil { expr, .. } => Ok(self.visit_expr(expr)?.ceil()), SQLExpr::CompoundIdentifier(idents) => self.visit_compound_identifier(idents), SQLExpr::Extract { field, expr } => parse_extract(self.visit_expr(expr)?, field), @@ -280,6 +280,11 @@ impl SQLExprVisitor<'_> { trim_what, trim_characters, } => self.visit_trim(expr, trim_where, trim_what, trim_characters), + SQLExpr::TryCast { + expr, + data_type, + format, + } => self.visit_cast(expr, data_type, format, false), SQLExpr::UnaryOp { op, expr } => self.visit_unary_op(op, expr), SQLExpr::Value(value) => self.visit_literal(value), e @ SQLExpr::Case { .. } => self.visit_case_when_then(e), @@ -610,14 +615,15 @@ impl SQLExprVisitor<'_> { } } - /// Visit a SQL `CAST` expression. + /// Visit a SQL `CAST` or `TRY_CAST` expression. /// - /// e.g. `CAST(column AS INT)` or `column::INT` + /// e.g. `CAST(col AS INT)`, `col::int4`, or `TRY_CAST(col AS VARCHAR)`, fn visit_cast( &mut self, expr: &SQLExpr, data_type: &SQLDataType, format: &Option, + strict: bool, ) -> PolarsResult { if format.is_some() { return Err(polars_err!(ComputeError: "unsupported use of FORMAT in CAST expression")); @@ -629,7 +635,11 @@ impl SQLExprVisitor<'_> { return Ok(expr.str().json_decode(None, None)); } let polars_type = map_sql_polars_datatype(data_type)?; - Ok(expr.strict_cast(polars_type)) + Ok(if strict { + expr.strict_cast(polars_type) + } else { + expr.cast(polars_type) + }) } /// Visit a SQL literal. diff --git a/py-polars/tests/unit/sql/test_cast.py b/py-polars/tests/unit/sql/test_cast.py index 0f5cc61c1dc9..baccfbd74b06 100644 --- a/py-polars/tests/unit/sql/test_cast.py +++ b/py-polars/tests/unit/sql/test_cast.py @@ -161,9 +161,15 @@ def test_cast() -> None: def test_cast_errors(values: Any, cast_op: str, error: str) -> None: df = pl.DataFrame({"values": values}) + # invalid CAST should raise an error... with pytest.raises(ComputeError, match=error): df.sql(f"SELECT {cast_op} FROM df") + # ... or return `null` values if using TRY_CAST + target_type = cast_op.split("::")[1] + res = df.sql(f"SELECT TRY_CAST(values AS {target_type}) AS cast_values FROM df") + assert None in res.to_series() + def test_cast_json() -> None: df = pl.DataFrame({"txt": ['{"a":[1,2,3],"b":["x","y","z"],"c":5.0}']})