Skip to content

Commit

Permalink
feat: Improve read_database_uri typing (#19334)
Browse files Browse the repository at this point in the history
  • Loading branch information
max-muoto authored Oct 22, 2024
1 parent c6b2329 commit 1d210ab
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 2 deletions.
45 changes: 45 additions & 0 deletions py-polars/polars/io/database/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,51 @@ def read_database(
)


@overload
def read_database_uri(
query: str,
uri: str,
*,
partition_on: str | None = None,
partition_range: tuple[int, int] | None = None,
partition_num: int | None = None,
protocol: str | None = None,
engine: Literal["adbc"],
schema_overrides: SchemaDict | None = None,
execute_options: dict[str, Any] | None = None,
) -> DataFrame: ...


@overload
def read_database_uri(
query: list[str] | str,
uri: str,
*,
partition_on: str | None = None,
partition_range: tuple[int, int] | None = None,
partition_num: int | None = None,
protocol: str | None = None,
engine: Literal["connectorx"] | None = None,
schema_overrides: SchemaDict | None = None,
execute_options: None = None,
) -> DataFrame: ...


@overload
def read_database_uri(
query: str,
uri: str,
*,
partition_on: str | None = None,
partition_range: tuple[int, int] | None = None,
partition_num: int | None = None,
protocol: str | None = None,
engine: DbReadEngine | None = None,
schema_overrides: None = None,
execute_options: dict[str, Any] | None = None,
) -> DataFrame: ...


def read_database_uri(
query: list[str] | str,
uri: str,
Expand Down
5 changes: 3 additions & 2 deletions py-polars/tests/unit/io/database/test_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,17 +292,18 @@ def test_read_database(
tmp_sqlite_db: Path,
) -> None:
if read_method == "read_database_uri":
connect_using = cast("DbReadEngine", connect_using)
# instantiate the connection ourselves, using connectorx/adbc
df = pl.read_database_uri(
uri=f"sqlite:///{tmp_sqlite_db}",
query="SELECT * FROM test_data",
engine=str(connect_using), # type: ignore[arg-type]
engine=connect_using,
schema_overrides=schema_overrides,
)
df_empty = pl.read_database_uri(
uri=f"sqlite:///{tmp_sqlite_db}",
query="SELECT * FROM test_data WHERE name LIKE '%polars%'",
engine=str(connect_using), # type: ignore[arg-type]
engine=connect_using,
schema_overrides=schema_overrides,
)
elif "adbc" in os.environ["PYTEST_CURRENT_TEST"]:
Expand Down

0 comments on commit 1d210ab

Please sign in to comment.