Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DOP-16999] - Add jdbc_dialect logging #305

Merged
merged 3 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/changelog/next_release/305.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add log.info about JDBC dialect usage: ``Detected dialect: 'org.apache.spark.sql.jdbc.MySQLDialect'``
2 changes: 2 additions & 0 deletions onetl/connection/db_connection/jdbc_connection/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def sql(

query = clear_statement(query)

log.info("|%s| Detected dialect: '%s'", self.__class__.__name__, self.jdbc_dialect)
log.info("|%s| Executing SQL query (on executor):", self.__class__.__name__)
log_lines(log, query)

Expand Down Expand Up @@ -195,6 +196,7 @@ def get_df_schema(
columns: list[str] | None = None,
options: JDBCReadOptions | None = None,
) -> StructType:
log.info("|%s| Detected dialect: '%s'", self.__class__.__name__, self.jdbc_dialect)
log.info("|%s| Fetching schema of table %r ...", self.__class__.__name__, source)

query = self.dialect.get_sql_query(source, columns=columns, limit=0, compact=True)
Expand Down
12 changes: 12 additions & 0 deletions onetl/connection/db_connection/jdbc_mixin/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ def fetch(

query = clear_statement(query)

log.info("|%s| Detected dialect: '%s'", self.__class__.__name__, self.jdbc_dialect)
log.info("|%s| Executing SQL query (on driver):", self.__class__.__name__)
log_lines(log, query)

Expand Down Expand Up @@ -277,6 +278,7 @@ def execute(

statement = clear_statement(statement)

log.info("|%s| Detected dialect: '%s'", self.__class__.__name__, self.jdbc_dialect)
log.info("|%s| Executing statement (on driver):", self.__class__.__name__)
log_lines(log, statement)

Expand Down Expand Up @@ -308,6 +310,16 @@ def execute(
log_lines(log, str(metrics))
return df

@property
def jdbc_dialect(self):
dolfinus marked this conversation as resolved.
Show resolved Hide resolved
"""
Returns the JDBC dialect associated with the connection URL.
"""
jdbc_dialects_package = self.spark._jvm.org.apache.spark.sql.jdbc
dialect = jdbc_dialects_package.JdbcDialects.get(self.jdbc_url).toString()

return dialect.split("$")[0] if "$" in dialect else dialect

@validator("spark")
def _check_java_class_imported(cls, spark):
try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -321,3 +321,38 @@ def func_finalizer():
# wrong syntax
with pytest.raises(Exception):
clickhouse.execute(f"CREATE FUNCTION wrong_function AS (a, b) -> {suffix}")


def test_clickhouse_connection_no_jdbc_dialect(spark, processing, load_table_data, caplog):
clickhouse = Clickhouse(
host=processing.host,
port=processing.port,
user=processing.user,
password=processing.password,
database=processing.database,
spark=spark,
)

table = load_table_data.full_name
clickhouse.get_df_schema(table)

with caplog.at_level(logging.INFO):
assert "Detected dialect: 'org.apache.spark.sql.jdbc.NoopDialect'" in caplog.text
dolfinus marked this conversation as resolved.
Show resolved Hide resolved

# clear the caplog buffer
caplog.clear()
clickhouse.sql("SELECT version()")
with caplog.at_level(logging.INFO):
assert "Detected dialect: 'org.apache.spark.sql.jdbc.NoopDialect'" in caplog.text

# clear the caplog buffer
caplog.clear()
clickhouse.fetch("SELECT version()")
with caplog.at_level(logging.INFO):
assert "Detected dialect: 'org.apache.spark.sql.jdbc.NoopDialect'" in caplog.text

# clear the caplog buffer
caplog.clear()
clickhouse.execute(f"TRUNCATE TABLE {table}")
with caplog.at_level(logging.INFO):
assert "Detected dialect: 'org.apache.spark.sql.jdbc.NoopDialect'" in caplog.text
Original file line number Diff line number Diff line change
Expand Up @@ -1035,3 +1035,41 @@ def test_postgres_connection_execute_with_legacy_jdbc_options(spark, processing)

options = Postgres.JDBCOptions(query_timeout=30)
postgres.execute("DROP TABLE IF EXISTS temp_table;", options=options)


def test_postgres_connection_jdbc_dialect_usage(spark, processing, load_table_data, caplog):
postgres = Postgres(
host=processing.host,
port=processing.port,
user=processing.user,
password=processing.password,
database=processing.database,
spark=spark,
)

table = load_table_data.full_name
postgres.get_df_schema(table)

with caplog.at_level(logging.INFO):
assert "Detected dialect: 'org.apache.spark.sql.jdbc.PostgresDialect'" in caplog.text

# clear the caplog buffer
caplog.clear()
postgres.sql("SELECT version()")
with caplog.at_level(logging.INFO):
assert "Detected dialect: 'org.apache.spark.sql.jdbc.PostgresDialect'" in caplog.text

caplog.clear()
postgres.fetch("SELECT version()")
with caplog.at_level(logging.INFO):
assert "Detected dialect: 'org.apache.spark.sql.jdbc.PostgresDialect'" in caplog.text

caplog.clear()
postgres.fetch("SELECT version()")
with caplog.at_level(logging.INFO):
assert "Detected dialect: 'org.apache.spark.sql.jdbc.PostgresDialect'" in caplog.text

caplog.clear()
postgres.execute(f"TRUNCATE TABLE {table}")
with caplog.at_level(logging.INFO):
assert "Detected dialect: 'org.apache.spark.sql.jdbc.PostgresDialect'" in caplog.text
dolfinus marked this conversation as resolved.
Show resolved Hide resolved