diff --git a/docs/changelog/next_release/305.feature.rst b/docs/changelog/next_release/305.feature.rst new file mode 100644 index 00000000..672a16d7 --- /dev/null +++ b/docs/changelog/next_release/305.feature.rst @@ -0,0 +1 @@ +Add log.info about JDBC dialect usage: ``Using dialect: org.apache.spark.sql.jdbc.MySQLDialect`` diff --git a/onetl/connection/db_connection/jdbc_connection/connection.py b/onetl/connection/db_connection/jdbc_connection/connection.py index 9d41298e..32c0b65c 100644 --- a/onetl/connection/db_connection/jdbc_connection/connection.py +++ b/onetl/connection/db_connection/jdbc_connection/connection.py @@ -90,6 +90,7 @@ def sql( query = clear_statement(query) + log.info("|%s| Detected dialect: '%s'", self.__class__.__name__, self.jdbc_dialect) log.info("|%s| Executing SQL query (on executor):", self.__class__.__name__) log_lines(log, query) @@ -195,6 +196,7 @@ def get_df_schema( columns: list[str] | None = None, options: JDBCReadOptions | None = None, ) -> StructType: + log.info("|%s| Detected dialect: '%s'", self.__class__.__name__, self.jdbc_dialect) log.info("|%s| Fetching schema of table %r ...", self.__class__.__name__, source) query = self.dialect.get_sql_query(source, columns=columns, limit=0, compact=True) diff --git a/onetl/connection/db_connection/jdbc_mixin/connection.py b/onetl/connection/db_connection/jdbc_mixin/connection.py index 8ec77d13..5d309e58 100644 --- a/onetl/connection/db_connection/jdbc_mixin/connection.py +++ b/onetl/connection/db_connection/jdbc_mixin/connection.py @@ -205,6 +205,7 @@ def fetch( query = clear_statement(query) + log.info("|%s| Detected dialect: '%s'", self.__class__.__name__, self.jdbc_dialect) log.info("|%s| Executing SQL query (on driver):", self.__class__.__name__) log_lines(log, query) @@ -277,6 +278,7 @@ def execute( statement = clear_statement(statement) + log.info("|%s| Detected dialect: '%s'", self.__class__.__name__, self.jdbc_dialect) log.info("|%s| Executing statement (on driver):", self.__class__.__name__) log_lines(log, statement) @@ -308,6 +310,16 @@ def execute( log_lines(log, str(metrics)) return df + @property + def jdbc_dialect(self): + """ + Returns the JDBC dialect associated with the connection URL. + """ + jdbc_dialects_package = self.spark._jvm.org.apache.spark.sql.jdbc + dialect = jdbc_dialects_package.JdbcDialects.get(self.jdbc_url).toString() + + return dialect.split("$")[0] if "$" in dialect else dialect + @validator("spark") def _check_java_class_imported(cls, spark): try: diff --git a/tests/tests_integration/tests_db_connection_integration/test_clickhouse_integration.py b/tests/tests_integration/tests_db_connection_integration/test_clickhouse_integration.py index 78656d83..1bd2ebf3 100644 --- a/tests/tests_integration/tests_db_connection_integration/test_clickhouse_integration.py +++ b/tests/tests_integration/tests_db_connection_integration/test_clickhouse_integration.py @@ -321,3 +321,38 @@ def func_finalizer(): # wrong syntax with pytest.raises(Exception): clickhouse.execute(f"CREATE FUNCTION wrong_function AS (a, b) -> {suffix}") + + +def test_clickhouse_connection_no_jdbc_dialect(spark, processing, load_table_data, caplog): + clickhouse = Clickhouse( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + database=processing.database, + spark=spark, + ) + + table = load_table_data.full_name + clickhouse.get_df_schema(table) + + with caplog.at_level(logging.INFO): + assert "Detected dialect: 'org.apache.spark.sql.jdbc.NoopDialect'" in caplog.text + + # clear the caplog buffer + caplog.clear() + clickhouse.sql("SELECT version()") + with caplog.at_level(logging.INFO): + assert "Detected dialect: 'org.apache.spark.sql.jdbc.NoopDialect'" in caplog.text + + # clear the caplog buffer + caplog.clear() + clickhouse.fetch("SELECT version()") + with caplog.at_level(logging.INFO): + assert "Detected dialect: 'org.apache.spark.sql.jdbc.NoopDialect'" in caplog.text + + # clear the caplog buffer + caplog.clear() + clickhouse.execute(f"TRUNCATE TABLE {table}") + with caplog.at_level(logging.INFO): + assert "Detected dialect: 'org.apache.spark.sql.jdbc.NoopDialect'" in caplog.text diff --git a/tests/tests_integration/tests_db_connection_integration/test_postgres_integration.py b/tests/tests_integration/tests_db_connection_integration/test_postgres_integration.py index 6cea95cc..0b05e7f9 100644 --- a/tests/tests_integration/tests_db_connection_integration/test_postgres_integration.py +++ b/tests/tests_integration/tests_db_connection_integration/test_postgres_integration.py @@ -1035,3 +1035,41 @@ def test_postgres_connection_execute_with_legacy_jdbc_options(spark, processing) options = Postgres.JDBCOptions(query_timeout=30) postgres.execute("DROP TABLE IF EXISTS temp_table;", options=options) + + +def test_postgres_connection_jdbc_dialect_usage(spark, processing, load_table_data, caplog): + postgres = Postgres( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + database=processing.database, + spark=spark, + ) + + table = load_table_data.full_name + postgres.get_df_schema(table) + + with caplog.at_level(logging.INFO): + assert "Detected dialect: 'org.apache.spark.sql.jdbc.PostgresDialect'" in caplog.text + + # clear the caplog buffer + caplog.clear() + postgres.sql("SELECT version()") + with caplog.at_level(logging.INFO): + assert "Detected dialect: 'org.apache.spark.sql.jdbc.PostgresDialect'" in caplog.text + + caplog.clear() + postgres.fetch("SELECT version()") + with caplog.at_level(logging.INFO): + assert "Detected dialect: 'org.apache.spark.sql.jdbc.PostgresDialect'" in caplog.text + + caplog.clear() + postgres.fetch("SELECT version()") + with caplog.at_level(logging.INFO): + assert "Detected dialect: 'org.apache.spark.sql.jdbc.PostgresDialect'" in caplog.text + + caplog.clear() + postgres.execute(f"TRUNCATE TABLE {table}") + with caplog.at_level(logging.INFO): + assert "Detected dialect: 'org.apache.spark.sql.jdbc.PostgresDialect'" in caplog.text