diff --git a/connectorx-python/connectorx/__init__.py b/connectorx-python/connectorx/__init__.py index a6a5acc6d..456042077 100644 --- a/connectorx-python/connectorx/__init__.py +++ b/connectorx-python/connectorx/__init__.py @@ -1,16 +1,17 @@ from __future__ import annotations - import importlib from importlib.metadata import version -from typing import Any, Literal, TYPE_CHECKING, overload +from typing import Literal, TYPE_CHECKING, overload from .connectorx import ( read_sql as _read_sql, partition_sql as _partition_sql, read_sql2 as _read_sql2, get_meta as _get_meta, + _DataframeInfos, + _ArrowInfos, ) if TYPE_CHECKING: @@ -394,9 +395,7 @@ def read_sql( return df -def reconstruct_arrow( - result: tuple[list[str], list[list[tuple[int, int]]]], -) -> pa.Table: +def reconstruct_arrow(result: _ArrowInfos) -> pa.Table: import pyarrow as pa names, ptrs = result @@ -412,7 +411,7 @@ def reconstruct_arrow( return pa.Table.from_batches(rbs) -def reconstruct_pandas(df_infos: dict[str, Any]) -> pd.DataFrame: +def reconstruct_pandas(df_infos: _DataframeInfos) -> pd.DataFrame: import pandas as pd data = df_infos["data"] diff --git a/connectorx-python/connectorx/connectorx.pyi b/connectorx-python/connectorx/connectorx.pyi index b556d918b..c1ccbfa53 100644 --- a/connectorx-python/connectorx/connectorx.pyi +++ b/connectorx-python/connectorx/connectorx.pyi @@ -1,11 +1,22 @@ from __future__ import annotations -from typing import overload, Literal, Any, TypeAlias -import pandas as pd +from typing import overload, Literal, Any, TypeAlias, TypedDict +import numpy as np _ArrowArrayPtr: TypeAlias = int _ArrowSchemaPtr: TypeAlias = int -_Column: TypeAlias = str +_Header: TypeAlias = str + +class PandasBlockInfo: + cids: list[int] + dt: int + +class _DataframeInfos(TypedDict): + data: list[tuple[np.ndarray, ...] | np.ndarray] + headers: list[_Header] + block_infos: list[PandasBlockInfo] + +_ArrowInfos = tuple[list[_Header], list[list[tuple[_ArrowArrayPtr, _ArrowSchemaPtr]]]] @overload def read_sql( @@ -14,7 +25,7 @@ def read_sql( protocol: str | None, queries: list[str] | None, partition_query: dict[str, Any] | None, -) -> pd.DataFrame: ... +) -> _DataframeInfos: ... @overload def read_sql( conn: str, @@ -22,13 +33,11 @@ def read_sql( protocol: str | None, queries: list[str] | None, partition_query: dict[str, Any] | None, -) -> tuple[list[_Column], list[list[tuple[_ArrowArrayPtr, _ArrowSchemaPtr]]]]: ... +) -> _ArrowInfos: ... def partition_sql(conn: str, partition_query: dict[str, Any]) -> list[str]: ... -def read_sql2( - sql: str, db_map: dict[str, str] -) -> tuple[list[_Column], list[list[tuple[_ArrowArrayPtr, _ArrowSchemaPtr]]]]: ... +def read_sql2(sql: str, db_map: dict[str, str]) -> _ArrowInfos: ... def get_meta( conn: str, protocol: Literal["csv", "binary", "cursor", "simple", "text"] | None, query: str, -) -> dict[str, Any]: ... +) -> _DataframeInfos: ...