From f103fa8ccad65dd950fe9aa8f1a5ad57ea9357bd Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Sun, 27 Oct 2024 12:51:21 +0400 Subject: [PATCH] fix: Don't panic in SQL temporal string check; raise suitable `ColumnNotFound` error (#19473) --- crates/polars-sql/src/sql_expr.rs | 21 ++++++++-------- .../tests/unit/sql/test_miscellaneous.py | 25 ++++++++++++++++++- 2 files changed, 35 insertions(+), 11 deletions(-) diff --git a/crates/polars-sql/src/sql_expr.rs b/crates/polars-sql/src/sql_expr.rs index a6fe495d1ba5..5eb2bdd843b4 100644 --- a/crates/polars-sql/src/sql_expr.rs +++ b/crates/polars-sql/src/sql_expr.rs @@ -374,10 +374,9 @@ impl SQLExprVisitor<'_> { }, // identify "CAST(expr AS type) string" and/or "expr::type string" expressions (Expr::Cast { expr, dtype, .. }, Expr::Literal(LiteralValue::String(s))) => { - if let Expr::Column(name) = &**expr { - (Some(name.clone()), Some(s), Some(dtype)) - } else { - (None, Some(s), Some(dtype)) + match &**expr { + Expr::Column(name) => (Some(name.clone()), Some(s), Some(dtype)), + _ => (None, Some(s), Some(dtype)), } }, _ => (None, None, None), @@ -385,23 +384,25 @@ impl SQLExprVisitor<'_> { if expr_dtype.is_none() && self.active_schema.is_none() { right.clone() } else { - let left_dtype = expr_dtype - .unwrap_or_else(|| self.active_schema.as_ref().unwrap().get(&name).unwrap()); - + let left_dtype = expr_dtype.or_else(|| { + self.active_schema + .as_ref() + .and_then(|schema| schema.get(&name)) + }); match left_dtype { - DataType::Time if is_iso_time(s) => { + Some(DataType::Time) if is_iso_time(s) => { right.clone().str().to_time(StrptimeOptions { strict: true, ..Default::default() }) }, - DataType::Date if is_iso_date(s) => { + Some(DataType::Date) if is_iso_date(s) => { right.clone().str().to_date(StrptimeOptions { strict: true, ..Default::default() }) }, - DataType::Datetime(tu, tz) if is_iso_datetime(s) || is_iso_date(s) => { + Some(DataType::Datetime(tu, tz)) if is_iso_datetime(s) || is_iso_date(s) => { if s.len() == 10 { // handle upcast from ISO date string (10 chars) to datetime lit(format!("{}T00:00:00", s)) diff --git a/py-polars/tests/unit/sql/test_miscellaneous.py b/py-polars/tests/unit/sql/test_miscellaneous.py index 95ba8461bebe..f7d0615e13c6 100644 --- a/py-polars/tests/unit/sql/test_miscellaneous.py +++ b/py-polars/tests/unit/sql/test_miscellaneous.py @@ -7,7 +7,7 @@ import pytest import polars as pl -from polars.exceptions import SQLInterfaceError, SQLSyntaxError +from polars.exceptions import ColumnNotFoundError, SQLInterfaceError, SQLSyntaxError from polars.testing import assert_frame_equal if TYPE_CHECKING: @@ -362,3 +362,26 @@ def test_global_variable_inference_17398() -> None: eager=True, ) assert_frame_equal(res, users) + + +@pytest.mark.parametrize( + "query", + [ + "SELECT invalid_column FROM self", + "SELECT key, invalid_column FROM self", + "SELECT invalid_column * 2 FROM self", + "SELECT * FROM self ORDER BY invalid_column", + "SELECT * FROM self WHERE invalid_column = 200", + "SELECT * FROM self WHERE invalid_column = '200'", + "SELECT key, SUM(n) AS sum_n FROM self GROUP BY invalid_column", + ], +) +def test_invalid_cols(query: str) -> None: + df = pl.DataFrame( + { + "key": ["xx", "xx", "yy"], + "n": ["100", "200", "300"], + } + ) + with pytest.raises(ColumnNotFoundError, match="invalid_column"): + df.sql(query)