From 0b65c330acdb46e907a34accad190b027f23f2b1 Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Mon, 18 Mar 2024 20:43:56 +0400 Subject: [PATCH] feat(python): improved dtype inference/refinement for `read_database` results (#15126) --- py-polars/polars/datatypes/convert.py | 191 +++++++++++++++++- py-polars/polars/io/database.py | 110 ++++++++-- py-polars/tests/unit/io/test_database_read.py | 83 ++++++++ 3 files changed, 363 insertions(+), 21 deletions(-) diff --git a/py-polars/polars/datatypes/convert.py b/py-polars/polars/datatypes/convert.py index d44bf9672908..6630069575ba 100644 --- a/py-polars/polars/datatypes/convert.py +++ b/py-polars/polars/datatypes/convert.py @@ -135,10 +135,191 @@ def _map_py_type_to_dtype( dtype if nested is None else dtype(_map_py_type_to_dtype(nested)) # type: ignore[operator] ) - msg = "invalid type" + msg = f"unrecognised Python type: {python_dtype!r}" raise TypeError(msg) +def _timeunit_from_precision(precision: int | str | None) -> str | None: + """Return `time_unit` from integer precision value.""" + from math import ceil + + if not precision: + return None + elif isinstance(precision, str): + if precision.isdigit(): + precision = int(precision) + elif (precision := precision.lower()) in ("s", "ms", "us", "ns"): + return "ms" if precision == "s" else precision + try: + n = min(max(3, int(ceil(precision / 3)) * 3), 9) # type: ignore[operator] + return {3: "ms", 6: "us", 9: "ns"}.get(n) + except TypeError: + return None + + +def _infer_dtype_from_database_typename( + value: str, + *, + raise_unmatched: bool = True, +) -> PolarsDataType | None: + """Attempt to infer Polars dtype from database cursor `type_code` string value.""" + dtype: PolarsDataType | None = None + + # normalise string name/case (eg: 'IntegerType' -> 'INTEGER') + original_value = value + value = value.upper().replace("TYPE", "") + + # extract optional type modifier (eg: 'VARCHAR(64)' -> '64') + if re.search(r"\([\w,: ]+\)$", value): + modifier = value[value.find("(") + 1 : -1] + value = value.split("(")[0] + elif ( + not value.startswith(("<", ">")) and re.search(r"\[[\w,\]\[: ]+]$", value) + ) or value.endswith(("[S]", "[MS]", "[US]", "[NS]")): + modifier = value[value.find("[") + 1 : -1] + value = value.split("[")[0] + else: + modifier = "" + + # array dtypes + array_aliases = ("ARRAY", "LIST", "[]") + if value.endswith(array_aliases) or value.startswith(array_aliases): + for a in array_aliases: + value = value.replace(a, "", 1) if value else "" + + nested: PolarsDataType | None = None + if not value and modifier: + nested = _infer_dtype_from_database_typename( + value=modifier, + raise_unmatched=False, + ) + else: + if inner_value := _infer_dtype_from_database_typename( + value[1:-1] + if (value[0], value[-1]) == ("<", ">") + else re.sub(r"\W", "", re.sub(r"\WOF\W", "", value)), + raise_unmatched=False, + ): + nested = inner_value + elif modifier: + nested = _infer_dtype_from_database_typename( + value=modifier, + raise_unmatched=False, + ) + if nested: + dtype = List(nested) + + # float dtypes + elif value.startswith("FLOAT") or ("DOUBLE" in value) or (value == "REAL"): + dtype = ( + Float32 + if value == "FLOAT4" + or (value.endswith(("16", "32")) or (modifier in ("16", "32"))) + else Float64 + ) + + # integer dtypes + elif ("INTERVAL" not in value) and ( + value.startswith(("INT", "UINT", "UNSIGNED")) + or value.endswith(("INT", "SERIAL")) + or ("INTEGER" in value) + or value == "ROWID" + ): + sz: Any + if "LARGE" in value or value.startswith("BIG") or value == "INT8": + sz = 64 + elif "MEDIUM" in value or value in ("INT4", "SERIAL"): + sz = 32 + elif "SMALL" in value or value == "INT2": + sz = 16 + elif "TINY" in value: + sz = 8 + else: + sz = None + + sz = modifier if (not sz and modifier) else sz + if not isinstance(sz, int): + sz = int(sz) if isinstance(sz, str) and sz.isdigit() else None + if ( + ("U" in value and "MEDIUM" not in value) + or ("UNSIGNED" in value) + or value == "ROWID" + ): + dtype = _integer_dtype_from_nbits(sz, unsigned=True, default=UInt64) + else: + dtype = _integer_dtype_from_nbits(sz, unsigned=False, default=Int64) + + # decimal dtypes + elif (is_dec := ("DECIMAL" in value)) or ("NUMERIC" in value): + if "," in modifier: + prec, scale = modifier.split(",") + dtype = Decimal(int(prec), int(scale)) + else: + dtype = Decimal if is_dec else Float64 + + # string dtypes + elif ( + any(tp in value for tp in ("VARCHAR", "STRING", "TEXT", "UNICODE")) + or value.startswith(("STR", "CHAR", "NCHAR", "UTF")) + or value.endswith(("_UTF8", "_UTF16", "_UTF32")) + ): + dtype = String + + # binary dtypes + elif value in ("BYTEA", "BYTES", "BLOB", "CLOB", "BINARY"): + dtype = Binary + + # boolean dtypes + elif value.startswith("BOOL"): + dtype = Boolean + + # temporal dtypes + elif value.startswith(("DATETIME", "TIMESTAMP")) and not (value.endswith("[D]")): + if any((tz in value.replace(" ", "")) for tz in ("TZ", "TIMEZONE")): + if "WITHOUT" not in value: + return None # there's a timezone, but we don't know what it is + unit = _timeunit_from_precision(modifier) if modifier else "us" + dtype = Datetime(time_unit=(unit or "us")) # type: ignore[arg-type] + + elif re.sub(r"\d", "", value) in ("INTERVAL", "TIMEDELTA"): + dtype = Duration + + elif value in ("DATE", "DATE32", "DATE64"): + dtype = Date + + elif value in ("TIME", "TIME32", "TIME64"): + dtype = Time + + if not dtype and raise_unmatched: + msg = f"cannot infer dtype from {original_value!r} string value" + raise ValueError(msg) + + return dtype + + +@functools.lru_cache(8) +def _integer_dtype_from_nbits( + bits: int, + *, + unsigned: bool, + default: PolarsDataType | None = None, +) -> PolarsDataType | None: + dtype = { + (8, False): Int8, + (8, True): UInt8, + (16, False): Int16, + (16, True): UInt16, + (32, False): Int32, + (32, True): UInt32, + (64, False): Int64, + (64, True): UInt64, + }.get((bits, unsigned), None) + + if dtype is None and default is not None: + return default + return dtype + + def is_polars_dtype(dtype: Any, *, include_unknown: bool = False) -> bool: """Indicate whether the given input is a Polars dtype, or dtype specialization.""" try: @@ -415,10 +596,10 @@ def py_type_to_dtype( try: return _map_py_type_to_dtype(data_type) except (KeyError, TypeError): # pragma: no cover - if not raise_unmatched: - return None - msg = f"cannot infer dtype from {data_type!r} (type: {type(data_type).__name__!r})" - raise ValueError(msg) from None + if raise_unmatched: + msg = f"cannot infer dtype from {data_type!r} (type: {type(data_type).__name__!r})" + raise ValueError(msg) from None + return None def py_type_to_arrow_type(dtype: PythonDataType) -> pa.lib.DataType: diff --git a/py-polars/polars/io/database.py b/py-polars/polars/io/database.py index f7e933b1e448..abb432c88e3c 100644 --- a/py-polars/polars/io/database.py +++ b/py-polars/polars/io/database.py @@ -2,13 +2,26 @@ import re import sys +from contextlib import suppress from importlib import import_module -from inspect import Parameter, signature +from inspect import Parameter, isclass, signature from typing import TYPE_CHECKING, Any, Iterable, Literal, Sequence, TypedDict, overload from polars._utils.deprecation import issue_deprecation_warning from polars.convert import from_arrow -from polars.datatypes import N_INFER_DEFAULT +from polars.datatypes import ( + INTEGER_DTYPES, + N_INFER_DEFAULT, + UNSIGNED_INTEGER_DTYPES, + Decimal, + Float32, + Float64, +) +from polars.datatypes.convert import ( + _infer_dtype_from_database_typename, + _integer_dtype_from_nbits, + _map_py_type_to_dtype, +) from polars.exceptions import InvalidOperationError, UnsuitableSQLError if TYPE_CHECKING: @@ -26,6 +39,7 @@ from typing_extensions import Self from polars import DataFrame + from polars.datatypes import PolarsDataType from polars.type_aliases import ConnectionOrCursor, Cursor, DbReadEngine, SchemaDict try: @@ -295,17 +309,19 @@ def _from_rows( if hasattr(self.result, "fetchall"): if self.driver_name == "sqlalchemy": if hasattr(self.result, "cursor"): - cursor_desc = {d[0]: d[1] for d in self.result.cursor.description} + cursor_desc = {d[0]: d[1:] for d in self.result.cursor.description} elif hasattr(self.result, "_metadata"): cursor_desc = {k: None for k in self.result._metadata.keys} else: msg = f"Unable to determine metadata from query result; {self.result!r}" raise ValueError(msg) else: - cursor_desc = {d[0]: d[1] for d in self.result.description} + cursor_desc = {d[0]: d[1:] for d in self.result.description} - # TODO: refine types based on the cursor description's type_code, - # if/where available? (for now, we just read the column names) + schema_overrides = self._inject_type_overrides( + description=cursor_desc, + schema_overrides=(schema_overrides or {}), + ) result_columns = list(cursor_desc) frames = ( DataFrame( @@ -324,17 +340,79 @@ def _from_rows( return frames if iter_batches else next(frames) # type: ignore[arg-type] return None - def _normalise_cursor(self, conn: ConnectionOrCursor) -> Cursor: + def _inject_type_overrides( + self, + description: dict[str, Any], + schema_overrides: SchemaDict, + ) -> SchemaDict: + """Attempt basic dtype inference from a cursor description.""" + # note: this is limited; the `type_code` property may contain almost anything, + # from strings or python types to driver-specific codes, classes, enums, etc. + # currently we only do additional inference from string/python type values. + # (further refinement requires per-driver module knowledge and lookups). + + dtype: PolarsDataType | None = None + for nm, desc in description.items(): + if desc is None: + continue + elif nm not in schema_overrides: + type_code, _disp_size, internal_size, prec, scale, _null_ok = desc + if isclass(type_code): + # python types, eg: int, float, str, etc + with suppress(TypeError): + dtype = _map_py_type_to_dtype(type_code) # type: ignore[arg-type] + + elif isinstance(type_code, str): + # database/sql type names, eg: "VARCHAR", "NUMERIC", "BLOB", etc + dtype = _infer_dtype_from_database_typename( + value=type_code, + raise_unmatched=False, + ) + + if dtype is not None: + # check additional cursor information to improve dtype inference + if dtype == Float64 and internal_size == 4: + dtype = Float32 + + elif dtype in INTEGER_DTYPES and internal_size in (2, 4, 8): + bits = internal_size * 8 + dtype = _integer_dtype_from_nbits( + bits, + unsigned=(dtype in UNSIGNED_INTEGER_DTYPES), + default=dtype, + ) + elif ( + dtype == Decimal + and isinstance(prec, int) + and isinstance(scale, int) + and prec <= 38 + and scale <= 38 + ): + dtype = Decimal(prec, scale) + + if dtype is not None: + schema_overrides[nm] = dtype # type: ignore[index] + + return schema_overrides + + def _normalise_cursor(self, conn: Any) -> Cursor: """Normalise a connection object such that we have the query executor.""" - if self.driver_name == "sqlalchemy" and type(conn).__name__ == "Engine": - self.can_close_cursor = True - if conn.driver == "databricks-sql-python": # type: ignore[union-attr] - # take advantage of the raw connection to get arrow integration - self.driver_name = "databricks" - return conn.raw_connection().cursor() # type: ignore[union-attr, return-value] + if self.driver_name == "sqlalchemy": + self.can_close_cursor = (conn_type := type(conn).__name__) == "Engine" + if conn_type == "Session": + return conn else: - # sqlalchemy engine; direct use is deprecated, so prefer the connection - return conn.connect() # type: ignore[union-attr, return-value] + # where possible, use the raw connection to access arrow integration + if conn.engine.driver == "databricks-sql-python": + self.driver_name = "databricks" + return conn.engine.raw_connection().cursor() + elif conn.engine.driver == "duckdb_engine": + self.driver_name = "duckdb" + return conn.engine.raw_connection().driver_connection.c + elif conn_type == "Engine": + return conn.connect() + else: + return conn elif hasattr(conn, "cursor"): # connection has a dedicated cursor; prefer over direct execute @@ -344,7 +422,7 @@ def _normalise_cursor(self, conn: ConnectionOrCursor) -> Cursor: elif hasattr(conn, "execute"): # can execute directly (given cursor, sqlalchemy connection, etc) - return conn # type: ignore[return-value] + return conn msg = f"Unrecognised connection {conn!r}; unable to find 'execute' method" raise TypeError(msg) diff --git a/py-polars/tests/unit/io/test_database_read.py b/py-polars/tests/unit/io/test_database_read.py index fde1b294d6f0..509ca2b4d2b6 100644 --- a/py-polars/tests/unit/io/test_database_read.py +++ b/py-polars/tests/unit/io/test_database_read.py @@ -16,11 +16,13 @@ from sqlalchemy.sql.expression import cast as alchemy_cast import polars as pl +from polars.datatypes.convert import _infer_dtype_from_database_typename from polars.exceptions import ComputeError, UnsuitableSQLError from polars.io.database import _ARROW_DRIVER_REGISTRY_ from polars.testing import assert_frame_equal if TYPE_CHECKING: + from polars.datatypes import PolarsDataType from polars.type_aliases import ( ConnectionOrCursor, DbReadEngine, @@ -806,3 +808,84 @@ def test_read_kuzu_graph_database(tmp_path: Path, io_files_path: Path) -> None: schema={"a.name": pl.Utf8, "f.since": pl.Int64, "b.name": pl.Utf8} ), ) + + +@pytest.mark.parametrize( + ("value", "expected_dtype"), + [ + # string types + ("UTF16", pl.String), + ("char(8)", pl.String), + ("nchar[128]", pl.String), + ("varchar", pl.String), + ("CHARACTER VARYING(64)", pl.String), + ("nvarchar(32)", pl.String), + ("TEXT", pl.String), + # array types + ("float32[]", pl.List(pl.Float32)), + ("double array", pl.List(pl.Float64)), + ("array[bool]", pl.List(pl.Boolean)), + ("array of nchar(8)", pl.List(pl.String)), + ("array[array[int8]]", pl.List(pl.List(pl.Int64))), + # numeric types + ("numeric[10,5]", pl.Decimal(10, 5)), + ("bigdecimal", pl.Decimal), + ("decimal128(10,5)", pl.Decimal(10, 5)), + ("double precision", pl.Float64), + ("floating point", pl.Float64), + ("numeric", pl.Float64), + ("real", pl.Float64), + ("boolean", pl.Boolean), + ("tinyint", pl.Int8), + ("smallint", pl.Int16), + ("int", pl.Int64), + ("int4", pl.Int32), + ("int2", pl.Int16), + ("int(16)", pl.Int16), + ("ROWID", pl.UInt64), + ("mediumint", pl.Int32), + ("unsigned mediumint", pl.UInt32), + ("smallserial", pl.Int16), + ("serial", pl.Int32), + ("bigserial", pl.Int64), + # temporal types + ("timestamp(3)", pl.Datetime("ms")), + ("timestamp(5)", pl.Datetime("us")), + ("timestamp(7)", pl.Datetime("ns")), + ("datetime without tz", pl.Datetime("us")), + ("date", pl.Date), + ("time", pl.Time), + ("date32", pl.Date), + ("time64", pl.Time), + # binary types + ("BYTEA", pl.Binary), + ("BLOB", pl.Binary), + ], +) +def test_database_dtype_inference_from_string( + value: str, + expected_dtype: PolarsDataType, +) -> None: + inferred_dtype = _infer_dtype_from_database_typename(value) + assert inferred_dtype == expected_dtype # type: ignore[operator] + + +@pytest.mark.parametrize( + "value", + [ + "FooType", + "Unknown", + "MISSING", + "XML", # note: we deliberately exclude "number" as it is ambiguous. + "Number", # (could refer to any size of int, float, or decimal dtype) + ], +) +def test_database_dtype_inference_from_invalid_string(value: str) -> None: + with pytest.raises(ValueError, match="cannot infer dtype"): + _infer_dtype_from_database_typename(value) + + inferred_dtype = _infer_dtype_from_database_typename( + value=value, + raise_unmatched=False, + ) + assert inferred_dtype is None