From 4d7bdb35d606f08625e83170a5bfad84eddd62f6 Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Thu, 6 Jun 2024 13:27:59 +0400 Subject: [PATCH] feat: Extend recognised `EXTRACT` and `DATE_PART` SQL part abbreviations --- crates/polars-sql/src/functions.rs | 15 +++- crates/polars-sql/src/sql_expr.rs | 74 ++++++++++--------- crates/polars-time/src/windows/duration.rs | 10 +-- .../reference/sql/functions/temporal.rst | 66 +++++++++-------- py-polars/tests/unit/sql/test_temporal.py | 65 +++++++++------- 5 files changed, 128 insertions(+), 102 deletions(-) diff --git a/crates/polars-sql/src/functions.rs b/crates/polars-sql/src/functions.rs index be78be02baff..37934dee065c 100644 --- a/crates/polars-sql/src/functions.rs +++ b/crates/polars-sql/src/functions.rs @@ -10,11 +10,11 @@ use polars_plan::prelude::col; use polars_plan::prelude::LiteralValue::Null; use polars_plan::prelude::{lit, StrptimeOptions}; use sqlparser::ast::{ - Expr as SQLExpr, Function as SQLFunction, FunctionArg, FunctionArgExpr, Value as SQLValue, - WindowSpec, WindowType, + DateTimeField, Expr as SQLExpr, Function as SQLFunction, FunctionArg, FunctionArgExpr, Ident, + Value as SQLValue, WindowSpec, WindowType, }; -use crate::sql_expr::{parse_date_part, parse_sql_expr}; +use crate::sql_expr::{parse_extract_date_part, parse_sql_expr}; use crate::SQLContext; pub(crate) struct SQLFunctionVisitor<'a> { @@ -889,7 +889,14 @@ impl SQLFunctionVisitor<'_> { }, DatePart => self.try_visit_binary(|part, e| { match part { - Expr::Literal(LiteralValue::String(p)) => parse_date_part(e, &p), + Expr::Literal(LiteralValue::String(p)) => { + // note: 'DATE_PART' and 'EXTRACT' are minor syntactic + // variations on otherwise identical functionality + parse_extract_date_part(e, &DateTimeField::Custom(Ident { + value: p, + quote_style: None, + })) + }, _ => { polars_bail!(SQLSyntax: "invalid 'part' for EXTRACT/DATE_PART: {}", function.args[1]); } diff --git a/crates/polars-sql/src/sql_expr.rs b/crates/polars-sql/src/sql_expr.rs index 8e2de9ebacb7..0f59a87f56d3 100644 --- a/crates/polars-sql/src/sql_expr.rs +++ b/crates/polars-sql/src/sql_expr.rs @@ -240,7 +240,9 @@ impl SQLExprVisitor<'_> { } => self.visit_cast(expr, data_type, format, true), SQLExpr::Ceil { expr, .. } => Ok(self.visit_expr(expr)?.ceil()), SQLExpr::CompoundIdentifier(idents) => self.visit_compound_identifier(idents), - SQLExpr::Extract { field, expr } => parse_extract(self.visit_expr(expr)?, field), + SQLExpr::Extract { field, expr } => { + parse_extract_date_part(self.visit_expr(expr)?, field) + }, SQLExpr::Floor { expr, .. } => Ok(self.visit_expr(expr)?.floor()), SQLExpr::Function(function) => self.visit_function(function), SQLExpr::Identifier(ident) => self.visit_identifier(ident), @@ -1171,7 +1173,41 @@ pub(crate) fn parse_sql_array(expr: &SQLExpr, ctx: &mut SQLContext) -> PolarsRes } } -fn parse_extract(expr: Expr, field: &DateTimeField) -> PolarsResult { +pub(crate) fn parse_extract_date_part(expr: Expr, field: &DateTimeField) -> PolarsResult { + let field = match field { + // handle 'DATE_PART' and all valid abbreviations/alternates + DateTimeField::Custom(Ident { value, .. }) => { + let value = value.to_ascii_lowercase(); + match value.as_str() { + "millennium" | "millennia" => &DateTimeField::Millennium, + "century" | "centuries" => &DateTimeField::Century, + "decade" | "decades" => &DateTimeField::Decade, + "isoyear" => &DateTimeField::Isoyear, + "year" | "years" | "y" => &DateTimeField::Year, + "quarter" | "quarters" => &DateTimeField::Quarter, + "month" | "months" | "mon" | "mons" => &DateTimeField::Month, + "dayofyear" | "doy" => &DateTimeField::DayOfYear, + "dayofweek" | "dow" => &DateTimeField::DayOfWeek, + "isoweek" | "week" | "weeks" => &DateTimeField::IsoWeek, + "isodow" => &DateTimeField::Isodow, + "day" | "days" | "d" => &DateTimeField::Day, + "hour" | "hours" | "h" => &DateTimeField::Hour, + "minute" | "minutes" | "mins" | "min" | "m" => &DateTimeField::Minute, + "second" | "seconds" | "sec" | "secs" | "s" => &DateTimeField::Second, + "millisecond" | "milliseconds" | "ms" => &DateTimeField::Millisecond, + "microsecond" | "microseconds" | "us" => &DateTimeField::Microsecond, + "nanosecond" | "nanoseconds" | "ns" => &DateTimeField::Nanosecond, + #[cfg(feature = "timezones")] + "timezone" => &DateTimeField::Timezone, + "time" => &DateTimeField::Time, + "epoch" => &DateTimeField::Epoch, + _ => { + polars_bail!(SQLSyntax: "EXTRACT/DATE_PART does not support '{}' part", value) + }, + } + }, + _ => field, + }; Ok(match field { DateTimeField::Millennium => expr.dt().millennium(), DateTimeField::Century => expr.dt().century(), @@ -1226,40 +1262,6 @@ fn parse_extract(expr: Expr, field: &DateTimeField) -> PolarsResult { }) } -pub(crate) fn parse_date_part(expr: Expr, part: &str) -> PolarsResult { - let part = part.to_ascii_lowercase(); - parse_extract( - expr, - match part.as_str() { - "millennium" => &DateTimeField::Millennium, - "century" => &DateTimeField::Century, - "decade" => &DateTimeField::Decade, - "isoyear" => &DateTimeField::Isoyear, - "year" => &DateTimeField::Year, - "quarter" => &DateTimeField::Quarter, - "month" => &DateTimeField::Month, - "dayofyear" | "doy" => &DateTimeField::DayOfYear, - "dayofweek" | "dow" => &DateTimeField::DayOfWeek, - "isoweek" | "week" => &DateTimeField::IsoWeek, - "isodow" => &DateTimeField::Isodow, - "day" => &DateTimeField::Day, - "hour" => &DateTimeField::Hour, - "minute" => &DateTimeField::Minute, - "second" => &DateTimeField::Second, - "millisecond" | "milliseconds" => &DateTimeField::Millisecond, - "microsecond" | "microseconds" => &DateTimeField::Microsecond, - "nanosecond" | "nanoseconds" => &DateTimeField::Nanosecond, - #[cfg(feature = "timezones")] - "timezone" => &DateTimeField::Timezone, - "time" => &DateTimeField::Time, - "epoch" => &DateTimeField::Epoch, - _ => { - polars_bail!(SQLSyntax: "EXTRACT/DATE_PART does not support '{}' part", part) - }, - }, - ) -} - fn bitstring_to_bytes_literal(b: &String) -> PolarsResult { let n_bits = b.len(); if !b.chars().all(|c| c == '0' || c == '1') || n_bits > 64 { diff --git a/crates/polars-time/src/windows/duration.rs b/crates/polars-time/src/windows/duration.rs index b96d4b16e0d3..a459a935107b 100644 --- a/crates/polars-time/src/windows/duration.rs +++ b/crates/polars-time/src/windows/duration.rs @@ -257,13 +257,9 @@ impl Duration { }, _ if as_interval => match &*unit { // interval-only (verbose/sql) matches - "nanosec" | "nanosecs" | "nanosecond" | "nanoseconds" => nsecs += n, - "microsec" | "microsecs" | "microsecond" | "microseconds" => { - nsecs += n * NS_MICROSECOND - }, - "millisec" | "millisecs" | "millisecond" | "milliseconds" => { - nsecs += n * NS_MILLISECOND - }, + "nanosecond" | "nanoseconds" => nsecs += n, + "microsecond" | "microseconds" => nsecs += n * NS_MICROSECOND, + "millisecond" | "milliseconds" => nsecs += n * NS_MILLISECOND, "sec" | "secs" | "second" | "seconds" => nsecs += n * NS_SECOND, "min" | "mins" | "minute" | "minutes" => nsecs += n * NS_MINUTE, "hour" | "hours" => nsecs += n * NS_HOUR, diff --git a/py-polars/docs/source/reference/sql/functions/temporal.rst b/py-polars/docs/source/reference/sql/functions/temporal.rst index d54ee6e5693f..e698fc439907 100644 --- a/py-polars/docs/source/reference/sql/functions/temporal.rst +++ b/py-polars/docs/source/reference/sql/functions/temporal.rst @@ -47,24 +47,27 @@ DATE_PART Extracts a part of a date (or datetime) such as 'year', 'month', etc. **Supported parts/fields:** - - "day" - - "dayofweek" | "dow" + - "millennium" | "millennia" + - "century" | "centuries" + - "decade" | "decades" + - "isoyear" + - "year" | "years" | "y" + - "quarter" | "quarters" + - "month" | "months" | "mon" | "mons" - "dayofyear" | "doy" - - "decade" - - "epoch" - - "hour" - - "isodow" + - "dayofweek" | "dow" - "isoweek" | "week" - - "isoyear" - - "microsecond(s)" - - "millisecond(s)" - - "nanosecond(s)" - - "minute" - - "month" - - "quarter" - - "second" + - "isodow" + - "day" | "days" | "d" + - "hour" | "hours" | "h" + - "minute" | "minutes" | "mins" | "min" | "m" + - "second" | "seconds" | "sec" | "secs" | "s" + - "millisecond" | "milliseconds" | "ms" + - "microsecond" | "microseconds" | "us" + - "nanosecond" | "nanoseconds" | "ns" + - "timezone" - "time" - - "year" + - "epoch" **Example:** @@ -106,24 +109,27 @@ EXTRACT Extracts a part of a date (or datetime) such as 'year', 'month', etc. **Supported parts/fields:** - - "day" - - "dayofweek" | "dow" + - "millennium" | "millennia" + - "century" | "centuries" + - "decade" | "decades" + - "isoyear" + - "year" | "years" | "y" + - "quarter" | "quarters" + - "month" | "months" | "mon" | "mons" - "dayofyear" | "doy" - - "decade" - - "epoch" - - "hour" - - "isodow" + - "dayofweek" | "dow" - "isoweek" | "week" - - "isoyear" - - "microsecond(s)" - - "millisecond(s)" - - "nanosecond(s)" - - "minute" - - "month" - - "quarter" - - "second" + - "isodow" + - "day" | "days" | "d" + - "hour" | "hours" | "h" + - "minute" | "minutes" | "mins" | "min" | "m" + - "second" | "seconds" | "sec" | "secs" | "s" + - "millisecond" | "milliseconds" | "ms" + - "microsecond" | "microseconds" | "us" + - "nanosecond" | "nanoseconds" | "ns" + - "timezone" - "time" - - "year" + - "epoch" .. code-block:: python diff --git a/py-polars/tests/unit/sql/test_temporal.py b/py-polars/tests/unit/sql/test_temporal.py index cd4919a13485..3233369a5ccd 100644 --- a/py-polars/tests/unit/sql/test_temporal.py +++ b/py-polars/tests/unit/sql/test_temporal.py @@ -56,37 +56,49 @@ def test_datetime_to_time(time_unit: Literal["ns", "us", "ms"]) -> None: @pytest.mark.parametrize( - ("part", "dtype", "expected"), + ("parts", "dtype", "expected"), [ - ("decade", pl.Int32, [202, 202, 200]), - ("isoyear", pl.Int32, [2024, 2020, 2005]), - ("year", pl.Int32, [2024, 2020, 2006]), - ("quarter", pl.Int8, [1, 4, 1]), - ("month", pl.Int8, [1, 12, 1]), - ("week", pl.Int8, [1, 53, 52]), - ("doy", pl.Int16, [7, 365, 1]), - ("isodow", pl.Int8, [7, 3, 7]), - ("dow", pl.Int8, [0, 3, 0]), - ("day", pl.Int8, [7, 30, 1]), - ("hour", pl.Int8, [1, 10, 23]), - ("minute", pl.Int8, [2, 30, 59]), - ("second", pl.Int8, [3, 45, 59]), - ("millisecond", pl.Float64, [3123.456, 45987.654, 59555.555]), - ("microsecond", pl.Float64, [3123456.0, 45987654.0, 59555555.0]), - ("nanosecond", pl.Float64, [3123456000.0, 45987654000.0, 59555555000.0]), + (["decade", "decades"], pl.Int32, [202, 202, 200]), + (["isoyear"], pl.Int32, [2024, 2020, 2005]), + (["year", "y"], pl.Int32, [2024, 2020, 2006]), + (["quarter"], pl.Int8, [1, 4, 1]), + (["month", "months", "mon", "mons"], pl.Int8, [1, 12, 1]), + (["week", "weeks"], pl.Int8, [1, 53, 52]), + (["doy"], pl.Int16, [7, 365, 1]), + (["isodow"], pl.Int8, [7, 3, 7]), + (["dow"], pl.Int8, [0, 3, 0]), + (["day", "days", "d"], pl.Int8, [7, 30, 1]), + (["hour", "hours", "h"], pl.Int8, [1, 10, 23]), + (["minute", "min", "mins", "m"], pl.Int8, [2, 30, 59]), + (["second", "seconds", "secs", "sec"], pl.Int8, [3, 45, 59]), ( - "time", + ["millisecond", "milliseconds", "ms"], + pl.Float64, + [3123.456, 45987.654, 59555.555], + ), + ( + ["microsecond", "microseconds", "us"], + pl.Float64, + [3123456.0, 45987654.0, 59555555.0], + ), + ( + ["nanosecond", "nanoseconds", "ns"], + pl.Float64, + [3123456000.0, 45987654000.0, 59555555000.0], + ), + ( + ["time"], pl.Time, [time(1, 2, 3, 123456), time(10, 30, 45, 987654), time(23, 59, 59, 555555)], ), ( - "epoch", + ["epoch"], pl.Float64, [1704589323.123456, 1609324245.987654, 1136159999.555555], ), ], ) -def test_extract(part: str, dtype: pl.DataType, expected: list[Any]) -> None: +def test_extract(parts: list[str], dtype: pl.DataType, expected: list[Any]) -> None: df = pl.DataFrame( { "dt": [ @@ -100,11 +112,14 @@ def test_extract(part: str, dtype: pl.DataType, expected: list[Any]) -> None: } ) with pl.SQLContext(frame_data=df, eager=True) as ctx: - for func in (f"EXTRACT({part} FROM dt)", f"DATE_PART('{part}',dt)"): - res = ctx.execute(f"SELECT {func} AS {part} FROM frame_data").to_series() - - assert res.dtype == dtype - assert res.to_list() == expected + for part in parts: + for fn in ( + f"EXTRACT({part} FROM dt)", + f"DATE_PART('{part}',dt)", + ): + res = ctx.execute(f"SELECT {fn} AS {part} FROM frame_data").to_series() + assert res.dtype == dtype + assert res.to_list() == expected def test_extract_errors() -> None: