Skip to content

Commit

Permalink
feat: Improve SQL support for array indexing, increase test coverage (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-beedie authored Jun 25, 2024
1 parent a69f6dd commit 6518a80
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 19 deletions.
14 changes: 10 additions & 4 deletions crates/polars-sql/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use sqlparser::ast::{
OrderByExpr, Value as SQLValue, WindowSpec, WindowType,
};

use crate::sql_expr::{parse_extract_date_part, parse_sql_expr};
use crate::sql_expr::{adjust_one_indexed_param, parse_extract_date_part, parse_sql_expr};
use crate::SQLContext;

pub(crate) struct SQLFunctionVisitor<'a> {
Expand Down Expand Up @@ -1016,7 +1016,7 @@ impl SQLFunctionVisitor<'_> {
},
OctetLength => self.visit_unary(|e| e.str().len_bytes()),
StrPos => {
// note: 1-indexed, not 0-indexed, and returns zero if match not found
// // note: SQL is 1-indexed; returns zero if no match found
self.visit_binary(|expr, substring| {
(expr.str().find(substring, true) + typed_lit(1u32)).fill_null(typed_lit(0u32))
})
Expand Down Expand Up @@ -1092,7 +1092,7 @@ impl SQLFunctionVisitor<'_> {
Substring => {
let args = extract_args(function)?;
match args.len() {
// note that SQL is 1-indexed, not 0-indexed, hence the need for adjustments
// note: SQL is 1-indexed, hence the need for adjustments
2 => self.try_visit_binary(|e, start| {
Ok(match start {
Expr::Literal(Null) => lit(Null),
Expand Down Expand Up @@ -1148,7 +1148,13 @@ impl SQLFunctionVisitor<'_> {
// ----
ArrayAgg => self.visit_arr_agg(),
ArrayContains => self.visit_binary::<Expr>(|e, s| e.list().contains(s)),
ArrayGet => self.visit_binary(|e, i| e.list().get(i, true)),
ArrayGet => {
// note: SQL is 1-indexed, not 0-indexed
self.visit_binary(|e, idx: Expr| {
let idx = adjust_one_indexed_param(idx, true);
e.list().get(idx, true)
})
},
ArrayLength => self.visit_unary(|e| e.list().len()),
ArrayMax => self.visit_unary(|e| e.list().max()),
ArrayMean => self.visit_unary(|e| e.list().mean()),
Expand Down
46 changes: 44 additions & 2 deletions crates/polars-sql/src/sql_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use sqlparser::ast::ExactNumberInfo;
use sqlparser::ast::{
ArrayElemTypeDef, BinaryOperator as SQLBinaryOperator, BinaryOperator, CastFormat, CastKind,
DataType as SQLDataType, DateTimeField, Expr as SQLExpr, Function as SQLFunction, Ident,
Interval, JoinConstraint, ObjectName, Query as Subquery, SelectItem, TimezoneInfo,
Interval, JoinConstraint, ObjectName, Query as Subquery, SelectItem, Subscript, TimezoneInfo,
TrimWhereField, UnaryOperator, Value as SQLValue,
};
use sqlparser::dialect::GenericDialect;
Expand Down Expand Up @@ -295,7 +295,7 @@ impl SQLExprVisitor<'_> {
} => self.visit_like(*negated, expr, pattern, escape_char, true),
SQLExpr::Nested(expr) => self.visit_expr(expr),
SQLExpr::Position { expr, r#in } => Ok(
// note: SQL is 1-indexed, not 0-indexed
// note: SQL is 1-indexed
(self
.visit_expr(r#in)?
.str()
Expand All @@ -316,6 +316,7 @@ impl SQLExprVisitor<'_> {
.contains(self.visit_expr(pattern)?, true);
Ok(if *negated { matches.not() } else { matches })
},
SQLExpr::Subscript { expr, subscript } => self.visit_subscript(expr, subscript),
SQLExpr::Subquery(_) => polars_bail!(SQLInterface: "unexpected subquery"),
SQLExpr::Trim {
expr,
Expand Down Expand Up @@ -443,6 +444,19 @@ impl SQLExprVisitor<'_> {
}
}

fn visit_subscript(&mut self, expr: &SQLExpr, subscript: &Subscript) -> PolarsResult<Expr> {
let expr = self.visit_expr(expr)?;
Ok(match subscript {
Subscript::Index { index } => {
let idx = adjust_one_indexed_param(self.visit_expr(index)?, true);
expr.list().get(idx, true)
},
Subscript::Slice { .. } => {
polars_bail!(SQLSyntax: "array slice syntax is not currently supported")
},
})
}

/// Handle implicit temporal string comparisons.
///
/// eg: "dt >= '2024-04-30'", or "dtm::date = '2077-10-10'"
Expand Down Expand Up @@ -1190,6 +1204,34 @@ pub(crate) fn parse_extract_date_part(expr: Expr, field: &DateTimeField) -> Pola
})
}

/// Allow an expression that represents a 1-indexed parameter to
/// be adjusted from 1-indexed (SQL) to 0-indexed (Rust/Polars)
pub(crate) fn adjust_one_indexed_param(idx: Expr, null_if_zero: bool) -> Expr {
match idx {
Expr::Literal(Null) => lit(Null),
Expr::Literal(LiteralValue::Int(0)) => {
if null_if_zero {
lit(Null)
} else {
idx
}
},
Expr::Literal(LiteralValue::Int(n)) if n < 0 => idx,
Expr::Literal(LiteralValue::Int(n)) => lit(n - 1),
// TODO: when 'saturating_sub' is available, should be able
// to streamline the when/then/otherwise block below -
_ => when(idx.clone().gt(lit(0)))
.then(idx.clone() - lit(1))
.otherwise(if null_if_zero {
when(idx.clone().eq(lit(0)))
.then(lit(Null))
.otherwise(idx.clone())
} else {
idx.clone()
}),
}
}

fn bitstring_to_bytes_literal(b: &String) -> PolarsResult<Expr> {
let n_bits = b.len();
if !b.chars().all(|c| c == '0' || c == '1') || n_bits > 64 {
Expand Down
6 changes: 3 additions & 3 deletions py-polars/docs/source/reference/sql/functions/array.rst
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ Returns the value at the given index in the array.
SELECT
foo, bar,
ARRAY_GET(foo, 1) AS foo_at_1,
ARRAY_GET(bar, 2) AS bar_at_2
ARRAY_GET(bar, 3) AS bar_at_2
FROM self
""")
# shape: (2, 4)
Expand All @@ -114,8 +114,8 @@ Returns the value at the given index in the array.
# │ --- ┆ --- ┆ --- ┆ --- │
# │ list[i64] ┆ list[i64] ┆ i64 ┆ i64 │
# ╞═══════════╪════════════╪══════════╪══════════╡
# │ [1, 2] ┆ [6, 7] ┆ 2 ┆ null │
# │ [4, 3, 2] ┆ [8, 9, 10] ┆ 3 ┆ 10 │
# │ [1, 2] ┆ [6, 7] ┆ 1 ┆ null │
# │ [4, 3, 2] ┆ [8, 9, 10] ┆ 4 ┆ 10 │
# └───────────┴────────────┴──────────┴──────────┘
.. _array_length:
Expand Down
95 changes: 85 additions & 10 deletions py-polars/tests/unit/sql/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ def test_array_agg(sort_order: str | None, limit: int | None, expected: Any) ->
order_by = "" if not sort_order else f" ORDER BY col0 {sort_order}"
limit_clause = "" if not limit else f" LIMIT {limit}"

res = pl.sql(f"""
res = pl.sql(
f"""
WITH data (col0, col1, col2) as (
VALUES
(1,'a','x'),
Expand All @@ -38,7 +39,8 @@ def test_array_agg(sort_order: str | None, limit: int | None, expected: Any) ->
FROM data
GROUP BY col1
ORDER BY col1
""").collect()
"""
).collect()

assert res.rows() == expected

Expand All @@ -48,9 +50,17 @@ def test_array_literals() -> None:
res = ctx.execute(
"""
SELECT
a1, a2, ARRAY_AGG(a1) AS a3, ARRAY_AGG(a2) AS a4
a1, a2,
-- test some array ops
ARRAY_AGG(a1) AS a3,
ARRAY_AGG(a2) AS a4,
ARRAY_CONTAINS(a1,20) AS i20,
ARRAY_CONTAINS(a2,'zz') AS izz,
ARRAY_REVERSE(a1) AS ar1,
ARRAY_REVERSE(a2) AS ar2
FROM (
SELECT
-- declare array literals
[10,20,30] AS a1,
['a','b','c'] AS a2,
FROM df
Expand All @@ -65,29 +75,94 @@ def test_array_literals() -> None:
"a2": [["a", "b", "c"]],
"a3": [[[10, 20, 30]]],
"a4": [[["a", "b", "c"]]],
"i20": [True],
"izz": [False],
"ar1": [[30, 20, 10]],
"ar2": [["c", "b", "a"]],
}
),
)


@pytest.mark.parametrize(
("array_index", "expected"),
[
(-4, None),
(-3, 99),
(-2, 66),
(-1, 33),
(0, None),
(1, 99),
(2, 66),
(3, 33),
(4, None),
],
)
def test_array_indexing(array_index: int, expected: int | None) -> None:
res = pl.sql(
f"""
SELECT
arr[{array_index}] AS idx1,
ARRAY_GET(arr,{array_index}) AS idx2,
FROM (SELECT [99,66,33] AS arr) tbl
"""
).collect()

assert_frame_equal(
res,
pl.DataFrame(
{"idx1": [expected], "idx2": [expected]},
),
check_dtypes=False,
)


def test_array_indexing_by_expr() -> None:
df = pl.DataFrame(
{
"idx": [-2, -1, 0, None, 1, 2, 3],
"arr": [[0, 1, 2, 3], [4, 5], [6], [7, 8, 9], [8, 7], [6, 5, 4], [3, 2, 1]],
}
)
res = df.sql(
"""
SELECT
arr[idx] AS idx1,
ARRAY_GET(arr, idx) AS idx2
FROM self
"""
)
expected = [2, 5, None, None, 8, 5, 1]
assert_frame_equal(res, pl.DataFrame({"idx1": expected, "idx2": expected}))


def test_array_to_string() -> None:
data = {"values": [["aa", "bb"], [None, "cc"], ["dd", None]]}
data = {
"s_values": [["aa", "bb"], [None, "cc"], ["dd", None]],
"n_values": [[999, 777], [None, 555], [333, None]],
}
res = pl.DataFrame(data).sql(
"""
SELECT
ARRAY_TO_STRING(values, '') AS v1,
ARRAY_TO_STRING(values, ':') AS v2,
ARRAY_TO_STRING(values, ':', 'NA') AS v3
ARRAY_TO_STRING(s_values, '') AS vs1,
ARRAY_TO_STRING(s_values, ':') AS vs2,
ARRAY_TO_STRING(s_values, ':', 'NA') AS vs3,
ARRAY_TO_STRING(n_values, '') AS vn1,
ARRAY_TO_STRING(n_values, ':') AS vn2,
ARRAY_TO_STRING(n_values, ':', 'NA') AS vn3
FROM self
"""
)
assert_frame_equal(
res,
pl.DataFrame(
{
"v1": ["aabb", "cc", "dd"],
"v2": ["aa:bb", "cc", "dd"],
"v3": ["aa:bb", "NA:cc", "dd:NA"],
"vs1": ["aabb", "cc", "dd"],
"vs2": ["aa:bb", "cc", "dd"],
"vs3": ["aa:bb", "NA:cc", "dd:NA"],
"vn1": ["999777", "555", "333"],
"vn2": ["999:777", "555", "333"],
"vn3": ["999:777", "NA:555", "333:NA"],
}
),
)
Expand Down

0 comments on commit 6518a80

Please sign in to comment.