diff --git a/py-polars/polars/io/database/_executor.py b/py-polars/polars/io/database/_executor.py index f8c22d49d8e0..f001f2957165 100644 --- a/py-polars/polars/io/database/_executor.py +++ b/py-polars/polars/io/database/_executor.py @@ -3,28 +3,19 @@ import re from collections.abc import Coroutine from contextlib import suppress -from inspect import Parameter, isclass, signature +from inspect import Parameter, signature from typing import TYPE_CHECKING, Any, Iterable, Sequence from polars import functions as F from polars._utils.various import parse_version from polars.convert import from_arrow from polars.datatypes import ( - INTEGER_DTYPES, N_INFER_DEFAULT, - UNSIGNED_INTEGER_DTYPES, - Decimal, - Float32, - Float64, ) -from polars.datatypes.convert import _map_py_type_to_dtype from polars.exceptions import ModuleUpgradeRequired, UnsuitableSQLError from polars.io.database._arrow_registry import ARROW_DRIVER_REGISTRY from polars.io.database._cursor_proxies import ODBCCursorProxy, SurrealDBCursorProxy -from polars.io.database._inference import ( - _infer_dtype_from_database_typename, - _integer_dtype_from_nbits, -) +from polars.io.database._inference import _infer_dtype_from_cursor_description from polars.io.database._utils import _run_async if TYPE_CHECKING: @@ -46,7 +37,6 @@ from typing_extensions import Self from polars import DataFrame - from polars.datatypes import PolarsDataType from polars.type_aliases import ConnectionOrCursor, Cursor, SchemaDict try: @@ -287,8 +277,8 @@ def _from_rows( if is_async: original_result.close() - @staticmethod def _inject_type_overrides( + self, description: dict[str, Any], schema_overrides: SchemaDict, ) -> SchemaDict: @@ -297,50 +287,14 @@ def _inject_type_overrides( Notes ----- - This is limited; the `type_code` property may contain almost anything, + This is limited; the `type_code` description attr may contain almost anything, from strings or python types to driver-specific codes, classes, enums, etc. We currently only do the additional inference from string/python type values. (Further refinement will require per-driver module knowledge and lookups). """ - dtype: PolarsDataType | None = None for nm, desc in description.items(): - if desc is None: - continue - elif nm not in schema_overrides: - type_code, _disp_size, internal_size, prec, scale, _null_ok = desc - if isclass(type_code): - # python types, eg: int, float, str, etc - with suppress(TypeError): - dtype = _map_py_type_to_dtype(type_code) # type: ignore[arg-type] - - elif isinstance(type_code, str): - # database/sql type names, eg: "VARCHAR", "NUMERIC", "BLOB", etc - dtype = _infer_dtype_from_database_typename( - value=type_code, - raise_unmatched=False, - ) - - if dtype is not None: - # check additional cursor attrs to improve dtype specification - if dtype == Float64 and internal_size == 4: - dtype = Float32 - - elif dtype in INTEGER_DTYPES and internal_size in (2, 4, 8): - bits = internal_size * 8 - dtype = _integer_dtype_from_nbits( - bits, - unsigned=(dtype in UNSIGNED_INTEGER_DTYPES), - default=dtype, - ) - elif ( - dtype == Decimal - and isinstance(prec, int) - and isinstance(scale, int) - and prec <= 38 - and scale <= 38 - ): - dtype = Decimal(prec, scale) - + if desc is not None and nm not in schema_overrides: + dtype = _infer_dtype_from_cursor_description(self.cursor, desc) if dtype is not None: schema_overrides[nm] = dtype # type: ignore[index] diff --git a/py-polars/polars/io/database/_inference.py b/py-polars/polars/io/database/_inference.py index 2eb630f0cda0..daf06571434d 100644 --- a/py-polars/polars/io/database/_inference.py +++ b/py-polars/polars/io/database/_inference.py @@ -2,9 +2,13 @@ import functools import re +from contextlib import suppress +from inspect import isclass from typing import TYPE_CHECKING, Any from polars.datatypes import ( + INTEGER_DTYPES, + UNSIGNED_INTEGER_DTYPES, Binary, Boolean, Date, @@ -25,6 +29,7 @@ UInt32, UInt64, ) +from polars.datatypes.convert import _map_py_type_to_dtype if TYPE_CHECKING: from polars.type_aliases import PolarsDataType @@ -182,6 +187,50 @@ def _infer_dtype_from_database_typename( return dtype +def _infer_dtype_from_cursor_description( + cursor: Any, + description: tuple[Any, ...], +) -> PolarsDataType | None: + """Attempt to infer Polars dtype from database cursor description `type_code`.""" + type_code, _disp_size, internal_size, precision, scale, _null_ok = description + dtype: PolarsDataType | None = None + + if isclass(type_code): + # python types, eg: int, float, str, etc + with suppress(TypeError): + dtype = _map_py_type_to_dtype(type_code) # type: ignore[arg-type] + + elif isinstance(type_code, str): + # database/sql type names, eg: "VARCHAR", "NUMERIC", "BLOB", etc + dtype = _infer_dtype_from_database_typename( + value=type_code, + raise_unmatched=False, + ) + + # check additional cursor attrs to refine dtype specification + if dtype is not None: + if dtype == Float64 and internal_size == 4: + dtype = Float32 + + elif dtype in INTEGER_DTYPES and internal_size in (2, 4, 8): + bits = internal_size * 8 + dtype = _integer_dtype_from_nbits( + bits, + unsigned=(dtype in UNSIGNED_INTEGER_DTYPES), + default=dtype, + ) + elif ( + dtype == Decimal + and isinstance(precision, int) + and isinstance(scale, int) + and precision <= 38 + and scale <= 38 + ): + dtype = Decimal(precision, scale) + + return dtype + + @functools.lru_cache(8) def _integer_dtype_from_nbits( bits: int,