From 6518a80f99b9182ea7c669af103a2c60ec6693b5 Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Tue, 25 Jun 2024 09:43:35 +0400 Subject: [PATCH] feat: Improve SQL support for array indexing, increase test coverage (#16972) --- crates/polars-sql/src/functions.rs | 14 ++- crates/polars-sql/src/sql_expr.rs | 46 ++++++++- .../source/reference/sql/functions/array.rst | 6 +- py-polars/tests/unit/sql/test_array.py | 95 +++++++++++++++++-- 4 files changed, 142 insertions(+), 19 deletions(-) diff --git a/crates/polars-sql/src/functions.rs b/crates/polars-sql/src/functions.rs index 155953448d69..f8b2f281a025 100644 --- a/crates/polars-sql/src/functions.rs +++ b/crates/polars-sql/src/functions.rs @@ -15,7 +15,7 @@ use sqlparser::ast::{ OrderByExpr, Value as SQLValue, WindowSpec, WindowType, }; -use crate::sql_expr::{parse_extract_date_part, parse_sql_expr}; +use crate::sql_expr::{adjust_one_indexed_param, parse_extract_date_part, parse_sql_expr}; use crate::SQLContext; pub(crate) struct SQLFunctionVisitor<'a> { @@ -1016,7 +1016,7 @@ impl SQLFunctionVisitor<'_> { }, OctetLength => self.visit_unary(|e| e.str().len_bytes()), StrPos => { - // note: 1-indexed, not 0-indexed, and returns zero if match not found + // // note: SQL is 1-indexed; returns zero if no match found self.visit_binary(|expr, substring| { (expr.str().find(substring, true) + typed_lit(1u32)).fill_null(typed_lit(0u32)) }) @@ -1092,7 +1092,7 @@ impl SQLFunctionVisitor<'_> { Substring => { let args = extract_args(function)?; match args.len() { - // note that SQL is 1-indexed, not 0-indexed, hence the need for adjustments + // note: SQL is 1-indexed, hence the need for adjustments 2 => self.try_visit_binary(|e, start| { Ok(match start { Expr::Literal(Null) => lit(Null), @@ -1148,7 +1148,13 @@ impl SQLFunctionVisitor<'_> { // ---- ArrayAgg => self.visit_arr_agg(), ArrayContains => self.visit_binary::(|e, s| e.list().contains(s)), - ArrayGet => self.visit_binary(|e, i| e.list().get(i, true)), + ArrayGet => { + // note: SQL is 1-indexed, not 0-indexed + self.visit_binary(|e, idx: Expr| { + let idx = adjust_one_indexed_param(idx, true); + e.list().get(idx, true) + }) + }, ArrayLength => self.visit_unary(|e| e.list().len()), ArrayMax => self.visit_unary(|e| e.list().max()), ArrayMean => self.visit_unary(|e| e.list().mean()), diff --git a/crates/polars-sql/src/sql_expr.rs b/crates/polars-sql/src/sql_expr.rs index bcf95d75f8d4..8fffef9c1e8a 100644 --- a/crates/polars-sql/src/sql_expr.rs +++ b/crates/polars-sql/src/sql_expr.rs @@ -17,7 +17,7 @@ use sqlparser::ast::ExactNumberInfo; use sqlparser::ast::{ ArrayElemTypeDef, BinaryOperator as SQLBinaryOperator, BinaryOperator, CastFormat, CastKind, DataType as SQLDataType, DateTimeField, Expr as SQLExpr, Function as SQLFunction, Ident, - Interval, JoinConstraint, ObjectName, Query as Subquery, SelectItem, TimezoneInfo, + Interval, JoinConstraint, ObjectName, Query as Subquery, SelectItem, Subscript, TimezoneInfo, TrimWhereField, UnaryOperator, Value as SQLValue, }; use sqlparser::dialect::GenericDialect; @@ -295,7 +295,7 @@ impl SQLExprVisitor<'_> { } => self.visit_like(*negated, expr, pattern, escape_char, true), SQLExpr::Nested(expr) => self.visit_expr(expr), SQLExpr::Position { expr, r#in } => Ok( - // note: SQL is 1-indexed, not 0-indexed + // note: SQL is 1-indexed (self .visit_expr(r#in)? .str() @@ -316,6 +316,7 @@ impl SQLExprVisitor<'_> { .contains(self.visit_expr(pattern)?, true); Ok(if *negated { matches.not() } else { matches }) }, + SQLExpr::Subscript { expr, subscript } => self.visit_subscript(expr, subscript), SQLExpr::Subquery(_) => polars_bail!(SQLInterface: "unexpected subquery"), SQLExpr::Trim { expr, @@ -443,6 +444,19 @@ impl SQLExprVisitor<'_> { } } + fn visit_subscript(&mut self, expr: &SQLExpr, subscript: &Subscript) -> PolarsResult { + let expr = self.visit_expr(expr)?; + Ok(match subscript { + Subscript::Index { index } => { + let idx = adjust_one_indexed_param(self.visit_expr(index)?, true); + expr.list().get(idx, true) + }, + Subscript::Slice { .. } => { + polars_bail!(SQLSyntax: "array slice syntax is not currently supported") + }, + }) + } + /// Handle implicit temporal string comparisons. /// /// eg: "dt >= '2024-04-30'", or "dtm::date = '2077-10-10'" @@ -1190,6 +1204,34 @@ pub(crate) fn parse_extract_date_part(expr: Expr, field: &DateTimeField) -> Pola }) } +/// Allow an expression that represents a 1-indexed parameter to +/// be adjusted from 1-indexed (SQL) to 0-indexed (Rust/Polars) +pub(crate) fn adjust_one_indexed_param(idx: Expr, null_if_zero: bool) -> Expr { + match idx { + Expr::Literal(Null) => lit(Null), + Expr::Literal(LiteralValue::Int(0)) => { + if null_if_zero { + lit(Null) + } else { + idx + } + }, + Expr::Literal(LiteralValue::Int(n)) if n < 0 => idx, + Expr::Literal(LiteralValue::Int(n)) => lit(n - 1), + // TODO: when 'saturating_sub' is available, should be able + // to streamline the when/then/otherwise block below - + _ => when(idx.clone().gt(lit(0))) + .then(idx.clone() - lit(1)) + .otherwise(if null_if_zero { + when(idx.clone().eq(lit(0))) + .then(lit(Null)) + .otherwise(idx.clone()) + } else { + idx.clone() + }), + } +} + fn bitstring_to_bytes_literal(b: &String) -> PolarsResult { let n_bits = b.len(); if !b.chars().all(|c| c == '0' || c == '1') || n_bits > 64 { diff --git a/py-polars/docs/source/reference/sql/functions/array.rst b/py-polars/docs/source/reference/sql/functions/array.rst index 9bf76d0dc374..6774aa0444ed 100644 --- a/py-polars/docs/source/reference/sql/functions/array.rst +++ b/py-polars/docs/source/reference/sql/functions/array.rst @@ -105,7 +105,7 @@ Returns the value at the given index in the array. SELECT foo, bar, ARRAY_GET(foo, 1) AS foo_at_1, - ARRAY_GET(bar, 2) AS bar_at_2 + ARRAY_GET(bar, 3) AS bar_at_2 FROM self """) # shape: (2, 4) @@ -114,8 +114,8 @@ Returns the value at the given index in the array. # │ --- ┆ --- ┆ --- ┆ --- │ # │ list[i64] ┆ list[i64] ┆ i64 ┆ i64 │ # ╞═══════════╪════════════╪══════════╪══════════╡ - # │ [1, 2] ┆ [6, 7] ┆ 2 ┆ null │ - # │ [4, 3, 2] ┆ [8, 9, 10] ┆ 3 ┆ 10 │ + # │ [1, 2] ┆ [6, 7] ┆ 1 ┆ null │ + # │ [4, 3, 2] ┆ [8, 9, 10] ┆ 4 ┆ 10 │ # └───────────┴────────────┴──────────┴──────────┘ .. _array_length: diff --git a/py-polars/tests/unit/sql/test_array.py b/py-polars/tests/unit/sql/test_array.py index 62f9781f9296..1c513707db7e 100644 --- a/py-polars/tests/unit/sql/test_array.py +++ b/py-polars/tests/unit/sql/test_array.py @@ -25,7 +25,8 @@ def test_array_agg(sort_order: str | None, limit: int | None, expected: Any) -> order_by = "" if not sort_order else f" ORDER BY col0 {sort_order}" limit_clause = "" if not limit else f" LIMIT {limit}" - res = pl.sql(f""" + res = pl.sql( + f""" WITH data (col0, col1, col2) as ( VALUES (1,'a','x'), @@ -38,7 +39,8 @@ def test_array_agg(sort_order: str | None, limit: int | None, expected: Any) -> FROM data GROUP BY col1 ORDER BY col1 - """).collect() + """ + ).collect() assert res.rows() == expected @@ -48,9 +50,17 @@ def test_array_literals() -> None: res = ctx.execute( """ SELECT - a1, a2, ARRAY_AGG(a1) AS a3, ARRAY_AGG(a2) AS a4 + a1, a2, + -- test some array ops + ARRAY_AGG(a1) AS a3, + ARRAY_AGG(a2) AS a4, + ARRAY_CONTAINS(a1,20) AS i20, + ARRAY_CONTAINS(a2,'zz') AS izz, + ARRAY_REVERSE(a1) AS ar1, + ARRAY_REVERSE(a2) AS ar2 FROM ( SELECT + -- declare array literals [10,20,30] AS a1, ['a','b','c'] AS a2, FROM df @@ -65,19 +75,81 @@ def test_array_literals() -> None: "a2": [["a", "b", "c"]], "a3": [[[10, 20, 30]]], "a4": [[["a", "b", "c"]]], + "i20": [True], + "izz": [False], + "ar1": [[30, 20, 10]], + "ar2": [["c", "b", "a"]], } ), ) +@pytest.mark.parametrize( + ("array_index", "expected"), + [ + (-4, None), + (-3, 99), + (-2, 66), + (-1, 33), + (0, None), + (1, 99), + (2, 66), + (3, 33), + (4, None), + ], +) +def test_array_indexing(array_index: int, expected: int | None) -> None: + res = pl.sql( + f""" + SELECT + arr[{array_index}] AS idx1, + ARRAY_GET(arr,{array_index}) AS idx2, + FROM (SELECT [99,66,33] AS arr) tbl + """ + ).collect() + + assert_frame_equal( + res, + pl.DataFrame( + {"idx1": [expected], "idx2": [expected]}, + ), + check_dtypes=False, + ) + + +def test_array_indexing_by_expr() -> None: + df = pl.DataFrame( + { + "idx": [-2, -1, 0, None, 1, 2, 3], + "arr": [[0, 1, 2, 3], [4, 5], [6], [7, 8, 9], [8, 7], [6, 5, 4], [3, 2, 1]], + } + ) + res = df.sql( + """ + SELECT + arr[idx] AS idx1, + ARRAY_GET(arr, idx) AS idx2 + FROM self + """ + ) + expected = [2, 5, None, None, 8, 5, 1] + assert_frame_equal(res, pl.DataFrame({"idx1": expected, "idx2": expected})) + + def test_array_to_string() -> None: - data = {"values": [["aa", "bb"], [None, "cc"], ["dd", None]]} + data = { + "s_values": [["aa", "bb"], [None, "cc"], ["dd", None]], + "n_values": [[999, 777], [None, 555], [333, None]], + } res = pl.DataFrame(data).sql( """ SELECT - ARRAY_TO_STRING(values, '') AS v1, - ARRAY_TO_STRING(values, ':') AS v2, - ARRAY_TO_STRING(values, ':', 'NA') AS v3 + ARRAY_TO_STRING(s_values, '') AS vs1, + ARRAY_TO_STRING(s_values, ':') AS vs2, + ARRAY_TO_STRING(s_values, ':', 'NA') AS vs3, + ARRAY_TO_STRING(n_values, '') AS vn1, + ARRAY_TO_STRING(n_values, ':') AS vn2, + ARRAY_TO_STRING(n_values, ':', 'NA') AS vn3 FROM self """ ) @@ -85,9 +157,12 @@ def test_array_to_string() -> None: res, pl.DataFrame( { - "v1": ["aabb", "cc", "dd"], - "v2": ["aa:bb", "cc", "dd"], - "v3": ["aa:bb", "NA:cc", "dd:NA"], + "vs1": ["aabb", "cc", "dd"], + "vs2": ["aa:bb", "cc", "dd"], + "vs3": ["aa:bb", "NA:cc", "dd:NA"], + "vn1": ["999777", "555", "333"], + "vn2": ["999:777", "555", "333"], + "vn3": ["999:777", "NA:555", "333:NA"], } ), )