Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/feature/DOP-16999' into feature/…
Browse files Browse the repository at this point in the history
…DOP-16999
  • Loading branch information
maxim-lixakov committed Aug 20, 2024
2 parents 4dc8b56 + 81d86f6 commit 7d0f3a0
Show file tree
Hide file tree
Showing 7 changed files with 51 additions and 104 deletions.
3 changes: 0 additions & 3 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,6 @@
{"rel": "icon", "href": "icon.svg", "type": "image/svg+xml"},
]

# TODO: remove after https://github.com/mgeier/sphinx-last-updated-by-git/pull/77
git_exclude_patterns = ["docs/_static/logo_wide.svg"]

# The master toctree document.
master_doc = "index"

Expand Down
4 changes: 2 additions & 2 deletions onetl/connection/db_connection/jdbc_connection/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def sql(

query = clear_statement(query)

log.info("|%s| Detected dialect: '%s'", self.__class__.__name__, self.jdbc_dialect)
log.info("|%s| Detected dialect: '%s'", self.__class__.__name__, self._get_spark_dialect_name())
log.info("|%s| Executing SQL query (on executor):", self.__class__.__name__)
log_lines(log, query)

Expand Down Expand Up @@ -196,7 +196,7 @@ def get_df_schema(
columns: list[str] | None = None,
options: JDBCReadOptions | None = None,
) -> StructType:
log.info("|%s| Detected dialect: '%s'", self.__class__.__name__, self.jdbc_dialect)
log.info("|%s| Detected dialect: '%s'", self.__class__.__name__, self._get_spark_dialect_name())
log.info("|%s| Fetching schema of table %r ...", self.__class__.__name__, source)

query = self.dialect.get_sql_query(source, columns=columns, limit=0, compact=True)
Expand Down
29 changes: 14 additions & 15 deletions onetl/connection/db_connection/jdbc_mixin/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def fetch(

query = clear_statement(query)

log.info("|%s| Detected dialect: '%s'", self.__class__.__name__, self.jdbc_dialect)
log.info("|%s| Detected dialect: '%s'", self.__class__.__name__, self._get_spark_dialect_name())
log.info("|%s| Executing SQL query (on driver):", self.__class__.__name__)
log_lines(log, query)

Expand Down Expand Up @@ -278,7 +278,7 @@ def execute(

statement = clear_statement(statement)

log.info("|%s| Detected dialect: '%s'", self.__class__.__name__, self.jdbc_dialect)
log.info("|%s| Detected dialect: '%s'", self.__class__.__name__, self._get_spark_dialect_name())
log.info("|%s| Executing statement (on driver):", self.__class__.__name__)
log_lines(log, statement)

Expand Down Expand Up @@ -310,16 +310,6 @@ def execute(
log_lines(log, str(metrics))
return df

@property
def jdbc_dialect(self):
"""
Returns the JDBC dialect associated with the connection URL.
"""
jdbc_dialects_package = self.spark._jvm.org.apache.spark.sql.jdbc
dialect = jdbc_dialects_package.JdbcDialects.get(self.jdbc_url).toString()

return dialect.split("$")[0] if "$" in dialect else dialect

@validator("spark")
def _check_java_class_imported(cls, spark):
try:
Expand Down Expand Up @@ -429,6 +419,17 @@ def _get_jdbc_connection(self, options: JDBCFetchOptions | JDBCExecuteOptions):
self._last_connection_and_options.data = (new_connection, options)
return new_connection

def _get_spark_dialect_name(self) -> str:
"""
Returns the name of the JDBC dialect associated with the connection URL.
"""
dialect = self._get_spark_dialect().toString()
return dialect.split("$")[0] if "$" in dialect else dialect

def _get_spark_dialect(self):
jdbc_dialects_package = self.spark._jvm.org.apache.spark.sql.jdbc
return jdbc_dialects_package.JdbcDialects.get(self.jdbc_url)

def _close_connections(self):
with suppress(Exception):
# connection maybe not opened yet
Expand Down Expand Up @@ -571,9 +572,7 @@ def _resultset_to_dataframe(self, result_set) -> DataFrame:

from pyspark.sql import DataFrame # noqa: WPS442

jdbc_dialects_package = self.spark._jvm.org.apache.spark.sql.jdbc # type: ignore
jdbc_dialect = jdbc_dialects_package.JdbcDialects.get(self.jdbc_url)

jdbc_dialect = self._get_spark_dialect()
jdbc_utils_package = self.spark._jvm.org.apache.spark.sql.execution.datasources.jdbc # type: ignore
jdbc_utils = jdbc_utils_package.JdbcUtils

Expand Down
3 changes: 2 additions & 1 deletion requirements/docs.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ sphinx<8
sphinx-copybutton
sphinx-design
sphinx-favicon
sphinx-last-updated-by-git
# https://github.com/mgeier/sphinx-last-updated-by-git/pull/77
sphinx-last-updated-by-git>=0.3.8
# TODO: uncomment after https://github.com/zqmillet/sphinx-plantuml/pull/4
# sphinx-plantuml
sphinx-tabs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def test_clickhouse_connection_check_extra_is_handled_by_driver(spark, processin


@pytest.mark.parametrize("suffix", ["", ";"])
def test_clickhouse_connection_sql(spark, processing, load_table_data, suffix):
def test_clickhouse_connection_sql(spark, processing, load_table_data, suffix, caplog):
clickhouse = Clickhouse(
host=processing.host,
port=processing.port,
Expand All @@ -73,7 +73,11 @@ def test_clickhouse_connection_sql(spark, processing, load_table_data, suffix):
)

table = load_table_data.full_name
df = clickhouse.sql(f"SELECT * FROM {table}{suffix}")

with caplog.at_level(logging.INFO):
df = clickhouse.sql(f"SELECT * FROM {table}{suffix}")
assert "Detected dialect: 'org.apache.spark.sql.jdbc.NoopDialect'" in caplog.text

table_df = processing.get_expected_dataframe(
schema=load_table_data.schema,
table=load_table_data.table,
Expand All @@ -91,7 +95,7 @@ def test_clickhouse_connection_sql(spark, processing, load_table_data, suffix):


@pytest.mark.parametrize("suffix", ["", ";"])
def test_clickhouse_connection_fetch(spark, processing, load_table_data, suffix):
def test_clickhouse_connection_fetch(spark, processing, load_table_data, suffix, caplog):
clickhouse = Clickhouse(
host=processing.host,
port=processing.port,
Expand All @@ -103,7 +107,10 @@ def test_clickhouse_connection_fetch(spark, processing, load_table_data, suffix)

schema = load_table_data.schema
table = load_table_data.full_name
df = clickhouse.fetch(f"SELECT * FROM {table}{suffix}")

with caplog.at_level(logging.INFO):
df = clickhouse.fetch(f"SELECT * FROM {table}{suffix}")
assert "Detected dialect: 'org.apache.spark.sql.jdbc.NoopDialect'" in caplog.text

table_df = processing.get_expected_dataframe(
schema=load_table_data.schema,
Expand Down Expand Up @@ -192,7 +199,7 @@ def test_clickhouse_connection_execute_ddl(spark, processing, get_schema_table,

@pytest.mark.flaky
@pytest.mark.parametrize("suffix", ["", ";"])
def test_clickhouse_connection_execute_dml(request, spark, processing, load_table_data, suffix):
def test_clickhouse_connection_execute_dml(request, spark, processing, load_table_data, suffix, caplog):
clickhouse = Clickhouse(
host=processing.host,
port=processing.port,
Expand Down Expand Up @@ -242,7 +249,9 @@ def table_finalizer():
updated_df = pandas.concat([updated_rows, unchanged_rows])
processing.assert_equal_df(df=df, other_frame=updated_df, order_by="id_int")

clickhouse.execute(f"UPDATE {temp_table} SET hwm_int = 1 WHERE id_int < 50{suffix}")
with caplog.at_level(logging.INFO):
clickhouse.execute(f"UPDATE {temp_table} SET hwm_int = 1 WHERE id_int < 50{suffix}")
assert "Detected dialect: 'org.apache.spark.sql.jdbc.NoopDialect'" in caplog.text

clickhouse.execute(f"ALTER TABLE {temp_table} DELETE WHERE id_int < 70{suffix}")
df = clickhouse.fetch(f"SELECT * FROM {temp_table}{suffix}")
Expand Down Expand Up @@ -273,6 +282,7 @@ def test_clickhouse_connection_execute_function(
processing,
load_table_data,
suffix,
caplog,
):
clickhouse = Clickhouse(
host=processing.host,
Expand Down Expand Up @@ -321,38 +331,3 @@ def func_finalizer():
# wrong syntax
with pytest.raises(Exception):
clickhouse.execute(f"CREATE FUNCTION wrong_function AS (a, b) -> {suffix}")


def test_clickhouse_connection_no_jdbc_dialect(spark, processing, load_table_data, caplog):
clickhouse = Clickhouse(
host=processing.host,
port=processing.port,
user=processing.user,
password=processing.password,
database=processing.database,
spark=spark,
)

table = load_table_data.full_name
clickhouse.get_df_schema(table)

with caplog.at_level(logging.INFO):
assert "Detected dialect: 'org.apache.spark.sql.jdbc.NoopDialect'" in caplog.text

# clear the caplog buffer
caplog.clear()
clickhouse.sql("SELECT version()")
with caplog.at_level(logging.INFO):
assert "Detected dialect: 'org.apache.spark.sql.jdbc.NoopDialect'" in caplog.text

# clear the caplog buffer
caplog.clear()
clickhouse.fetch("SELECT version()")
with caplog.at_level(logging.INFO):
assert "Detected dialect: 'org.apache.spark.sql.jdbc.NoopDialect'" in caplog.text

# clear the caplog buffer
caplog.clear()
clickhouse.execute(f"TRUNCATE TABLE {table}")
with caplog.at_level(logging.INFO):
assert "Detected dialect: 'org.apache.spark.sql.jdbc.NoopDialect'" in caplog.text
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def test_postgres_connection_check_fail(spark):


@pytest.mark.parametrize("suffix", ["", ";"])
def test_postgres_connection_sql(spark, processing, load_table_data, suffix):
def test_postgres_connection_sql(spark, processing, load_table_data, suffix, caplog):
postgres = Postgres(
host=processing.host,
port=processing.port,
Expand All @@ -60,7 +60,10 @@ def test_postgres_connection_sql(spark, processing, load_table_data, suffix):

table = load_table_data.full_name

df = postgres.sql(f"SELECT * FROM {table}{suffix}")
with caplog.at_level(logging.INFO):
df = postgres.sql(f"SELECT * FROM {table}{suffix}")
assert "Detected dialect: 'org.apache.spark.sql.jdbc.PostgresDialect'" in caplog.text

table_df = processing.get_expected_dataframe(
schema=load_table_data.schema,
table=load_table_data.table,
Expand All @@ -79,7 +82,7 @@ def test_postgres_connection_sql(spark, processing, load_table_data, suffix):


@pytest.mark.parametrize("suffix", ["", ";"])
def test_postgres_connection_fetch(spark, processing, load_table_data, suffix):
def test_postgres_connection_fetch(spark, processing, load_table_data, suffix, caplog):
postgres = Postgres(
host=processing.host,
port=processing.port,
Expand All @@ -91,7 +94,10 @@ def test_postgres_connection_fetch(spark, processing, load_table_data, suffix):

table = load_table_data.full_name

df = postgres.fetch(f"SELECT * FROM {table}{suffix}", Postgres.FetchOptions(fetchsize=2))
with caplog.at_level(logging.INFO):
df = postgres.fetch(f"SELECT * FROM {table}{suffix}", Postgres.FetchOptions(fetchsize=2))
assert "Detected dialect: 'org.apache.spark.sql.jdbc.PostgresDialect'" in caplog.text

table_df = processing.get_expected_dataframe(
schema=load_table_data.schema,
table=load_table_data.table,
Expand Down Expand Up @@ -1023,7 +1029,7 @@ def test_postgres_connection_fetch_with_legacy_jdbc_options(spark, processing):
assert df is not None


def test_postgres_connection_execute_with_legacy_jdbc_options(spark, processing):
def test_postgres_connection_execute_with_legacy_jdbc_options(spark, processing, caplog):
postgres = Postgres(
host=processing.host,
port=processing.port,
Expand All @@ -1034,42 +1040,7 @@ def test_postgres_connection_execute_with_legacy_jdbc_options(spark, processing)
)

options = Postgres.JDBCOptions(query_timeout=30)
postgres.execute("DROP TABLE IF EXISTS temp_table;", options=options)


def test_postgres_connection_jdbc_dialect_usage(spark, processing, load_table_data, caplog):
postgres = Postgres(
host=processing.host,
port=processing.port,
user=processing.user,
password=processing.password,
database=processing.database,
spark=spark,
)

table = load_table_data.full_name
postgres.get_df_schema(table)

with caplog.at_level(logging.INFO):
assert "Detected dialect: 'org.apache.spark.sql.jdbc.PostgresDialect'" in caplog.text

# clear the caplog buffer
caplog.clear()
postgres.sql("SELECT version()")
with caplog.at_level(logging.INFO):
assert "Detected dialect: 'org.apache.spark.sql.jdbc.PostgresDialect'" in caplog.text

caplog.clear()
postgres.fetch("SELECT version()")
with caplog.at_level(logging.INFO):
assert "Detected dialect: 'org.apache.spark.sql.jdbc.PostgresDialect'" in caplog.text

caplog.clear()
postgres.fetch("SELECT version()")
with caplog.at_level(logging.INFO):
assert "Detected dialect: 'org.apache.spark.sql.jdbc.PostgresDialect'" in caplog.text

caplog.clear()
postgres.execute(f"TRUNCATE TABLE {table}")
with caplog.at_level(logging.INFO):
postgres.execute("DROP TABLE IF EXISTS temp_table;", options=options)
assert "Detected dialect: 'org.apache.spark.sql.jdbc.PostgresDialect'" in caplog.text
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import re
import secrets
from datetime import date, datetime, timedelta
Expand Down Expand Up @@ -182,6 +183,7 @@ def test_postgres_strategy_incremental_batch_different_hwm_type_in_store(
hwm_column,
new_type,
step,
caplog,
):
postgres = Postgres(
host=processing.host,
Expand All @@ -200,7 +202,9 @@ def test_postgres_strategy_incremental_batch_different_hwm_type_in_store(

with IncrementalBatchStrategy(step=step) as batches:
for _ in batches:
reader.run()
with caplog.at_level(logging.INFO):
reader.run()
assert "Detected dialect: 'org.apache.spark.sql.jdbc.PostgresDialect'" in caplog.text

# change table schema
new_fields = {column_name: processing.get_column_type(column_name) for column_name in processing.column_names}
Expand Down

0 comments on commit 7d0f3a0

Please sign in to comment.