Skip to content

Commit

Permalink
refactor(python): Move dedicated inference code out of io.database
Browse files Browse the repository at this point in the history
…executor module (#15526)
  • Loading branch information
alexander-beedie authored Apr 7, 2024
1 parent 63ad8af commit 293833d
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 52 deletions.
58 changes: 6 additions & 52 deletions py-polars/polars/io/database/_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,19 @@
import re
from collections.abc import Coroutine
from contextlib import suppress
from inspect import Parameter, isclass, signature
from inspect import Parameter, signature
from typing import TYPE_CHECKING, Any, Iterable, Sequence

from polars import functions as F
from polars._utils.various import parse_version
from polars.convert import from_arrow
from polars.datatypes import (
INTEGER_DTYPES,
N_INFER_DEFAULT,
UNSIGNED_INTEGER_DTYPES,
Decimal,
Float32,
Float64,
)
from polars.datatypes.convert import _map_py_type_to_dtype
from polars.exceptions import ModuleUpgradeRequired, UnsuitableSQLError
from polars.io.database._arrow_registry import ARROW_DRIVER_REGISTRY
from polars.io.database._cursor_proxies import ODBCCursorProxy, SurrealDBCursorProxy
from polars.io.database._inference import (
_infer_dtype_from_database_typename,
_integer_dtype_from_nbits,
)
from polars.io.database._inference import _infer_dtype_from_cursor_description
from polars.io.database._utils import _run_async

if TYPE_CHECKING:
Expand All @@ -46,7 +37,6 @@
from typing_extensions import Self

from polars import DataFrame
from polars.datatypes import PolarsDataType
from polars.type_aliases import ConnectionOrCursor, Cursor, SchemaDict

try:
Expand Down Expand Up @@ -287,8 +277,8 @@ def _from_rows(
if is_async:
original_result.close()

@staticmethod
def _inject_type_overrides(
self,
description: dict[str, Any],
schema_overrides: SchemaDict,
) -> SchemaDict:
Expand All @@ -297,50 +287,14 @@ def _inject_type_overrides(
Notes
-----
This is limited; the `type_code` property may contain almost anything,
This is limited; the `type_code` description attr may contain almost anything,
from strings or python types to driver-specific codes, classes, enums, etc.
We currently only do the additional inference from string/python type values.
(Further refinement will require per-driver module knowledge and lookups).
"""
dtype: PolarsDataType | None = None
for nm, desc in description.items():
if desc is None:
continue
elif nm not in schema_overrides:
type_code, _disp_size, internal_size, prec, scale, _null_ok = desc
if isclass(type_code):
# python types, eg: int, float, str, etc
with suppress(TypeError):
dtype = _map_py_type_to_dtype(type_code) # type: ignore[arg-type]

elif isinstance(type_code, str):
# database/sql type names, eg: "VARCHAR", "NUMERIC", "BLOB", etc
dtype = _infer_dtype_from_database_typename(
value=type_code,
raise_unmatched=False,
)

if dtype is not None:
# check additional cursor attrs to improve dtype specification
if dtype == Float64 and internal_size == 4:
dtype = Float32

elif dtype in INTEGER_DTYPES and internal_size in (2, 4, 8):
bits = internal_size * 8
dtype = _integer_dtype_from_nbits(
bits,
unsigned=(dtype in UNSIGNED_INTEGER_DTYPES),
default=dtype,
)
elif (
dtype == Decimal
and isinstance(prec, int)
and isinstance(scale, int)
and prec <= 38
and scale <= 38
):
dtype = Decimal(prec, scale)

if desc is not None and nm not in schema_overrides:
dtype = _infer_dtype_from_cursor_description(self.cursor, desc)
if dtype is not None:
schema_overrides[nm] = dtype # type: ignore[index]

Expand Down
49 changes: 49 additions & 0 deletions py-polars/polars/io/database/_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@

import functools
import re
from contextlib import suppress
from inspect import isclass
from typing import TYPE_CHECKING, Any

from polars.datatypes import (
INTEGER_DTYPES,
UNSIGNED_INTEGER_DTYPES,
Binary,
Boolean,
Date,
Expand All @@ -25,6 +29,7 @@
UInt32,
UInt64,
)
from polars.datatypes.convert import _map_py_type_to_dtype

if TYPE_CHECKING:
from polars.type_aliases import PolarsDataType
Expand Down Expand Up @@ -182,6 +187,50 @@ def _infer_dtype_from_database_typename(
return dtype


def _infer_dtype_from_cursor_description(
cursor: Any,
description: tuple[Any, ...],
) -> PolarsDataType | None:
"""Attempt to infer Polars dtype from database cursor description `type_code`."""
type_code, _disp_size, internal_size, precision, scale, _null_ok = description
dtype: PolarsDataType | None = None

if isclass(type_code):
# python types, eg: int, float, str, etc
with suppress(TypeError):
dtype = _map_py_type_to_dtype(type_code) # type: ignore[arg-type]

elif isinstance(type_code, str):
# database/sql type names, eg: "VARCHAR", "NUMERIC", "BLOB", etc
dtype = _infer_dtype_from_database_typename(
value=type_code,
raise_unmatched=False,
)

# check additional cursor attrs to refine dtype specification
if dtype is not None:
if dtype == Float64 and internal_size == 4:
dtype = Float32

elif dtype in INTEGER_DTYPES and internal_size in (2, 4, 8):
bits = internal_size * 8
dtype = _integer_dtype_from_nbits(
bits,
unsigned=(dtype in UNSIGNED_INTEGER_DTYPES),
default=dtype,
)
elif (
dtype == Decimal
and isinstance(precision, int)
and isinstance(scale, int)
and precision <= 38
and scale <= 38
):
dtype = Decimal(precision, scale)

return dtype


@functools.lru_cache(8)
def _integer_dtype_from_nbits(
bits: int,
Expand Down

0 comments on commit 293833d

Please sign in to comment.