Skip to content

Commit

Permalink
feat: Support SQL "SELECT" with no tables, optimise registration of…
Browse files Browse the repository at this point in the history
… globals (#16836)
  • Loading branch information
alexander-beedie authored Jun 10, 2024
1 parent 8934698 commit 92af769
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 96 deletions.
33 changes: 22 additions & 11 deletions crates/polars-sql/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -474,23 +474,20 @@ impl SQLContext {
/// Execute the 'SELECT' part of the query.
fn execute_select(&mut self, select_stmt: &Select, query: &Query) -> PolarsResult<LazyFrame> {
// Determine involved dataframes.
// Implicit joins require some more work in query parsers, explicit joins are preferred for now.
let sql_tbl: &TableWithJoins = select_stmt
.from
.first()
.ok_or_else(|| polars_err!(SQLSyntax: "no table name provided in query"))?;
// Note: implicit joins require more work in query parsing,
// explicit joins are preferred for now (ref: #16662)

let mut lf = self.execute_from_statement(sql_tbl)?;
let mut lf = if select_stmt.from.is_empty() {
DataFrame::empty().lazy()
} else {
self.execute_from_statement(select_stmt.from.first().unwrap())?
};
let mut contains_wildcard = false;
let mut contains_wildcard_exclude = false;

// Filter expression.
let schema = Some(lf.schema_with_arenas(&mut self.lp_arena, &mut self.expr_arena)?);
if let Some(expr) = select_stmt.selection.as_ref() {
let mut filter_expression = parse_sql_expr(expr, self, schema.as_deref())?;
lf = self.process_subqueries(lf, vec![&mut filter_expression]);
lf = lf.filter(filter_expression);
}
lf = self.process_where(lf, &select_stmt.selection)?;

// Column projections.
let projections: Vec<_> = select_stmt
Expand Down Expand Up @@ -668,6 +665,20 @@ impl SQLContext {
Ok(lf)
}

fn process_where(
&mut self,
mut lf: LazyFrame,
expr: &Option<SQLExpr>,
) -> PolarsResult<LazyFrame> {
if let Some(expr) = expr {
let schema = Some(lf.schema_with_arenas(&mut self.lp_arena, &mut self.expr_arena)?);
let mut filter_expression = parse_sql_expr(expr, self, schema.as_deref())?;
lf = self.process_subqueries(lf, vec![&mut filter_expression]);
lf = lf.filter(filter_expression);
}
Ok(lf)
}

pub(super) fn process_join(
&self,
left_tbl: LazyFrame,
Expand Down
89 changes: 19 additions & 70 deletions py-polars/polars/sql/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,6 @@ def __init__(
frames: Mapping[str, CompatibleFrameType | None] | None = None,
*,
register_globals: bool | int = False,
all_compatible: bool = False,
eager: bool = False,
**named_frames: CompatibleFrameType | None,
) -> None:
Expand All @@ -163,13 +162,11 @@ def __init__(
A `{name:frame, ...}` mapping which can include Polars frames *and*
pandas DataFrames, Series and pyarrow Table and RecordBatch objects.
register_globals
Register compatible objects found in the globals, automatically mapping
their variable name to a table name. If given an integer then only the
Register compatible objects (polars DataFrame, LazyFrame, and Series) found
in the globals, automatically mapping their variable name to a table name.
To register other objects (pandas/pyarrow data) pass them explicitly, or
call the `execute_global` classmethod. If given an integer then only the
most recent "n" objects found will be registered.
all_compatible
If `register_globals` is set this option controls whether we *also* register
all pandas DataFrame, Series, and pyarrow Table and RecordBatch objects.
If False, only Polars classes are registered with the SQL engine.
eager
If True, returns execution results as `DataFrame` instead of `LazyFrame`.
(Note that the query itself is always executed in lazy-mode; this parameter
Expand Down Expand Up @@ -203,7 +200,7 @@ def __init__(
frames = dict(frames or {})
if register_globals:
for name, obj in _get_frame_locals(
all_compatible=all_compatible,
all_compatible=False,
n_objects=None if (register_globals is True) else None,
).items():
if name not in frames and name not in named_frames:
Expand Down Expand Up @@ -237,9 +234,10 @@ def execute_global(
Notes
-----
* This convenience method automatically registers all compatible objects in
the local stack, mapping their variable name to a table name. Note that in
addition to polars DataFrame, LazyFrame, and Series this method will *also*
register pandas DataFrame, Series, and pyarrow Table and RecordBatch objects.
the local stack that are referenced in the query, mapping their variable name
to a table name. Note that in addition to polars DataFrame, LazyFrame, and
Series this method *also* registers pandas DataFrame, Series, and pyarrow
Table and RecordBatch objects.
* Instead of calling this classmethod you should consider using `pl.sql`,
which will use this code internally.
Expand Down Expand Up @@ -274,13 +272,16 @@ def execute_global(
# basic extraction of possible table names from the query, so we don't register
# unnecessary objects from the globals (ideally we shuoold look to make the
# underlying `sqlparser-rs` lib parse the query to identify table names)
q = re.split(r"\bFROM\b", query, maxsplit=1, flags=re.I)[1]
possible_names = {
nm
for nm in re.split(r"\s", q)
if re.match(r'^("[^"]+")$', nm)
or (nm.isidentifier() and nm.lower() not in _SQL_KEYWORDS_)
}
q = re.split(r"\bFROM\b", query, maxsplit=1, flags=re.I)
possible_names = (
{
nm.strip('"')
for nm in re.split(r"\s", q[1])
if re.match(r'^("[^"]+")$', nm) or nm.isidentifier()
}
if len(q) > 1
else set()
)
# get compatible frame objects from the globals, constraining by possible names
named_frames = _get_frame_locals(all_compatible=True, named=possible_names)
with cls(frames=named_frames, register_globals=False) as ctx:
Expand Down Expand Up @@ -668,56 +669,4 @@ def tables(self) -> list[str]:
return sorted(self._ctxt.get_tables())


_SQL_KEYWORDS_ = {
"and",
"anti",
"array",
"as",
"asc",
"boolean",
"by",
"case",
"create",
"date",
"datetime",
"desc",
"distinct",
"double",
"drop",
"exclude",
"float",
"from",
"full",
"group",
"having",
"in",
"inner",
"int",
"interval",
"join",
"left",
"limit",
"not",
"null",
"offset",
"on",
"or",
"order",
"outer",
"regexp",
"right",
"rlike",
"select",
"semi",
"show",
"table",
"tables",
"then",
"using",
"when",
"where",
"with",
}


__all__ = ["SQLContext"]
22 changes: 22 additions & 0 deletions py-polars/tests/unit/sql/test_literals.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,25 @@ def test_intervals() -> None:
match="unary ops are not valid on interval strings; found -'7d'",
):
ctx.execute("SELECT INTERVAL -'7d' AS one_week_ago FROM df")


def test_select_literals_no_table() -> None:
res = pl.sql("SELECT 1 AS one, '2' AS two, 3.0 AS three", eager=True)
assert res.to_dict(as_series=False) == {
"one": [1],
"two": ["2"],
"three": [3.0],
}


def test_select_from_table_with_reserved_names() -> None:
select = pl.DataFrame({"select": [1, 2, 3], "from": [4, 5, 6]}) # noqa: F841
out = pl.sql(
"""
SELECT "from", "select"
FROM "select"
WHERE "from" >= 5 AND "select" % 2 != 1
""",
eager=True,
)
assert out.rows() == [(5, 2)]
29 changes: 14 additions & 15 deletions py-polars/tests/unit/sql/test_miscellaneous.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,17 +168,19 @@ def test_sql_on_compatible_frame_types() -> None:
(df["a"] * 2).rename("c"), # polars series
(dfp["a"] * 2).rename("c"), # pandas series
):
res = pl.sql("""
SELECT a, b, SUM(c) AS cc FROM (
SELECT * FROM df -- polars frame
UNION ALL SELECT * FROM dfp -- pandas frame
UNION ALL SELECT * FROM dfa -- pyarrow table
UNION ALL SELECT * FROM dfb -- pyarrow record batch
) tbl
INNER JOIN dfs ON dfs.c == tbl.b -- join on pandas/polars series
GROUP BY "a", "b"
ORDER BY "a", "b"
""").collect()
res = pl.sql(
"""
SELECT a, b, SUM(c) AS cc FROM (
SELECT * FROM df -- polars frame
UNION ALL SELECT * FROM dfp -- pandas frame
UNION ALL SELECT * FROM dfa -- pyarrow table
UNION ALL SELECT * FROM dfb -- pyarrow record batch
) tbl
INNER JOIN dfs ON dfs.c == tbl.b -- join on pandas/polars series
GROUP BY "a", "b"
ORDER BY "a", "b"
"""
).collect()

expected = pl.DataFrame({"a": [1, 3], "b": [4, 6], "cc": [16, 24]})
assert_frame_equal(left=expected, right=res)
Expand All @@ -191,7 +193,4 @@ def test_sql_on_compatible_frame_types() -> None:

# don't register all compatible objects
with pytest.raises(SQLInterfaceError, match="relation 'dfp' was not found"):
pl.SQLContext(
register_globals=True,
all_compatible=False,
).execute("SELECT * FROM dfp")
pl.SQLContext(register_globals=True).execute("SELECT * FROM dfp")

0 comments on commit 92af769

Please sign in to comment.