diff --git a/docs/source/installation.rst b/docs/source/installation.rst index 7311cf59..991a80c5 100644 --- a/docs/source/installation.rst +++ b/docs/source/installation.rst @@ -12,3 +12,14 @@ or from a conda environment :: conda install datajudge -c conda-forge + + + +Snowflake +^^^^ + +If your backend is ``snowflake`` and you are querying large datasets, +you can additionally install ``pandas`` to make use of very fast query loading +(up to 50x speedup for large datasets). + Note: The ``pandas`` requirement is a bug in the snowflake-python-connector + and, hopefully, it will not be needed in the future. diff --git a/environment.yml b/environment.yml index 168d7c83..74816782 100644 --- a/environment.yml +++ b/environment.yml @@ -3,6 +3,7 @@ channels: - conda-forge - nodefaults dependencies: + - pandas - python>=3.8 - pytest - pytest-cov diff --git a/src/datajudge/db_access.py b/src/datajudge/db_access.py index a766f8f1..9346c3f8 100644 --- a/src/datajudge/db_access.py +++ b/src/datajudge/db_access.py @@ -3,6 +3,7 @@ import functools import json import operator +import warnings from abc import ABC, abstractmethod from collections import Counter from dataclasses import dataclass @@ -11,6 +12,17 @@ import sqlalchemy as sa from sqlalchemy.sql.expression import FromClause +from .utils import check_module_installed + +snowflake_available = check_module_installed("snowflake") +pandas_available = check_module_installed("pandas") + + +if snowflake_available and not pandas_available: + warnings.warn( + "For snowflake users: `pandas` is not installed, that means optimized data loading is not available." + ) + def is_mssql(engine: sa.engine.Engine) -> bool: return engine.name == "mssql" @@ -648,7 +660,20 @@ def get_column( if not aggregate_operator: selection = sa.select([column]) - result = engine.connect().execute(selection).scalars().all() + + # snowflake-specific optimization iff pandas is installed + if is_snowflake(engine) and pandas_available: + snowflake_cursor = engine.connect().connection.cursor() + + # note: in addition to pyarrow, this currently requires pandas as well + pa_table = snowflake_cursor.execute(str(selection)).fetch_arrow_all() + if pa_table: # snowflake connector returns NoneType when the table is empty + result = pa_table.column(0).to_numpy() + else: + result = [] + + else: + result = engine.connect().execute(selection).scalars().all() else: selection = sa.select([aggregate_operator(column)]) diff --git a/src/datajudge/utils.py b/src/datajudge/utils.py new file mode 100644 index 00000000..b9da85c8 --- /dev/null +++ b/src/datajudge/utils.py @@ -0,0 +1,8 @@ +def check_module_installed(module_name: str) -> bool: + import importlib + + try: + importlib.import_module(module_name) + return True + except ModuleNotFoundError: + return False