From 1d210abb0daf296aa731d3d6b955a4f18b5bb967 Mon Sep 17 00:00:00 2001 From: Max Muoto Date: Tue, 22 Oct 2024 09:04:37 -0500 Subject: [PATCH] feat: Improve `read_database_uri` typing (#19334) --- py-polars/polars/io/database/functions.py | 45 +++++++++++++++++++ py-polars/tests/unit/io/database/test_read.py | 5 ++- 2 files changed, 48 insertions(+), 2 deletions(-) diff --git a/py-polars/polars/io/database/functions.py b/py-polars/polars/io/database/functions.py index ac5dcaaac3e7..aa686e2f813e 100644 --- a/py-polars/polars/io/database/functions.py +++ b/py-polars/polars/io/database/functions.py @@ -263,6 +263,51 @@ def read_database( ) +@overload +def read_database_uri( + query: str, + uri: str, + *, + partition_on: str | None = None, + partition_range: tuple[int, int] | None = None, + partition_num: int | None = None, + protocol: str | None = None, + engine: Literal["adbc"], + schema_overrides: SchemaDict | None = None, + execute_options: dict[str, Any] | None = None, +) -> DataFrame: ... + + +@overload +def read_database_uri( + query: list[str] | str, + uri: str, + *, + partition_on: str | None = None, + partition_range: tuple[int, int] | None = None, + partition_num: int | None = None, + protocol: str | None = None, + engine: Literal["connectorx"] | None = None, + schema_overrides: SchemaDict | None = None, + execute_options: None = None, +) -> DataFrame: ... + + +@overload +def read_database_uri( + query: str, + uri: str, + *, + partition_on: str | None = None, + partition_range: tuple[int, int] | None = None, + partition_num: int | None = None, + protocol: str | None = None, + engine: DbReadEngine | None = None, + schema_overrides: None = None, + execute_options: dict[str, Any] | None = None, +) -> DataFrame: ... + + def read_database_uri( query: list[str] | str, uri: str, diff --git a/py-polars/tests/unit/io/database/test_read.py b/py-polars/tests/unit/io/database/test_read.py index deb44a5a79f4..fb6fa8dca0ed 100644 --- a/py-polars/tests/unit/io/database/test_read.py +++ b/py-polars/tests/unit/io/database/test_read.py @@ -292,17 +292,18 @@ def test_read_database( tmp_sqlite_db: Path, ) -> None: if read_method == "read_database_uri": + connect_using = cast("DbReadEngine", connect_using) # instantiate the connection ourselves, using connectorx/adbc df = pl.read_database_uri( uri=f"sqlite:///{tmp_sqlite_db}", query="SELECT * FROM test_data", - engine=str(connect_using), # type: ignore[arg-type] + engine=connect_using, schema_overrides=schema_overrides, ) df_empty = pl.read_database_uri( uri=f"sqlite:///{tmp_sqlite_db}", query="SELECT * FROM test_data WHERE name LIKE '%polars%'", - engine=str(connect_using), # type: ignore[arg-type] + engine=connect_using, schema_overrides=schema_overrides, ) elif "adbc" in os.environ["PYTEST_CURRENT_TEST"]: