Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

typing: optimize type annotations #609

Merged
merged 1 commit into from
Apr 18, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 110 additions & 13 deletions connectorx-python/connectorx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any
from typing import Any, Literal, TYPE_CHECKING, overload

from importlib.metadata import version

Expand All @@ -11,6 +11,13 @@
get_meta as _get_meta,
)

if TYPE_CHECKING:
import pandas as pd
import polars as pl
import modin.pandas as mpd
import dask.dataframe as dd
import pyarrow as pa

__version__ = version(__name__)

import os
Expand All @@ -27,8 +34,10 @@
"CX_REWRITER_PATH", os.path.join(dir_path, "dependencies/federated-rewriter.jar")
)

Protocol = Literal["csv", "binary", "cursor", "simple", "text"]

def rewrite_conn(conn: str, protocol: str | None = None):

def rewrite_conn(conn: str, protocol: Protocol | None = None) -> tuple[str, Protocol]:
if not protocol:
# note: redshift/clickhouse are not compatible with the 'binary' protocol, and use other database
# drivers to connect. set a compatible protocol and masquerade as the appropriate backend.
Expand All @@ -47,8 +56,8 @@ def rewrite_conn(conn: str, protocol: str | None = None):
def get_meta(
conn: str,
query: str,
protocol: str | None = None,
):
protocol: Protocol | None = None,
) -> pd.DataFrame:
"""
Get metadata (header) of the given query (only for pandas)

Expand All @@ -75,7 +84,7 @@ def partition_sql(
partition_on: str,
partition_num: int,
partition_range: tuple[int, int] | None = None,
):
) -> list[str]:
"""
Partition the sql query

Expand Down Expand Up @@ -106,11 +115,11 @@ def read_sql_pandas(
sql: list[str] | str,
con: str | dict[str, str],
index_col: str | None = None,
protocol: str | None = None,
protocol: Protocol | None = None,
partition_on: str | None = None,
partition_range: tuple[int, int] | None = None,
partition_num: int | None = None,
):
) -> pd.DataFrame:
"""
Run the SQL query, download the data from database into a dataframe.
First several parameters are in the same name and order with `pandas.read_sql`.
Expand Down Expand Up @@ -142,17 +151,103 @@ def read_sql_pandas(
)


# default return pd.DataFrame
@overload
def read_sql(
conn: str | dict[str, str],
query: list[str] | str,
*,
return_type: str = "pandas",
protocol: str | None = None,
protocol: Protocol | None = None,
partition_on: str | None = None,
partition_range: tuple[int, int] | None = None,
partition_num: int | None = None,
index_col: str | None = None,
):
) -> pd.DataFrame: ...


@overload
def read_sql(
conn: str | dict[str, str],
query: list[str] | str,
*,
return_type: Literal["pandas"],
protocol: Protocol | None = None,
partition_on: str | None = None,
partition_range: tuple[int, int] | None = None,
partition_num: int | None = None,
index_col: str | None = None,
) -> pd.DataFrame: ...


@overload
def read_sql(
conn: str | dict[str, str],
query: list[str] | str,
*,
return_type: Literal["arrow", "arrow2"],
protocol: Protocol | None = None,
partition_on: str | None = None,
partition_range: tuple[int, int] | None = None,
partition_num: int | None = None,
index_col: str | None = None,
) -> pa.Table: ...


@overload
def read_sql(
conn: str | dict[str, str],
query: list[str] | str,
*,
return_type: Literal["modin"],
protocol: Protocol | None = None,
partition_on: str | None = None,
partition_range: tuple[int, int] | None = None,
partition_num: int | None = None,
index_col: str | None = None,
) -> mpd.DataFrame: ...


@overload
def read_sql(
conn: str | dict[str, str],
query: list[str] | str,
*,
return_type: Literal["dask"],
protocol: Protocol | None = None,
partition_on: str | None = None,
partition_range: tuple[int, int] | None = None,
partition_num: int | None = None,
index_col: str | None = None,
) -> dd.DataFrame: ...


@overload
def read_sql(
conn: str | dict[str, str],
query: list[str] | str,
*,
return_type: Literal["polars", "polars2"],
protocol: Protocol | None = None,
partition_on: str | None = None,
partition_range: tuple[int, int] | None = None,
partition_num: int | None = None,
index_col: str | None = None,
) -> pl.DataFrame: ...


def read_sql(
conn: str | dict[str, str],
query: list[str] | str,
*,
return_type: Literal[
"pandas", "polars", "polars2", "arrow", "arrow2", "modin", "dask"
] = "pandas",
protocol: Protocol | None = None,
partition_on: str | None = None,
partition_range: tuple[int, int] | None = None,
partition_num: int | None = None,
index_col: str | None = None,
) -> pd.DataFrame | mpd.DataFrame | dd.DataFrame | pl.DataFrame | pa.Table:
"""
Run the SQL query, download the data from database into a dataframe.

Expand Down Expand Up @@ -318,7 +413,9 @@ def read_sql(
return df


def reconstruct_arrow(result: tuple[list[str], list[list[tuple[int, int]]]]):
def reconstruct_arrow(
result: tuple[list[str], list[list[tuple[int, int]]]],
) -> pa.Table:
import pyarrow as pa

names, ptrs = result
Expand All @@ -334,7 +431,7 @@ def reconstruct_arrow(result: tuple[list[str], list[list[tuple[int, int]]]]):
return pa.Table.from_batches(rbs)


def reconstruct_pandas(df_infos: dict[str, Any]):
def reconstruct_pandas(df_infos: dict[str, Any]) -> pd.DataFrame:
import pandas as pd

data = df_infos["data"]
Expand Down Expand Up @@ -388,6 +485,6 @@ def remove_ending_semicolon(query: str) -> str:
SQL query

"""
if query.endswith(';'):
if query.endswith(";"):
query = query[:-1]
return query
Loading