Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(python): Move dedicated inference code out of io.database executor module #15526

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 6 additions & 52 deletions py-polars/polars/io/database/_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,19 @@
import re
from collections.abc import Coroutine
from contextlib import suppress
from inspect import Parameter, isclass, signature
from inspect import Parameter, signature
from typing import TYPE_CHECKING, Any, Iterable, Sequence

from polars import functions as F
from polars._utils.various import parse_version
from polars.convert import from_arrow
from polars.datatypes import (
INTEGER_DTYPES,
N_INFER_DEFAULT,
UNSIGNED_INTEGER_DTYPES,
Decimal,
Float32,
Float64,
)
from polars.datatypes.convert import _map_py_type_to_dtype
from polars.exceptions import ModuleUpgradeRequired, UnsuitableSQLError
from polars.io.database._arrow_registry import ARROW_DRIVER_REGISTRY
from polars.io.database._cursor_proxies import ODBCCursorProxy, SurrealDBCursorProxy
from polars.io.database._inference import (
_infer_dtype_from_database_typename,
_integer_dtype_from_nbits,
)
from polars.io.database._inference import _infer_dtype_from_cursor_description
from polars.io.database._utils import _run_async

if TYPE_CHECKING:
Expand All @@ -46,7 +37,6 @@
from typing_extensions import Self

from polars import DataFrame
from polars.datatypes import PolarsDataType
from polars.type_aliases import ConnectionOrCursor, Cursor, SchemaDict

try:
Expand Down Expand Up @@ -287,8 +277,8 @@ def _from_rows(
if is_async:
original_result.close()

@staticmethod
def _inject_type_overrides(
self,
description: dict[str, Any],
schema_overrides: SchemaDict,
) -> SchemaDict:
Expand All @@ -297,50 +287,14 @@ def _inject_type_overrides(

Notes
-----
This is limited; the `type_code` property may contain almost anything,
This is limited; the `type_code` description attr may contain almost anything,
from strings or python types to driver-specific codes, classes, enums, etc.
We currently only do the additional inference from string/python type values.
(Further refinement will require per-driver module knowledge and lookups).
"""
dtype: PolarsDataType | None = None
for nm, desc in description.items():
if desc is None:
continue
elif nm not in schema_overrides:
type_code, _disp_size, internal_size, prec, scale, _null_ok = desc
if isclass(type_code):
# python types, eg: int, float, str, etc
with suppress(TypeError):
dtype = _map_py_type_to_dtype(type_code) # type: ignore[arg-type]

elif isinstance(type_code, str):
# database/sql type names, eg: "VARCHAR", "NUMERIC", "BLOB", etc
dtype = _infer_dtype_from_database_typename(
value=type_code,
raise_unmatched=False,
)

if dtype is not None:
# check additional cursor attrs to improve dtype specification
if dtype == Float64 and internal_size == 4:
dtype = Float32

elif dtype in INTEGER_DTYPES and internal_size in (2, 4, 8):
bits = internal_size * 8
dtype = _integer_dtype_from_nbits(
bits,
unsigned=(dtype in UNSIGNED_INTEGER_DTYPES),
default=dtype,
)
elif (
dtype == Decimal
and isinstance(prec, int)
and isinstance(scale, int)
and prec <= 38
and scale <= 38
):
dtype = Decimal(prec, scale)

if desc is not None and nm not in schema_overrides:
dtype = _infer_dtype_from_cursor_description(self.cursor, desc)
if dtype is not None:
schema_overrides[nm] = dtype # type: ignore[index]

Expand Down
49 changes: 49 additions & 0 deletions py-polars/polars/io/database/_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@

import functools
import re
from contextlib import suppress
from inspect import isclass
from typing import TYPE_CHECKING, Any

from polars.datatypes import (
INTEGER_DTYPES,
UNSIGNED_INTEGER_DTYPES,
Binary,
Boolean,
Date,
Expand All @@ -25,6 +29,7 @@
UInt32,
UInt64,
)
from polars.datatypes.convert import _map_py_type_to_dtype

if TYPE_CHECKING:
from polars.type_aliases import PolarsDataType
Expand Down Expand Up @@ -182,6 +187,50 @@ def _infer_dtype_from_database_typename(
return dtype


def _infer_dtype_from_cursor_description(
cursor: Any,
description: tuple[Any, ...],
) -> PolarsDataType | None:
"""Attempt to infer Polars dtype from database cursor description `type_code`."""
type_code, _disp_size, internal_size, precision, scale, _null_ok = description
dtype: PolarsDataType | None = None

if isclass(type_code):
# python types, eg: int, float, str, etc
with suppress(TypeError):
dtype = _map_py_type_to_dtype(type_code) # type: ignore[arg-type]

elif isinstance(type_code, str):
# database/sql type names, eg: "VARCHAR", "NUMERIC", "BLOB", etc
dtype = _infer_dtype_from_database_typename(
value=type_code,
raise_unmatched=False,
)

# check additional cursor attrs to refine dtype specification
if dtype is not None:
if dtype == Float64 and internal_size == 4:
dtype = Float32

elif dtype in INTEGER_DTYPES and internal_size in (2, 4, 8):
bits = internal_size * 8
dtype = _integer_dtype_from_nbits(
bits,
unsigned=(dtype in UNSIGNED_INTEGER_DTYPES),
default=dtype,
)
elif (
dtype == Decimal
and isinstance(precision, int)
and isinstance(scale, int)
and precision <= 38
and scale <= 38
):
dtype = Decimal(precision, scale)

return dtype


@functools.lru_cache(8)
def _integer_dtype_from_nbits(
bits: int,
Expand Down
Loading