From e5a704deb0556f3e81383078e5b2809a075d8a9d Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Tue, 10 Dec 2024 02:37:29 +0530 Subject: [PATCH 01/12] added the initial skeleton for the polars database dataset Signed-off-by: Minura Punchihewa --- .../polars/__init__.py | 11 +++ .../polars/polars_database_dataset.py | 77 +++++++++++++++++++ 2 files changed, 88 insertions(+) create mode 100644 kedro-datasets/kedro_datasets_experimental/polars/__init__.py create mode 100644 kedro-datasets/kedro_datasets_experimental/polars/polars_database_dataset.py diff --git a/kedro-datasets/kedro_datasets_experimental/polars/__init__.py b/kedro-datasets/kedro_datasets_experimental/polars/__init__.py new file mode 100644 index 000000000..49f696047 --- /dev/null +++ b/kedro-datasets/kedro_datasets_experimental/polars/__init__.py @@ -0,0 +1,11 @@ +"""``AbstractDataset`` implementation to load/save to databases using the Polars library.""" + +from typing import Any + +import lazy_loader as lazy + +PolarsDatabaseDataset: Any + +__getattr__, __dir__, __all__ = lazy.attach( + __name__, submod_attrs={"polars_database_dataset": ["PolarsDatabaseDataset"]} +) diff --git a/kedro-datasets/kedro_datasets_experimental/polars/polars_database_dataset.py b/kedro-datasets/kedro_datasets_experimental/polars/polars_database_dataset.py new file mode 100644 index 000000000..db2eb041d --- /dev/null +++ b/kedro-datasets/kedro_datasets_experimental/polars/polars_database_dataset.py @@ -0,0 +1,77 @@ +import copy +from typing import Any, NoReturn + +import fsspec +import polars as pl +from kedro.io.core import ( + AbstractDataset, + DatasetError, + get_protocol_and_path, +) + + +class PolarsDatabaseDataset(AbstractDataset[None, pl.DataFrame]): + def __init__( # noqa: PLR0913 + self, + sql: str | None = None, + credentials: dict[str, Any] | None = None, + load_args: dict[str, Any] | None = None, + fs_args: dict[str, Any] | None = None, + filepath: str | None = None, + execution_options: dict[str, Any] | None = None, + metadata: dict[str, Any] | None = None, + ) -> None: + if sql and filepath: + raise DatasetError( + "'sql' and 'filepath' arguments cannot both be provided." + "Please only provide one." + ) + + if not (sql or filepath): + raise DatasetError( + "'sql' and 'filepath' arguments cannot both be empty." + "Please provide a sql query or path to a sql query file." + ) + + if not (credentials and "con" in credentials and credentials["con"]): + raise DatasetError( + "'con' argument cannot be empty. Please " + "provide a SQLAlchemy connection string." + ) + + default_load_args: dict[str, Any] = {} + + self._load_args = ( + {**default_load_args, **load_args} + if load_args is not None + else default_load_args + ) + + self.metadata = metadata + + # load sql query from file + if sql: + self._load_args["sql"] = sql + self._filepath = None + else: + # filesystem for loading sql file + _fs_args = copy.deepcopy(fs_args) or {} + _fs_credentials = _fs_args.pop("credentials", {}) + protocol, path = get_protocol_and_path(str(filepath)) + + self._protocol = protocol + self._fs = fsspec.filesystem(self._protocol, **_fs_credentials, **_fs_args) + self._filepath = path + self._connection_str = credentials["con"] + self._connection_args = { + k: credentials[k] for k in credentials.keys() if k != "con" + } + self._execution_options = execution_options or {} + if "mssql" in self._connection_str: + self.adapt_mssql_date_params() + + def load(self) -> pl.DataFrame: + pass + + def save(self, data: None) -> NoReturn: + pass \ No newline at end of file From 6fc01a89c44f736ffaa55258719600ec9c5fafb0 Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Tue, 14 Jan 2025 23:29:30 +0530 Subject: [PATCH 02/12] updated the implementation by extending SQLQueryDataset Signed-off-by: Minura Punchihewa --- .../polars/polars_database_dataset.py | 79 ++++++------------- 1 file changed, 25 insertions(+), 54 deletions(-) diff --git a/kedro-datasets/kedro_datasets_experimental/polars/polars_database_dataset.py b/kedro-datasets/kedro_datasets_experimental/polars/polars_database_dataset.py index db2eb041d..2ab287265 100644 --- a/kedro-datasets/kedro_datasets_experimental/polars/polars_database_dataset.py +++ b/kedro-datasets/kedro_datasets_experimental/polars/polars_database_dataset.py @@ -1,16 +1,14 @@ import copy +from pathlib import PurePosixPath from typing import Any, NoReturn -import fsspec import polars as pl -from kedro.io.core import ( - AbstractDataset, - DatasetError, - get_protocol_and_path, -) +from kedro_datasets.pandas.sql_dataset import SQLQueryDataset, get_filepath_str + + +class PolarsDatabaseDataset(SQLQueryDataset): -class PolarsDatabaseDataset(AbstractDataset[None, pl.DataFrame]): def __init__( # noqa: PLR0913 self, sql: str | None = None, @@ -18,60 +16,33 @@ def __init__( # noqa: PLR0913 load_args: dict[str, Any] | None = None, fs_args: dict[str, Any] | None = None, filepath: str | None = None, - execution_options: dict[str, Any] | None = None, metadata: dict[str, Any] | None = None, ) -> None: - if sql and filepath: - raise DatasetError( - "'sql' and 'filepath' arguments cannot both be provided." - "Please only provide one." - ) - - if not (sql or filepath): - raise DatasetError( - "'sql' and 'filepath' arguments cannot both be empty." - "Please provide a sql query or path to a sql query file." - ) - - if not (credentials and "con" in credentials and credentials["con"]): - raise DatasetError( - "'con' argument cannot be empty. Please " - "provide a SQLAlchemy connection string." - ) - - default_load_args: dict[str, Any] = {} - - self._load_args = ( - {**default_load_args, **load_args} - if load_args is not None - else default_load_args + """Creates a new ``PolarsDatabaseDataset``.""" + super().__init__( + sql=sql, + credentials=credentials, + load_args=load_args, + fs_args=fs_args, + filepath=filepath, + metadata=metadata, ) - self.metadata = metadata + def load(self) -> pl.DataFrame: + load_args = copy.deepcopy(self._load_args) - # load sql query from file - if sql: - self._load_args["sql"] = sql - self._filepath = None + if self._filepath: + load_path = get_filepath_str(PurePosixPath(self._filepath), self._protocol) + with self._fs.open(load_path, mode="r") as fs_file: + query = fs_file.read() else: - # filesystem for loading sql file - _fs_args = copy.deepcopy(fs_args) or {} - _fs_credentials = _fs_args.pop("credentials", {}) - protocol, path = get_protocol_and_path(str(filepath)) - - self._protocol = protocol - self._fs = fsspec.filesystem(self._protocol, **_fs_credentials, **_fs_args) - self._filepath = path - self._connection_str = credentials["con"] - self._connection_args = { - k: credentials[k] for k in credentials.keys() if k != "con" - } - self._execution_options = execution_options or {} - if "mssql" in self._connection_str: - self.adapt_mssql_date_params() + query = load_args.pop("sql") - def load(self) -> pl.DataFrame: - pass + return pl.read_database( + query=query, + connection=self._connection_str, + **load_args + ) def save(self, data: None) -> NoReturn: pass \ No newline at end of file From 87766d5512fbda87cfddbf6cdd8a4c36a43ac2c4 Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Tue, 28 Jan 2025 09:17:45 +0530 Subject: [PATCH 03/12] removed dependency on SQLQueryDataset Signed-off-by: Minura Punchihewa --- .../polars/polars_database_dataset.py | 196 +++++++++++++++++- 1 file changed, 186 insertions(+), 10 deletions(-) diff --git a/kedro-datasets/kedro_datasets_experimental/polars/polars_database_dataset.py b/kedro-datasets/kedro_datasets_experimental/polars/polars_database_dataset.py index 2ab287265..98a59709e 100644 --- a/kedro-datasets/kedro_datasets_experimental/polars/polars_database_dataset.py +++ b/kedro-datasets/kedro_datasets_experimental/polars/polars_database_dataset.py @@ -1,13 +1,93 @@ import copy from pathlib import PurePosixPath +import re from typing import Any, NoReturn +import fsspec +from kedro.io.core import ( + AbstractDataset, + DatasetError, + get_filepath_str, + get_protocol_and_path, +) import polars as pl +from sqlalchemy import MetaData, Table, create_engine, inspect, select +from sqlalchemy.exc import NoSuchModuleError -from kedro_datasets.pandas.sql_dataset import SQLQueryDataset, get_filepath_str +from kedro_datasets._typing import TablePreview -class PolarsDatabaseDataset(SQLQueryDataset): + +KNOWN_PIP_INSTALL = { + "psycopg2": "psycopg2", + "mysqldb": "mysqlclient", + "cx_Oracle": "cx_Oracle", + "mssql": "pyodbc", +} + +DRIVER_ERROR_MESSAGE = """ +A module/driver is missing when connecting to your SQL server. SQLDataset + supports SQLAlchemy drivers. Please refer to + https://docs.sqlalchemy.org/core/engines.html#supported-databases + for more information. +\n\n +""" + + +def _find_known_drivers(module_import_error: ImportError) -> str | None: + """Looks up known keywords in a ``ModuleNotFoundError`` so that it can + provide better guideline for the user. + + Args: + module_import_error: Error raised while connecting to a SQL server. + + Returns: + Instructions for installing missing driver. An empty string is + returned in case error is related to an unknown driver. + + """ + + # module errors contain string "No module name 'module_name'" + # we are trying to extract module_name surrounded by quotes here + res = re.findall(r"'(.*?)'", str(module_import_error.args[0]).lower()) + + # in case module import error does not match our expected pattern + # we have no recommendation + if not res: + return None + + missing_module = res[0] + + if KNOWN_PIP_INSTALL.get(missing_module): + return ( + f"You can also try installing missing driver with\n" + f"\npip install {KNOWN_PIP_INSTALL.get(missing_module)}" + ) + + return None + + +def _get_missing_module_error(import_error: ImportError) -> DatasetError: + missing_module_instruction = _find_known_drivers(import_error) + + if missing_module_instruction is None: + return DatasetError( + f"{DRIVER_ERROR_MESSAGE}Loading failed with error:\n\n{str(import_error)}" + ) + + return DatasetError(f"{DRIVER_ERROR_MESSAGE}{missing_module_instruction}") + + +def _get_sql_alchemy_missing_error() -> DatasetError: + return DatasetError( + "The SQL dialect in your connection is not supported by " + "SQLAlchemy. Please refer to " + "https://docs.sqlalchemy.org/core/engines.html#supported-databases " + "for more information." + ) + + +class PolarsDatabaseDataset(AbstractDataset[None, pl.DataFrame]): def __init__( # noqa: PLR0913 self, @@ -19,15 +99,82 @@ def __init__( # noqa: PLR0913 metadata: dict[str, Any] | None = None, ) -> None: """Creates a new ``PolarsDatabaseDataset``.""" - super().__init__( - sql=sql, - credentials=credentials, - load_args=load_args, - fs_args=fs_args, - filepath=filepath, - metadata=metadata, + if sql and filepath: + raise DatasetError( + "'sql' and 'filepath' arguments cannot both be provided." + "Please only provide one." + ) + + if not (sql or filepath): + raise DatasetError( + "'sql' and 'filepath' arguments cannot both be empty." + "Please provide a sql query or path to a sql query file." + ) + + if not (credentials and "con" in credentials and credentials["con"]): + raise DatasetError( + "'con' argument cannot be empty. Please " + "provide a SQLAlchemy connection string." + ) + + default_load_args: dict[str, Any] = {} + + self._load_args = ( + {**default_load_args, **load_args} + if load_args is not None + else default_load_args ) + self.metadata = metadata + + # load sql query from file + if sql: + self._load_args["sql"] = sql + self._filepath = None + else: + # filesystem for loading sql file + _fs_args = copy.deepcopy(fs_args) or {} + _fs_credentials = _fs_args.pop("credentials", {}) + protocol, path = get_protocol_and_path(str(filepath)) + + self._protocol = protocol + self._fs = fsspec.filesystem(self._protocol, **_fs_credentials, **_fs_args) + self._filepath = path + self._connection_str = credentials["con"] + self._connection_args = { + k: credentials[k] for k in credentials.keys() if k != "con" + } + if "mssql" in self._connection_str: + self.adapt_mssql_date_params() + + @classmethod + def create_connection( + cls, connection_str: str, connection_args: dict | None = None + ) -> None: + """Given a connection string, create singleton connection + to be used across all instances of `PolarsDatabaseDataset` that + need to connect to the same source. + """ + connection_args = connection_args or {} + try: + engine = create_engine(connection_str, **connection_args) + except ImportError as import_error: + raise _get_missing_module_error(import_error) from import_error + except NoSuchModuleError as exc: + raise _get_sql_alchemy_missing_error() from exc + + cls.engines[connection_str] = engine + + @property + def engine(self): + """The ``Engine`` object for the dataset's connection string.""" + cls = type(self) + + if self._connection_str not in cls.engines: + self.create_connection(self._connection_str, self._connection_args) + + return cls.engines[self._connection_str] + def load(self) -> pl.DataFrame: load_args = copy.deepcopy(self._load_args) @@ -45,4 +192,33 @@ def load(self) -> pl.DataFrame: ) def save(self, data: None) -> NoReturn: - pass \ No newline at end of file + pass + + def _exists(self) -> bool: + insp = inspect(self.engine) + schema = self._load_args.get("schema", None) + return insp.has_table(self._load_args["table_name"], schema) + + def preview(self, nrows: int = 5) -> TablePreview: + """ + Generate a preview of the dataset with a specified number of rows. + + Args: + nrows: The number of rows to include in the preview. Defaults to 5. + + Returns: + dict: A dictionary containing the data in a split format. + """ + table_name = self._load_args["table_name"] + + metadata = MetaData() + table_ref = Table(table_name, metadata, autoload_with=self.engine) + + query = select(table_ref).limit(nrows) # type: ignore[arg-type] + + with self.engine.connect() as conn: + result = conn.execute(query) + data_preview = pl.DataFrame(result.fetchall(), columns=result.keys()) + + preview_data = data_preview.to_dict(orient="split") + return preview_data \ No newline at end of file From 0520de6cf58b84f9c42a66e3de77e5f40bd994c7 Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Tue, 28 Jan 2025 09:21:34 +0530 Subject: [PATCH 04/12] added the missing func to adapt mssql params Signed-off-by: Minura Punchihewa --- .../polars/polars_database_dataset.py | 25 ++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/kedro-datasets/kedro_datasets_experimental/polars/polars_database_dataset.py b/kedro-datasets/kedro_datasets_experimental/polars/polars_database_dataset.py index 98a59709e..eb436de36 100644 --- a/kedro-datasets/kedro_datasets_experimental/polars/polars_database_dataset.py +++ b/kedro-datasets/kedro_datasets_experimental/polars/polars_database_dataset.py @@ -221,4 +221,27 @@ def preview(self, nrows: int = 5) -> TablePreview: data_preview = pl.DataFrame(result.fetchall(), columns=result.keys()) preview_data = data_preview.to_dict(orient="split") - return preview_data \ No newline at end of file + return preview_data + + # For mssql only + def adapt_mssql_date_params(self) -> None: + """We need to change the format of datetime parameters. + MSSQL expects datetime in the exact format %y-%m-%dT%H:%M:%S. + Here, we also accept plain dates. + `pyodbc` does not accept named parameters, they must be provided as a list.""" + params = self._load_args.get("params", []) + if not isinstance(params, list): + raise DatasetError( + "Unrecognized `params` format. It can be only a `list`, " + f"got {type(params)!r}" + ) + new_load_args = [] + for value in params: + try: + as_date = dt.date.fromisoformat(value) + new_val = dt.datetime.combine(as_date, dt.time.min) + new_load_args.append(new_val.strftime("%Y-%m-%dT%H:%M:%S")) + except (TypeError, ValueError): + new_load_args.append(value) + if new_load_args: + self._load_args["params"] = tuple(new_load_args) \ No newline at end of file From 4a57c15c45fb995115ab929a0a2f661bc6ab5b79 Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Tue, 28 Jan 2025 12:05:32 +0530 Subject: [PATCH 05/12] implemented save() Signed-off-by: Minura Punchihewa --- .../polars/polars_database_dataset.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/kedro-datasets/kedro_datasets_experimental/polars/polars_database_dataset.py b/kedro-datasets/kedro_datasets_experimental/polars/polars_database_dataset.py index eb436de36..5d3ffe326 100644 --- a/kedro-datasets/kedro_datasets_experimental/polars/polars_database_dataset.py +++ b/kedro-datasets/kedro_datasets_experimental/polars/polars_database_dataset.py @@ -96,6 +96,7 @@ def __init__( # noqa: PLR0913 load_args: dict[str, Any] | None = None, fs_args: dict[str, Any] | None = None, filepath: str | None = None, + table_name: str | None = None, metadata: dict[str, Any] | None = None, ) -> None: """Creates a new ``PolarsDatabaseDataset``.""" @@ -125,6 +126,7 @@ def __init__( # noqa: PLR0913 else default_load_args ) + self.table_name = table_name self.metadata = metadata # load sql query from file @@ -191,8 +193,16 @@ def load(self) -> pl.DataFrame: **load_args ) - def save(self, data: None) -> NoReturn: - pass + def save(self, data: pl.DataFrame) -> NoReturn: + if not self.table_name: + raise DatasetError( + "'table_name' argument is required to save datasets." + ) + + data.write_database( + table=self.table_name, + connection=self._connection_str, + ) def _exists(self) -> bool: insp = inspect(self.engine) From 5ec37cd9a78ff80fe18b6c9a4a1dbe96b782e04c Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Tue, 28 Jan 2025 12:11:39 +0530 Subject: [PATCH 06/12] added missing import Signed-off-by: Minura Punchihewa --- .../polars/polars_database_dataset.py | 1 + 1 file changed, 1 insertion(+) diff --git a/kedro-datasets/kedro_datasets_experimental/polars/polars_database_dataset.py b/kedro-datasets/kedro_datasets_experimental/polars/polars_database_dataset.py index 5d3ffe326..93300ff01 100644 --- a/kedro-datasets/kedro_datasets_experimental/polars/polars_database_dataset.py +++ b/kedro-datasets/kedro_datasets_experimental/polars/polars_database_dataset.py @@ -1,4 +1,5 @@ import copy +import datetime as dt from pathlib import PurePosixPath import re from typing import Any, NoReturn From c410cf302c14aaaf6e007cf11ab2f137c8042c50 Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Tue, 28 Jan 2025 12:20:25 +0530 Subject: [PATCH 07/12] introduced the save_args param Signed-off-by: Minura Punchihewa --- .../polars/polars_database_dataset.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/kedro-datasets/kedro_datasets_experimental/polars/polars_database_dataset.py b/kedro-datasets/kedro_datasets_experimental/polars/polars_database_dataset.py index 93300ff01..3323dac49 100644 --- a/kedro-datasets/kedro_datasets_experimental/polars/polars_database_dataset.py +++ b/kedro-datasets/kedro_datasets_experimental/polars/polars_database_dataset.py @@ -1,3 +1,5 @@ +"""``PolarsDatabaseDataset`` to load and save data to a SQL backend using Polars.""" + import copy import datetime as dt from pathlib import PurePosixPath @@ -92,12 +94,14 @@ class PolarsDatabaseDataset(AbstractDataset[None, pl.DataFrame]): def __init__( # noqa: PLR0913 self, + *, sql: str | None = None, credentials: dict[str, Any] | None = None, load_args: dict[str, Any] | None = None, fs_args: dict[str, Any] | None = None, filepath: str | None = None, table_name: str | None = None, + save_args: dict[str, Any] | None = None, metadata: dict[str, Any] | None = None, ) -> None: """Creates a new ``PolarsDatabaseDataset``.""" @@ -120,6 +124,7 @@ def __init__( # noqa: PLR0913 ) default_load_args: dict[str, Any] = {} + default_save_args: dict[str, Any] = {} self._load_args = ( {**default_load_args, **load_args} @@ -128,6 +133,12 @@ def __init__( # noqa: PLR0913 ) self.table_name = table_name + self._save_args = ( + {**default_save_args, **save_args} + if save_args is not None + else default_save_args + ) + self.metadata = metadata # load sql query from file @@ -203,6 +214,7 @@ def save(self, data: pl.DataFrame) -> NoReturn: data.write_database( table=self.table_name, connection=self._connection_str, + **self._save_args ) def _exists(self) -> bool: From b1df871247cf2fc26bdc4631cc523313cb50f756 Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Tue, 28 Jan 2025 13:36:23 +0530 Subject: [PATCH 08/12] implemented _describe() Signed-off-by: Minura Punchihewa --- .../polars/polars_database_dataset.py | 39 +++++-------------- 1 file changed, 10 insertions(+), 29 deletions(-) diff --git a/kedro-datasets/kedro_datasets_experimental/polars/polars_database_dataset.py b/kedro-datasets/kedro_datasets_experimental/polars/polars_database_dataset.py index 3323dac49..a92bd77e2 100644 --- a/kedro-datasets/kedro_datasets_experimental/polars/polars_database_dataset.py +++ b/kedro-datasets/kedro_datasets_experimental/polars/polars_database_dataset.py @@ -188,6 +188,16 @@ def engine(self): self.create_connection(self._connection_str, self._connection_args) return cls.engines[self._connection_str] + + def _describe(self) -> dict[str, Any]: + load_args = copy.deepcopy(self._load_args) + return { + "sql": str(load_args.pop("sql", None)), + "filepath": str(self._filepath), + "load_args": str(load_args), + "table_name": self.table_name, + "save_args": str(self._save_args), + } def load(self) -> pl.DataFrame: load_args = copy.deepcopy(self._load_args) @@ -216,35 +226,6 @@ def save(self, data: pl.DataFrame) -> NoReturn: connection=self._connection_str, **self._save_args ) - - def _exists(self) -> bool: - insp = inspect(self.engine) - schema = self._load_args.get("schema", None) - return insp.has_table(self._load_args["table_name"], schema) - - def preview(self, nrows: int = 5) -> TablePreview: - """ - Generate a preview of the dataset with a specified number of rows. - - Args: - nrows: The number of rows to include in the preview. Defaults to 5. - - Returns: - dict: A dictionary containing the data in a split format. - """ - table_name = self._load_args["table_name"] - - metadata = MetaData() - table_ref = Table(table_name, metadata, autoload_with=self.engine) - - query = select(table_ref).limit(nrows) # type: ignore[arg-type] - - with self.engine.connect() as conn: - result = conn.execute(query) - data_preview = pl.DataFrame(result.fetchall(), columns=result.keys()) - - preview_data = data_preview.to_dict(orient="split") - return preview_data # For mssql only def adapt_mssql_date_params(self) -> None: From 54afac38f5fb729b0dd66458c527eb63ef48bdf8 Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Sat, 1 Feb 2025 13:02:52 +0530 Subject: [PATCH 09/12] updated the docstring for the database Signed-off-by: Minura Punchihewa --- .../polars/polars_database_dataset.py | 56 +++++++++++++++++-- 1 file changed, 51 insertions(+), 5 deletions(-) diff --git a/kedro-datasets/kedro_datasets_experimental/polars/polars_database_dataset.py b/kedro-datasets/kedro_datasets_experimental/polars/polars_database_dataset.py index a92bd77e2..0a0e2443e 100644 --- a/kedro-datasets/kedro_datasets_experimental/polars/polars_database_dataset.py +++ b/kedro-datasets/kedro_datasets_experimental/polars/polars_database_dataset.py @@ -14,12 +14,9 @@ get_protocol_and_path, ) import polars as pl -from sqlalchemy import MetaData, Table, create_engine, inspect, select +from sqlalchemy import create_engine from sqlalchemy.exc import NoSuchModuleError -from kedro_datasets._typing import TablePreview - - KNOWN_PIP_INSTALL = { "psycopg2": "psycopg2", @@ -91,7 +88,56 @@ def _get_sql_alchemy_missing_error() -> DatasetError: class PolarsDatabaseDataset(AbstractDataset[None, pl.DataFrame]): - + """``PolarsDatabaseDataset`` loads data from a provided SQL query or write data to a table. + It supports all allowed polars options on ``read_database`` and ``write_database``. + Since Polars uses SQLAlchemy behind the scenes, when instantiating ``PolarsDatabaseDataset`` one needs to pass + a compatible connection string either in ``credentials`` (see the example + code snippet below) or in ``load_args``. Connection string formats supported + by SQLAlchemy can be found here: + https://docs.sqlalchemy.org/core/engines.html#database-urls + + Example usage for the + `YAML API `_: + + .. code-block:: yaml + + shuttle_id_dataset: + type: polars.PolarsDatabaseDataset + sql: "select shuttle, shuttle_id from spaceflights.shuttles;" + credentials: db_credentials + + Sample database credentials entry in ``credentials.yml``: + + .. code-block:: yaml + + db_credentials: + con: postgresql://scott:tiger@localhost/test + pool_size: 10 # additional parameters + + Example usage for the + `Python API `_: + + .. code-block:: pycon + + >>> from pathlib import Path + >>> import polars as pl + >>> import sqlite3 + >>> + >>> from kedro_datasets_experimental.polars import PolarsDatabaseDataset + >>> + >>> data = pl.DataFrame({"col1": [1, 2], "col2": [4, 5], "col3": [5, 6]}) + >>> sql = "SELECT * FROM table_a" + >>> tmp_path = Path.cwd() / "tmp" + >>> credentials = {"con": f"sqlite:///{tmp_path / 'test.db'}"} + >>> dataset = SQLQueryDataset(sql=sql, credentials=credentials, table_name="table_a") + >>> + >>> dataset.save(data) + >>> reloaded = dataset.load() + >>> + >>> assert data.equals(reloaded) + """ def __init__( # noqa: PLR0913 self, *, From e6910fe7aaec0052a3a3963233d755c7accb7fac Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Sat, 1 Feb 2025 13:39:56 +0530 Subject: [PATCH 10/12] fixed save() Signed-off-by: Minura Punchihewa --- .../polars/polars_database_dataset.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/kedro-datasets/kedro_datasets_experimental/polars/polars_database_dataset.py b/kedro-datasets/kedro_datasets_experimental/polars/polars_database_dataset.py index 0a0e2443e..3e862ebbb 100644 --- a/kedro-datasets/kedro_datasets_experimental/polars/polars_database_dataset.py +++ b/kedro-datasets/kedro_datasets_experimental/polars/polars_database_dataset.py @@ -130,6 +130,7 @@ class PolarsDatabaseDataset(AbstractDataset[None, pl.DataFrame]): >>> data = pl.DataFrame({"col1": [1, 2], "col2": [4, 5], "col3": [5, 6]}) >>> sql = "SELECT * FROM table_a" >>> tmp_path = Path.cwd() / "tmp" + >>> tmp_path.mkdir(parents=True, exist_ok=True) >>> credentials = {"con": f"sqlite:///{tmp_path / 'test.db'}"} >>> dataset = SQLQueryDataset(sql=sql, credentials=credentials, table_name="table_a") >>> @@ -138,6 +139,10 @@ class PolarsDatabaseDataset(AbstractDataset[None, pl.DataFrame]): >>> >>> assert data.equals(reloaded) """ + # using Any because of Sphinx but it should be + # sqlalchemy.engine.Engine or sqlalchemy.engine.base.Engine + engines: dict[str, Any] = {} + def __init__( # noqa: PLR0913 self, *, @@ -170,7 +175,9 @@ def __init__( # noqa: PLR0913 ) default_load_args: dict[str, Any] = {} - default_save_args: dict[str, Any] = {} + default_save_args: dict[str, Any] = { + "if_exists": "replace" + } self._load_args = ( {**default_load_args, **load_args} @@ -268,7 +275,7 @@ def save(self, data: pl.DataFrame) -> NoReturn: ) data.write_database( - table=self.table_name, + table_name=self.table_name, connection=self._connection_str, **self._save_args ) From b477fceaf079a871220fc7b54c7d98522c177809 Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Sat, 1 Feb 2025 21:54:11 +0530 Subject: [PATCH 11/12] updated the required params Signed-off-by: Minura Punchihewa --- .../polars/polars_database_dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/kedro-datasets/kedro_datasets_experimental/polars/polars_database_dataset.py b/kedro-datasets/kedro_datasets_experimental/polars/polars_database_dataset.py index 3e862ebbb..50119703f 100644 --- a/kedro-datasets/kedro_datasets_experimental/polars/polars_database_dataset.py +++ b/kedro-datasets/kedro_datasets_experimental/polars/polars_database_dataset.py @@ -162,9 +162,9 @@ def __init__( # noqa: PLR0913 "Please only provide one." ) - if not (sql or filepath): + if not table_name or (sql or filepath): raise DatasetError( - "'sql' and 'filepath' arguments cannot both be empty." + "Either 'table_name' or one of 'sql' or 'filepath' arguments cannot both be empty." "Please provide a sql query or path to a sql query file." ) From 9423fb8d8fc0260f89327044211e5c0d9fb9b701 Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Sat, 1 Feb 2025 22:06:08 +0530 Subject: [PATCH 12/12] fixed lint issues Signed-off-by: Minura Punchihewa --- .../polars/polars_database_dataset.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/kedro-datasets/kedro_datasets_experimental/polars/polars_database_dataset.py b/kedro-datasets/kedro_datasets_experimental/polars/polars_database_dataset.py index 50119703f..4ae3de8a4 100644 --- a/kedro-datasets/kedro_datasets_experimental/polars/polars_database_dataset.py +++ b/kedro-datasets/kedro_datasets_experimental/polars/polars_database_dataset.py @@ -2,22 +2,21 @@ import copy import datetime as dt -from pathlib import PurePosixPath import re +from pathlib import PurePosixPath from typing import Any, NoReturn import fsspec +import polars as pl from kedro.io.core import ( AbstractDataset, DatasetError, get_filepath_str, get_protocol_and_path, ) -import polars as pl -from sqlalchemy import create_engine +from sqlalchemy import create_engine from sqlalchemy.exc import NoSuchModuleError - KNOWN_PIP_INSTALL = { "psycopg2": "psycopg2", "mysqldb": "mysqlclient", @@ -89,7 +88,7 @@ def _get_sql_alchemy_missing_error() -> DatasetError: class PolarsDatabaseDataset(AbstractDataset[None, pl.DataFrame]): """``PolarsDatabaseDataset`` loads data from a provided SQL query or write data to a table. - It supports all allowed polars options on ``read_database`` and ``write_database``. + It supports all allowed polars options on ``read_database`` and ``write_database``. Since Polars uses SQLAlchemy behind the scenes, when instantiating ``PolarsDatabaseDataset`` one needs to pass a compatible connection string either in ``credentials`` (see the example code snippet below) or in ``load_args``. Connection string formats supported @@ -241,7 +240,7 @@ def engine(self): self.create_connection(self._connection_str, self._connection_args) return cls.engines[self._connection_str] - + def _describe(self) -> dict[str, Any]: load_args = copy.deepcopy(self._load_args) return { @@ -273,13 +272,13 @@ def save(self, data: pl.DataFrame) -> NoReturn: raise DatasetError( "'table_name' argument is required to save datasets." ) - + data.write_database( table_name=self.table_name, connection=self._connection_str, **self._save_args ) - + # For mssql only def adapt_mssql_date_params(self) -> None: """We need to change the format of datetime parameters. @@ -301,4 +300,4 @@ def adapt_mssql_date_params(self) -> None: except (TypeError, ValueError): new_load_args.append(value) if new_load_args: - self._load_args["params"] = tuple(new_load_args) \ No newline at end of file + self._load_args["params"] = tuple(new_load_args)