Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test(datasets): make SQL dataset examples runnable #455

Merged
merged 8 commits into from
Dec 7, 2023
1 change: 0 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ dataset-doctest%:
--ignore kedro_datasets/databricks/managed_table_dataset.py \
--ignore kedro_datasets/pandas/deltatable_dataset.py \
--ignore kedro_datasets/pandas/gbq_dataset.py \
--ignore kedro_datasets/pandas/sql_dataset.py \
--ignore kedro_datasets/partitions/incremental_dataset.py \
--ignore kedro_datasets/partitions/partitioned_dataset.py \
--ignore kedro_datasets/polars/lazy_polars_dataset.py \
Expand Down
38 changes: 24 additions & 14 deletions kedro-datasets/kedro_datasets/pandas/sql_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,11 @@ class SQLTableDataset(AbstractDataset[pd.DataFrame, pd.DataFrame]):
>>>
>>> data = pd.DataFrame({"col1": [1, 2], "col2": [4, 5], "col3": [5, 6]})
>>> table_name = "table_a"
>>> credentials = {"con": "postgresql://scott:tiger@localhost/test"}
>>> data_set = SQLTableDataset(table_name=table_name, credentials=credentials)
>>> credentials = {"con": f"sqlite:///{tmp_path / 'test.db'}"}
>>> dataset = SQLTableDataset(table_name=table_name, credentials=credentials)
>>>
>>> data_set.save(data)
>>> reloaded = data_set.load()
>>> dataset.save(data)
>>> reloaded = dataset.load()
>>>
>>> assert data.equals(reloaded)

Expand Down Expand Up @@ -314,21 +314,31 @@ class SQLQueryDataset(AbstractDataset[None, pd.DataFrame]):

.. code-block:: pycon

>>> import sqlite3
>>>
>>> from kedro_datasets.pandas import SQLQueryDataset
>>> import pandas as pd
>>>
>>> data = pd.DataFrame({"col1": [1, 2], "col2": [4, 5], "col3": [5, 6]})
>>> sql = "SELECT * FROM table_a"
>>> credentials = {"con": "postgresql://scott:tiger@localhost/test"}
>>> data_set = SQLQueryDataset(sql=sql, credentials=credentials)
>>> credentials = {"con": f"sqlite:///{tmp_path / 'test.db'}"}
>>> dataset = SQLQueryDataset(sql=sql, credentials=credentials)
>>>
>>> con = sqlite3.connect(tmp_path / "test.db")
>>> cur = con.cursor()
>>> cur.execute("CREATE TABLE table_a(col1, col2, col3)")
<sqlite3.Cursor object at 0x...>
>>> cur.execute("INSERT INTO table_a VALUES (1, 4, 5), (2, 5, 6)")
<sqlite3.Cursor object at 0x...>
>>> con.commit()
>>> reloaded = dataset.load()
>>>
>>> sql_data = data_set.load()
>>> assert data.equals(reloaded)

Example of usage for mssql:
Example of usage for MSSQL:

.. code-block:: pycon


>>> credentials = {
... "server": "localhost",
... "port": "1433",
Expand All @@ -339,8 +349,8 @@ class SQLQueryDataset(AbstractDataset[None, pd.DataFrame]):
>>> def _make_mssql_connection_str(
... server: str, port: str, database: str, user: str, password: str
... ) -> str:
... import pyodbc # noqa
... from sqlalchemy.engine import URL # noqa
... import pyodbc
... from sqlalchemy.engine import URL
... driver = pyodbc.drivers()[-1]
... connection_str = (
... f"DRIVER={driver};SERVER={server},{port};DATABASE={database};"
Expand All @@ -349,11 +359,11 @@ class SQLQueryDataset(AbstractDataset[None, pd.DataFrame]):
... )
... return URL.create("mssql+pyodbc", query={"odbc_connect": connection_str})
...
>>> connection_str = _make_mssql_connection_str(**credentials)
>>> data_set = SQLQueryDataset(
>>> connection_str = _make_mssql_connection_str(**credentials) # doctest: +SKIP
>>> dataset = SQLQueryDataset( # doctest: +SKIP
... credentials={"con": connection_str}, sql="SELECT TOP 5 * FROM TestTable;"
... )
>>> df = data_set.load()
>>> df = dataset.load()

In addition, here is an example of a catalog with dates parsing:

Expand Down