Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DOP-16999] - Add jdbc_dialect logging #305

Merged
merged 3 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/changelog/next_release/305.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add log.info about JDBC dialect usage: ``Detected dialect: 'org.apache.spark.sql.jdbc.MySQLDialect'``
2 changes: 2 additions & 0 deletions onetl/connection/db_connection/jdbc_connection/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def sql(

query = clear_statement(query)

log.info("|%s| Detected dialect: '%s'", self.__class__.__name__, self._get_spark_dialect_name())
log.info("|%s| Executing SQL query (on executor):", self.__class__.__name__)
log_lines(log, query)

Expand Down Expand Up @@ -195,6 +196,7 @@ def get_df_schema(
columns: list[str] | None = None,
options: JDBCReadOptions | None = None,
) -> StructType:
log.info("|%s| Detected dialect: '%s'", self.__class__.__name__, self._get_spark_dialect_name())
log.info("|%s| Fetching schema of table %r ...", self.__class__.__name__, source)

query = self.dialect.get_sql_query(source, columns=columns, limit=0, compact=True)
Expand Down
17 changes: 14 additions & 3 deletions onetl/connection/db_connection/jdbc_mixin/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ def fetch(

query = clear_statement(query)

log.info("|%s| Detected dialect: '%s'", self.__class__.__name__, self._get_spark_dialect_name())
log.info("|%s| Executing SQL query (on driver):", self.__class__.__name__)
log_lines(log, query)

Expand Down Expand Up @@ -277,6 +278,7 @@ def execute(

statement = clear_statement(statement)

log.info("|%s| Detected dialect: '%s'", self.__class__.__name__, self._get_spark_dialect_name())
log.info("|%s| Executing statement (on driver):", self.__class__.__name__)
log_lines(log, statement)

Expand Down Expand Up @@ -417,6 +419,17 @@ def _get_jdbc_connection(self, options: JDBCFetchOptions | JDBCExecuteOptions):
self._last_connection_and_options.data = (new_connection, options)
return new_connection

def _get_spark_dialect_name(self) -> str:
"""
Returns the name of the JDBC dialect associated with the connection URL.
"""
dialect = self._get_spark_dialect().toString()
return dialect.split("$")[0] if "$" in dialect else dialect

def _get_spark_dialect(self):
jdbc_dialects_package = self.spark._jvm.org.apache.spark.sql.jdbc
return jdbc_dialects_package.JdbcDialects.get(self.jdbc_url)

def _close_connections(self):
with suppress(Exception):
# connection maybe not opened yet
Expand Down Expand Up @@ -559,9 +572,7 @@ def _resultset_to_dataframe(self, result_set) -> DataFrame:

from pyspark.sql import DataFrame # noqa: WPS442

jdbc_dialects_package = self.spark._jvm.org.apache.spark.sql.jdbc # type: ignore
jdbc_dialect = jdbc_dialects_package.JdbcDialects.get(self.jdbc_url)

jdbc_dialect = self._get_spark_dialect()
jdbc_utils_package = self.spark._jvm.org.apache.spark.sql.execution.datasources.jdbc # type: ignore
jdbc_utils = jdbc_utils_package.JdbcUtils

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def test_clickhouse_connection_check_extra_is_handled_by_driver(spark, processin


@pytest.mark.parametrize("suffix", ["", ";"])
def test_clickhouse_connection_sql(spark, processing, load_table_data, suffix):
def test_clickhouse_connection_sql(spark, processing, load_table_data, suffix, caplog):
clickhouse = Clickhouse(
host=processing.host,
port=processing.port,
Expand All @@ -73,7 +73,11 @@ def test_clickhouse_connection_sql(spark, processing, load_table_data, suffix):
)

table = load_table_data.full_name
df = clickhouse.sql(f"SELECT * FROM {table}{suffix}")

with caplog.at_level(logging.INFO):
df = clickhouse.sql(f"SELECT * FROM {table}{suffix}")
assert "Detected dialect: 'org.apache.spark.sql.jdbc.NoopDialect'" in caplog.text

table_df = processing.get_expected_dataframe(
schema=load_table_data.schema,
table=load_table_data.table,
Expand All @@ -91,7 +95,7 @@ def test_clickhouse_connection_sql(spark, processing, load_table_data, suffix):


@pytest.mark.parametrize("suffix", ["", ";"])
def test_clickhouse_connection_fetch(spark, processing, load_table_data, suffix):
def test_clickhouse_connection_fetch(spark, processing, load_table_data, suffix, caplog):
clickhouse = Clickhouse(
host=processing.host,
port=processing.port,
Expand All @@ -103,7 +107,10 @@ def test_clickhouse_connection_fetch(spark, processing, load_table_data, suffix)

schema = load_table_data.schema
table = load_table_data.full_name
df = clickhouse.fetch(f"SELECT * FROM {table}{suffix}")

with caplog.at_level(logging.INFO):
df = clickhouse.fetch(f"SELECT * FROM {table}{suffix}")
assert "Detected dialect: 'org.apache.spark.sql.jdbc.NoopDialect'" in caplog.text

table_df = processing.get_expected_dataframe(
schema=load_table_data.schema,
Expand Down Expand Up @@ -192,7 +199,7 @@ def test_clickhouse_connection_execute_ddl(spark, processing, get_schema_table,

@pytest.mark.flaky
@pytest.mark.parametrize("suffix", ["", ";"])
def test_clickhouse_connection_execute_dml(request, spark, processing, load_table_data, suffix):
def test_clickhouse_connection_execute_dml(request, spark, processing, load_table_data, suffix, caplog):
clickhouse = Clickhouse(
host=processing.host,
port=processing.port,
Expand Down Expand Up @@ -242,7 +249,9 @@ def table_finalizer():
updated_df = pandas.concat([updated_rows, unchanged_rows])
processing.assert_equal_df(df=df, other_frame=updated_df, order_by="id_int")

clickhouse.execute(f"UPDATE {temp_table} SET hwm_int = 1 WHERE id_int < 50{suffix}")
with caplog.at_level(logging.INFO):
clickhouse.execute(f"UPDATE {temp_table} SET hwm_int = 1 WHERE id_int < 50{suffix}")
assert "Detected dialect: 'org.apache.spark.sql.jdbc.NoopDialect'" in caplog.text

clickhouse.execute(f"ALTER TABLE {temp_table} DELETE WHERE id_int < 70{suffix}")
df = clickhouse.fetch(f"SELECT * FROM {temp_table}{suffix}")
Expand Down Expand Up @@ -273,6 +282,7 @@ def test_clickhouse_connection_execute_function(
processing,
load_table_data,
suffix,
caplog,
):
clickhouse = Clickhouse(
host=processing.host,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def test_postgres_connection_check_fail(spark):


@pytest.mark.parametrize("suffix", ["", ";"])
def test_postgres_connection_sql(spark, processing, load_table_data, suffix):
def test_postgres_connection_sql(spark, processing, load_table_data, suffix, caplog):
postgres = Postgres(
host=processing.host,
port=processing.port,
Expand All @@ -60,7 +60,10 @@ def test_postgres_connection_sql(spark, processing, load_table_data, suffix):

table = load_table_data.full_name

df = postgres.sql(f"SELECT * FROM {table}{suffix}")
with caplog.at_level(logging.INFO):
df = postgres.sql(f"SELECT * FROM {table}{suffix}")
assert "Detected dialect: 'org.apache.spark.sql.jdbc.PostgresDialect'" in caplog.text

table_df = processing.get_expected_dataframe(
schema=load_table_data.schema,
table=load_table_data.table,
Expand All @@ -79,7 +82,7 @@ def test_postgres_connection_sql(spark, processing, load_table_data, suffix):


@pytest.mark.parametrize("suffix", ["", ";"])
def test_postgres_connection_fetch(spark, processing, load_table_data, suffix):
def test_postgres_connection_fetch(spark, processing, load_table_data, suffix, caplog):
postgres = Postgres(
host=processing.host,
port=processing.port,
Expand All @@ -91,7 +94,10 @@ def test_postgres_connection_fetch(spark, processing, load_table_data, suffix):

table = load_table_data.full_name

df = postgres.fetch(f"SELECT * FROM {table}{suffix}", Postgres.FetchOptions(fetchsize=2))
with caplog.at_level(logging.INFO):
df = postgres.fetch(f"SELECT * FROM {table}{suffix}", Postgres.FetchOptions(fetchsize=2))
assert "Detected dialect: 'org.apache.spark.sql.jdbc.PostgresDialect'" in caplog.text

table_df = processing.get_expected_dataframe(
schema=load_table_data.schema,
table=load_table_data.table,
Expand Down Expand Up @@ -1023,7 +1029,7 @@ def test_postgres_connection_fetch_with_legacy_jdbc_options(spark, processing):
assert df is not None


def test_postgres_connection_execute_with_legacy_jdbc_options(spark, processing):
def test_postgres_connection_execute_with_legacy_jdbc_options(spark, processing, caplog):
postgres = Postgres(
host=processing.host,
port=processing.port,
Expand All @@ -1034,4 +1040,7 @@ def test_postgres_connection_execute_with_legacy_jdbc_options(spark, processing)
)

options = Postgres.JDBCOptions(query_timeout=30)
postgres.execute("DROP TABLE IF EXISTS temp_table;", options=options)

with caplog.at_level(logging.INFO):
postgres.execute("DROP TABLE IF EXISTS temp_table;", options=options)
assert "Detected dialect: 'org.apache.spark.sql.jdbc.PostgresDialect'" in caplog.text
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import re
import secrets
from datetime import date, datetime, timedelta
Expand Down Expand Up @@ -182,6 +183,7 @@ def test_postgres_strategy_incremental_batch_different_hwm_type_in_store(
hwm_column,
new_type,
step,
caplog,
):
postgres = Postgres(
host=processing.host,
Expand All @@ -200,7 +202,9 @@ def test_postgres_strategy_incremental_batch_different_hwm_type_in_store(

with IncrementalBatchStrategy(step=step) as batches:
for _ in batches:
reader.run()
with caplog.at_level(logging.INFO):
reader.run()
assert "Detected dialect: 'org.apache.spark.sql.jdbc.PostgresDialect'" in caplog.text

# change table schema
new_fields = {column_name: processing.get_column_type(column_name) for column_name in processing.column_names}
Expand Down