Skip to content

Commit

Permalink
feat: add typing
Browse files Browse the repository at this point in the history
Signed-off-by: Luka Peschke <[email protected]>
  • Loading branch information
lukapeschke committed Dec 2, 2024
1 parent 07f76fe commit e6057ea
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 21 deletions.
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -163,15 +163,17 @@ ignore_missing_imports = true
files = [
"toucan_connectors/auth.py",
"toucan_connectors/awsathena/awsathena_connector.py",
"toucan_connectors/azure_mssql/azure_mssql_connector.py",
"toucan_connectors/common.py",
"toucan_connectors/google_big_query/google_big_query_connector.py",
"toucan_connectors/hubspot_private_app/hubspot_connector.py",
"toucan_connectors/mongo/mongo_connector.py",
"toucan_connectors/mysql/mysql_connector.py",
"toucan_connectors/peakina/peakina_connector.py",
"toucan_connectors/postgres/postgresql_connector.py",
"toucan_connectors/mysql/mysql_connector.py",
"toucan_connectors/redshift/redshift_database_connector.py",
"toucan_connectors/snowflake/snowflake_connector.py",
"toucan_connectors/snowflake_oauth2/snowflake_oauth2_connector.py",
"toucan_connectors/redshift/redshift_database_connector.py",
"toucan_connectors/toucan_connector.py",
]

Expand Down
8 changes: 4 additions & 4 deletions toucan_connectors/azure_mssql/azure_mssql_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from typing import TYPE_CHECKING, Annotated

from pydantic import Field, StringConstraints
from pydantic import Field, SecretStr, StringConstraints

from toucan_connectors.common import (
convert_jinja_params_to_sqlalchemy_named,
Expand All @@ -29,7 +29,7 @@
class AzureMSSQLDataSource(ToucanDataSource):
database: str = Field(..., description="The name of the database you want to query")
query: Annotated[str, StringConstraints(min_length=1)] = Field(
..., description="You can write your SQL query here", widget="sql"
..., description="You can write your SQL query here", json_schema_extra={"widget": "sql"}
)


Expand All @@ -45,8 +45,8 @@ class AzureMSSQLConnector(ToucanConnector, data_source_model=AzureMSSQLDataSourc
)

user: str = Field(..., description="Your login username")
password: PlainJsonSecretStr = Field("", description="Your login password")
connect_timeout: int = Field(
password: PlainJsonSecretStr = Field(SecretStr(""), description="Your login password")
connect_timeout: int | None = Field(
None,
title="Connection timeout",
description="You can set a connection timeout in seconds here, i.e. the maximum length of "
Expand Down
31 changes: 16 additions & 15 deletions toucan_connectors/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,12 @@ def is_jinja_alone(s: str) -> bool:
return False


def _has_parameters(query: dict | list[dict] | tuple | str) -> bool:
def _has_parameters(query: str) -> bool:
t = Environment().parse(query) # noqa: S701
return bool(meta.find_undeclared_variables(t) or re.search(RE_PARAM, query))


def _prepare_parameters(p: dict | list[dict] | tuple | str) -> dict | list[dict] | tuple | str:
def _prepare_parameters(p: dict | list[dict] | tuple | str) -> dict | list[Any] | tuple | str:
if isinstance(p, str):
return repr(p)
elif isinstance(p, list):
Expand All @@ -79,7 +79,7 @@ def _prepare_parameters(p: dict | list[dict] | tuple | str) -> dict | list[dict]
return p


def _prepare_result(res: dict | list[dict] | tuple | str) -> dict | list[dict] | tuple | str:
def _prepare_result(res: dict | list[dict] | tuple | str) -> dict | list[Any] | tuple | str:
if isinstance(res, str):
return ast.literal_eval(res)
elif isinstance(res, list):
Expand Down Expand Up @@ -164,8 +164,8 @@ def _render_query(query: dict | list[dict] | tuple | str, parameters: dict | Non
clean_p = deepcopy(parameters)

if is_jinja_alone(query):
clean_p = _prepare_parameters(clean_p)
env = NativeEnvironment()
clean_p = _prepare_parameters(clean_p) # type:ignore[assignment]
env: Environment | NativeEnvironment = NativeEnvironment()
else:
env = Environment() # noqa: S701

Expand Down Expand Up @@ -243,7 +243,7 @@ def _flatten_dict(p, parent_key=""):
# jq filtering


def transform_with_jq(data: object, jq_filter: str) -> list:
def transform_with_jq(data: Any, jq_filter: str) -> list:
import jq

data = jq.all(jq_filter, data)
Expand Down Expand Up @@ -344,7 +344,7 @@ def convert_to_qmark_paramstyle(query_string: str, params_values: dict) -> tuple
if isinstance(o, list):
# in the query string, replace the ? at index i by the number of item
# in the provided parameter of type list
query_string = query_string.replace(extracted_params[i], f'({",".join(len(ordered_values[i])*["?"])})')
query_string = query_string.replace(extracted_params[i], f'({",".join(len(o)*["?"])})')

flattened_values = []
for val in ordered_values:
Expand Down Expand Up @@ -377,7 +377,7 @@ def convert_to_numeric_paramstyle(query_string: str, params_values: dict) -> tup
# query_string = "SELECT name FROM students WHERE age IN %(allowed_ages)"
# allowed_ages = [16, 17, 18]
# transformed query_string = "SELECT name FROM students WHERE age IN (:1,:2,:3)"
list_size = len(ordered_values[i])
list_size = len(o)
variable_list = f'({",".join([f":{variable_idx + n}" for n in range(list_size)])})'
query_string = query_string.replace(extracted_params[i], variable_list)
variable_idx += list_size
Expand All @@ -393,7 +393,8 @@ def convert_to_numeric_paramstyle(query_string: str, params_values: dict) -> tup
else:
flattened_values.append(val)

return query_string, flattened_values
# NOTE: we should probably return tuple(flattened_values) here but it could be breaking
return query_string, flattened_values # type:ignore[return-value]


def convert_to_printf_templating_style(query_string: str) -> str:
Expand Down Expand Up @@ -455,8 +456,8 @@ def rename_duplicate_columns(df: "pd.DataFrame") -> None:

cols = pd.Series(df.columns)
for dup in df.columns[df.columns.duplicated(keep=False)]:
cols[df.columns.get_loc(dup)] = [f"{dup}_{d_idx}" for d_idx in range(df.columns.get_loc(dup).sum())]
df.columns = cols
cols[df.columns.get_loc(dup)] = [f"{dup}_{d_idx}" for d_idx in range(df.columns.get_loc(dup).sum())] # type:ignore[union-attr]
df.columns = cols # type:ignore[assignment]


def pandas_read_sql(
Expand Down Expand Up @@ -490,10 +491,10 @@ def pandas_read_sql(
query = query.replace("%%", "%")
query = re.sub(r"%[^(%]", r"%\g<0>", query)
df = pd.read_sql(query, con=con, params=params, **kwargs)
except pd.io.sql.DatabaseError as exc:
except pd.errors.DatabaseError as exc:
if is_interpolating_table_name(query):
errmsg = f"Execution failed on sql '{query}': interpolating table name is forbidden"
raise pd.io.sql.DatabaseError(errmsg) from exc
raise pd.errors.DatabaseError(errmsg) from exc
else:
raise

Expand Down Expand Up @@ -523,10 +524,10 @@ def pandas_read_sqlalchemy_query(
try:
conn = engine.connect()
df = pd.read_sql_query(sa_query, conn, params=params)
except (pd.io.sql.DatabaseError, sa.exc.SQLAlchemyError) as exc:
except (pd.errors.DatabaseError, sa.exc.SQLAlchemyError) as exc:
if is_interpolating_table_name(query):
errmsg = f"Execution failed on sql '{query}': interpolating table name is forbidden"
raise pd.io.sql.DatabaseError(errmsg) from exc
raise pd.errors.DatabaseError(errmsg) from exc
else:
raise

Expand Down

0 comments on commit e6057ea

Please sign in to comment.