From 7565ae9f8dfb632e1884a8db8b68e2c6446ef169 Mon Sep 17 00:00:00 2001 From: "ZhengYu, Xu" Date: Thu, 18 Apr 2024 12:00:12 +0800 Subject: [PATCH] refactor: Optimize the detection of whether a package is installed. Signed-off-by: ZhengYu, Xu --- connectorx-python/connectorx/__init__.py | 44 ++++++++---------------- 1 file changed, 15 insertions(+), 29 deletions(-) diff --git a/connectorx-python/connectorx/__init__.py b/connectorx-python/connectorx/__init__.py index 8415c5126..5caab03b3 100644 --- a/connectorx-python/connectorx/__init__.py +++ b/connectorx-python/connectorx/__init__.py @@ -1,8 +1,8 @@ from __future__ import annotations -from typing import Any - +import importlib from importlib.metadata import version +from typing import Any from .connectorx import ( read_sql as _read_sql, @@ -216,10 +216,7 @@ def read_sql( if return_type == "pandas": df = df.to_pandas(date_as_object=False, split_blocks=False) if return_type == "polars": - try: - import polars as pl - except ModuleNotFoundError: - raise ValueError("You need to install polars first") + pl = try_import_module("polars") try: # api change for polars >= 0.8.* @@ -255,10 +252,7 @@ def read_sql( conn, protocol = rewrite_conn(conn, protocol) if return_type in {"modin", "dask", "pandas"}: - try: - import pandas - except ModuleNotFoundError: - raise ValueError("You need to install pandas first") + try_import_module("pandas") result = _read_sql( conn, @@ -273,25 +267,14 @@ def read_sql( df.set_index(index_col, inplace=True) if return_type == "modin": - try: - import modin.pandas as mpd - except ModuleNotFoundError: - raise ValueError("You need to install modin first") - + mpd = try_import_module("modin.pandas") df = mpd.DataFrame(df) elif return_type == "dask": - try: - import dask.dataframe as dd - except ModuleNotFoundError: - raise ValueError("You need to install dask first") - + dd = try_import_module("dask.dataframe") df = dd.from_pandas(df, npartitions=1) elif return_type in {"arrow", "arrow2", "polars", "polars2"}: - try: - import pyarrow - except ModuleNotFoundError: - raise ValueError("You need to install pyarrow first") + try_import_module("pyarrow") result = _read_sql( conn, @@ -302,11 +285,7 @@ def read_sql( ) df = reconstruct_arrow(result) if return_type in {"polars", "polars2"}: - try: - import polars as pl - except ModuleNotFoundError: - raise ValueError("You need to install polars first") - + pl = try_import_module("polars") try: df = pl.DataFrame.from_arrow(df) except AttributeError: @@ -391,3 +370,10 @@ def remove_ending_semicolon(query: str) -> str: if query.endswith(';'): query = query[:-1] return query + + +def try_import_module(name: str): + try: + return importlib.import_module(name) + except ModuleNotFoundError: + raise ValueError(f"You need to install {name.split('.')[0]} first")