diff --git a/connectorx-python/connectorx/__init__.py b/connectorx-python/connectorx/__init__.py index 8415c5126..1f91fdc9d 100644 --- a/connectorx-python/connectorx/__init__.py +++ b/connectorx-python/connectorx/__init__.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any +from typing import Any, Literal, TYPE_CHECKING, overload from importlib.metadata import version @@ -11,6 +11,13 @@ get_meta as _get_meta, ) +if TYPE_CHECKING: + import pandas as pd + import polars as pl + import modin.pandas as mpd + import dask.dataframe as dd + import pyarrow as pa + __version__ = version(__name__) import os @@ -27,8 +34,10 @@ "CX_REWRITER_PATH", os.path.join(dir_path, "dependencies/federated-rewriter.jar") ) +Protocol = Literal["csv", "binary", "cursor", "simple", "text"] -def rewrite_conn(conn: str, protocol: str | None = None): + +def rewrite_conn(conn: str, protocol: Protocol | None = None) -> tuple[str, Protocol]: if not protocol: # note: redshift/clickhouse are not compatible with the 'binary' protocol, and use other database # drivers to connect. set a compatible protocol and masquerade as the appropriate backend. @@ -47,8 +56,8 @@ def rewrite_conn(conn: str, protocol: str | None = None): def get_meta( conn: str, query: str, - protocol: str | None = None, -): + protocol: Protocol | None = None, +) -> pd.DataFrame: """ Get metadata (header) of the given query (only for pandas) @@ -75,7 +84,7 @@ def partition_sql( partition_on: str, partition_num: int, partition_range: tuple[int, int] | None = None, -): +) -> list[str]: """ Partition the sql query @@ -106,11 +115,11 @@ def read_sql_pandas( sql: list[str] | str, con: str | dict[str, str], index_col: str | None = None, - protocol: str | None = None, + protocol: Protocol | None = None, partition_on: str | None = None, partition_range: tuple[int, int] | None = None, partition_num: int | None = None, -): +) -> pd.DataFrame: """ Run the SQL query, download the data from database into a dataframe. First several parameters are in the same name and order with `pandas.read_sql`. @@ -142,17 +151,103 @@ def read_sql_pandas( ) +# default return pd.DataFrame +@overload def read_sql( conn: str | dict[str, str], query: list[str] | str, *, - return_type: str = "pandas", - protocol: str | None = None, + protocol: Protocol | None = None, partition_on: str | None = None, partition_range: tuple[int, int] | None = None, partition_num: int | None = None, index_col: str | None = None, -): +) -> pd.DataFrame: ... + + +@overload +def read_sql( + conn: str | dict[str, str], + query: list[str] | str, + *, + return_type: Literal["pandas"], + protocol: Protocol | None = None, + partition_on: str | None = None, + partition_range: tuple[int, int] | None = None, + partition_num: int | None = None, + index_col: str | None = None, +) -> pd.DataFrame: ... + + +@overload +def read_sql( + conn: str | dict[str, str], + query: list[str] | str, + *, + return_type: Literal["arrow", "arrow2"], + protocol: Protocol | None = None, + partition_on: str | None = None, + partition_range: tuple[int, int] | None = None, + partition_num: int | None = None, + index_col: str | None = None, +) -> pa.Table: ... + + +@overload +def read_sql( + conn: str | dict[str, str], + query: list[str] | str, + *, + return_type: Literal["modin"], + protocol: Protocol | None = None, + partition_on: str | None = None, + partition_range: tuple[int, int] | None = None, + partition_num: int | None = None, + index_col: str | None = None, +) -> mpd.DataFrame: ... + + +@overload +def read_sql( + conn: str | dict[str, str], + query: list[str] | str, + *, + return_type: Literal["dask"], + protocol: Protocol | None = None, + partition_on: str | None = None, + partition_range: tuple[int, int] | None = None, + partition_num: int | None = None, + index_col: str | None = None, +) -> dd.DataFrame: ... + + +@overload +def read_sql( + conn: str | dict[str, str], + query: list[str] | str, + *, + return_type: Literal["polars", "polars2"], + protocol: Protocol | None = None, + partition_on: str | None = None, + partition_range: tuple[int, int] | None = None, + partition_num: int | None = None, + index_col: str | None = None, +) -> pl.DataFrame: ... + + +def read_sql( + conn: str | dict[str, str], + query: list[str] | str, + *, + return_type: Literal[ + "pandas", "polars", "polars2", "arrow", "arrow2", "modin", "dask" + ] = "pandas", + protocol: Protocol | None = None, + partition_on: str | None = None, + partition_range: tuple[int, int] | None = None, + partition_num: int | None = None, + index_col: str | None = None, +) -> pd.DataFrame | mpd.DataFrame | dd.DataFrame | pl.DataFrame | pa.Table: """ Run the SQL query, download the data from database into a dataframe. @@ -318,7 +413,9 @@ def read_sql( return df -def reconstruct_arrow(result: tuple[list[str], list[list[tuple[int, int]]]]): +def reconstruct_arrow( + result: tuple[list[str], list[list[tuple[int, int]]]], +) -> pa.Table: import pyarrow as pa names, ptrs = result @@ -334,7 +431,7 @@ def reconstruct_arrow(result: tuple[list[str], list[list[tuple[int, int]]]]): return pa.Table.from_batches(rbs) -def reconstruct_pandas(df_infos: dict[str, Any]): +def reconstruct_pandas(df_infos: dict[str, Any]) -> pd.DataFrame: import pandas as pd data = df_infos["data"] @@ -388,6 +485,6 @@ def remove_ending_semicolon(query: str) -> str: SQL query """ - if query.endswith(';'): + if query.endswith(";"): query = query[:-1] return query