Skip to content

Commit

Permalink
fix(python): Check for duplicate column names in read_database curs…
Browse files Browse the repository at this point in the history
…or result, raising `DuplicateError` if found (#18548)
  • Loading branch information
alexander-beedie authored Sep 5, 2024
1 parent aad45cd commit 242d30a
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 14 deletions.
33 changes: 20 additions & 13 deletions py-polars/polars/io/database/_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
from polars import functions as F
from polars._utils.various import parse_version
from polars.convert import from_arrow
from polars.datatypes import (
N_INFER_DEFAULT,
from polars.datatypes import N_INFER_DEFAULT
from polars.exceptions import (
DuplicateError,
ModuleUpgradeRequiredError,
UnsuitableSQLError,
)
from polars.exceptions import ModuleUpgradeRequiredError, UnsuitableSQLError
from polars.io.database._arrow_registry import ARROW_DRIVER_REGISTRY
from polars.io.database._cursor_proxies import ODBCCursorProxy, SurrealDBCursorProxy
from polars.io.database._inference import _infer_dtype_from_cursor_description
Expand Down Expand Up @@ -266,25 +268,25 @@ def _from_rows(
if hasattr(self.result, "fetchall"):
if self.driver_name == "sqlalchemy":
if hasattr(self.result, "cursor"):
cursor_desc = {
d[0]: d[1:] for d in self.result.cursor.description
}
cursor_desc = [
(d[0], d[1:]) for d in self.result.cursor.description
]
elif hasattr(self.result, "_metadata"):
cursor_desc = {k: None for k in self.result._metadata.keys}
cursor_desc = [(k, None) for k in self.result._metadata.keys]
else:
msg = f"Unable to determine metadata from query result; {self.result!r}"
raise ValueError(msg)

elif hasattr(self.result, "description"):
cursor_desc = {d[0]: d[1:] for d in self.result.description}
cursor_desc = [(d[0], d[1:]) for d in self.result.description]
else:
cursor_desc = {}
cursor_desc = []

schema_overrides = self._inject_type_overrides(
description=cursor_desc,
schema_overrides=(schema_overrides or {}),
)
result_columns = list(cursor_desc)
result_columns = [nm for nm, _ in cursor_desc]
frames = (
DataFrame(
data=rows,
Expand All @@ -307,7 +309,7 @@ def _from_rows(

def _inject_type_overrides(
self,
description: dict[str, Any],
description: list[tuple[str, Any]],
schema_overrides: SchemaDict,
) -> SchemaDict:
"""
Expand All @@ -320,11 +322,16 @@ def _inject_type_overrides(
We currently only do the additional inference from string/python type values.
(Further refinement will require per-driver module knowledge and lookups).
"""
for nm, desc in description.items():
if desc is not None and nm not in schema_overrides:
dupe_check = set()
for nm, desc in description:
if nm in dupe_check:
msg = f"column {nm!r} appears more than once in the query/result cursor"
raise DuplicateError(msg)
elif desc is not None and nm not in schema_overrides:
dtype = _infer_dtype_from_cursor_description(self.cursor, desc)
if dtype is not None:
schema_overrides[nm] = dtype # type: ignore[index]
dupe_check.add(nm)

return schema_overrides

Expand Down
19 changes: 18 additions & 1 deletion py-polars/tests/unit/io/database/test_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import polars as pl
from polars._utils.various import parse_version
from polars.exceptions import UnsuitableSQLError
from polars.exceptions import DuplicateError, UnsuitableSQLError
from polars.io.database._arrow_registry import ARROW_DRIVER_REGISTRY
from polars.testing import assert_frame_equal

Expand Down Expand Up @@ -678,6 +678,23 @@ def test_read_database_exceptions(
read_database(**params)


@pytest.mark.parametrize(
"query",
[
"SELECT 1, 1 FROM test_data",
'SELECT 1 AS "n", 2 AS "n" FROM test_data',
'SELECT name, value AS "name" FROM test_data',
],
)
def test_read_database_duplicate_column_error(tmp_sqlite_db: Path, query: str) -> None:
alchemy_conn = create_engine(f"sqlite:///{tmp_sqlite_db}").connect()
with pytest.raises(
DuplicateError,
match="column .+ appears more than once in the query/result cursor",
):
pl.read_database(query, connection=alchemy_conn)


@pytest.mark.parametrize(
"uri",
[
Expand Down

0 comments on commit 242d30a

Please sign in to comment.