Skip to content

Commit

Permalink
fix: correct read sql return type annotation
Browse files Browse the repository at this point in the history
  • Loading branch information
zen-xu committed Apr 19, 2024
1 parent c5d79ad commit a336e00
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 15 deletions.
11 changes: 5 additions & 6 deletions connectorx-python/connectorx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
from __future__ import annotations


import importlib
from importlib.metadata import version

from typing import Any, Literal, TYPE_CHECKING, overload
from typing import Literal, TYPE_CHECKING, overload

from .connectorx import (
read_sql as _read_sql,
partition_sql as _partition_sql,
read_sql2 as _read_sql2,
get_meta as _get_meta,
_DataframeInfos,
_ArrowInfos,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -394,9 +395,7 @@ def read_sql(
return df


def reconstruct_arrow(
result: tuple[list[str], list[list[tuple[int, int]]]],
) -> pa.Table:
def reconstruct_arrow(result: _ArrowInfos) -> pa.Table:
import pyarrow as pa

names, ptrs = result
Expand All @@ -412,7 +411,7 @@ def reconstruct_arrow(
return pa.Table.from_batches(rbs)


def reconstruct_pandas(df_infos: dict[str, Any]) -> pd.DataFrame:
def reconstruct_pandas(df_infos: _DataframeInfos) -> pd.DataFrame:
import pandas as pd

data = df_infos["data"]
Expand Down
27 changes: 18 additions & 9 deletions connectorx-python/connectorx/connectorx.pyi
Original file line number Diff line number Diff line change
@@ -1,11 +1,22 @@
from __future__ import annotations

from typing import overload, Literal, Any, TypeAlias
import pandas as pd
from typing import overload, Literal, Any, TypeAlias, TypedDict
import numpy as np

_ArrowArrayPtr: TypeAlias = int
_ArrowSchemaPtr: TypeAlias = int
_Column: TypeAlias = str
_Header: TypeAlias = str

class PandasBlockInfo:
cids: list[int]
dt: int

class _DataframeInfos(TypedDict):
data: list[tuple[np.ndarray, ...] | np.ndarray]
headers: list[_Header]
block_infos: list[PandasBlockInfo]

_ArrowInfos = tuple[list[_Header], list[list[tuple[_ArrowArrayPtr, _ArrowSchemaPtr]]]]

@overload
def read_sql(
Expand All @@ -14,21 +25,19 @@ def read_sql(
protocol: str | None,
queries: list[str] | None,
partition_query: dict[str, Any] | None,
) -> pd.DataFrame: ...
) -> _DataframeInfos: ...
@overload
def read_sql(
conn: str,
return_type: Literal["arrow", "arrow2"],
protocol: str | None,
queries: list[str] | None,
partition_query: dict[str, Any] | None,
) -> tuple[list[_Column], list[list[tuple[_ArrowArrayPtr, _ArrowSchemaPtr]]]]: ...
) -> _ArrowInfos: ...
def partition_sql(conn: str, partition_query: dict[str, Any]) -> list[str]: ...
def read_sql2(
sql: str, db_map: dict[str, str]
) -> tuple[list[_Column], list[list[tuple[_ArrowArrayPtr, _ArrowSchemaPtr]]]]: ...
def read_sql2(sql: str, db_map: dict[str, str]) -> _ArrowInfos: ...
def get_meta(
conn: str,
protocol: Literal["csv", "binary", "cursor", "simple", "text"] | None,
query: str,
) -> dict[str, Any]: ...
) -> _DataframeInfos: ...

0 comments on commit a336e00

Please sign in to comment.