diff --git a/py-polars/polars/io/database/functions.py b/py-polars/polars/io/database/functions.py index aa686e2f813e..21e436dc0557 100644 --- a/py-polars/polars/io/database/functions.py +++ b/py-polars/polars/io/database/functions.py @@ -25,10 +25,12 @@ except ImportError: Selectable: TypeAlias = Any # type: ignore[no-redef] + from sqlalchemy.sql.elements import TextClause + @overload def read_database( - query: str | Selectable, + query: str | TextClause | Selectable, connection: ConnectionOrCursor | str, *, iter_batches: Literal[False] = ..., @@ -41,7 +43,7 @@ def read_database( @overload def read_database( - query: str | Selectable, + query: str | TextClause | Selectable, connection: ConnectionOrCursor | str, *, iter_batches: Literal[True], @@ -54,7 +56,7 @@ def read_database( @overload def read_database( - query: str | Selectable, + query: str | TextClause | Selectable, connection: ConnectionOrCursor | str, *, iter_batches: bool, @@ -66,7 +68,7 @@ def read_database( def read_database( - query: str | Selectable, + query: str | TextClause | Selectable, connection: ConnectionOrCursor | str, *, iter_batches: bool = False, diff --git a/py-polars/tests/unit/io/database/test_read.py b/py-polars/tests/unit/io/database/test_read.py index fb6fa8dca0ed..69e7853172a1 100644 --- a/py-polars/tests/unit/io/database/test_read.py +++ b/py-polars/tests/unit/io/database/test_read.py @@ -12,7 +12,7 @@ import pyarrow as pa import pytest import sqlalchemy -from sqlalchemy import Integer, MetaData, Table, create_engine, func, select +from sqlalchemy import Integer, MetaData, Table, create_engine, func, select, text from sqlalchemy.orm import sessionmaker from sqlalchemy.sql.expression import cast as alchemy_cast @@ -383,6 +383,39 @@ def test_read_database_alchemy_selectable(tmp_sqlite_db: Path) -> None: assert_frame_equal(batches[0], expected) +def test_read_database_alchemy_textclause(tmp_sqlite_db: Path) -> None: + # various flavours of alchemy connection + alchemy_engine = create_engine(f"sqlite:///{tmp_sqlite_db}") + alchemy_session: ConnectionOrCursor = sessionmaker(bind=alchemy_engine)() + alchemy_conn: ConnectionOrCursor = alchemy_engine.connect() + + # establish sqlalchemy "textclause" and validate usage + textclause_query = text(""" + SELECT CAST(STRFTIME('%Y',"date") AS INT) as "year", name, value + FROM test_data + WHERE value < 0 + """) + + expected = pl.DataFrame({"year": [2021], "name": ["other"], "value": [-99.5]}) + + for conn in (alchemy_session, alchemy_engine, alchemy_conn): + assert_frame_equal( + pl.read_database(textclause_query, connection=conn), + expected, + ) + + batches = list( + pl.read_database( + textclause_query, + connection=conn, + iter_batches=True, + batch_size=1, + ) + ) + assert len(batches) == 1 + assert_frame_equal(batches[0], expected) + + def test_read_database_parameterised(tmp_sqlite_db: Path) -> None: # raw cursor "execute" only takes positional params, alchemy cursor takes kwargs alchemy_engine = create_engine(f"sqlite:///{tmp_sqlite_db}")