diff --git a/docs/conf.py b/docs/conf.py index f781dddd4..867d4daff 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -120,9 +120,6 @@ {"rel": "icon", "href": "icon.svg", "type": "image/svg+xml"}, ] -# TODO: remove after https://github.com/mgeier/sphinx-last-updated-by-git/pull/77 -git_exclude_patterns = ["docs/_static/logo_wide.svg"] - # The master toctree document. master_doc = "index" diff --git a/onetl/connection/db_connection/jdbc_connection/connection.py b/onetl/connection/db_connection/jdbc_connection/connection.py index 32c0b65c6..0f3ac024e 100644 --- a/onetl/connection/db_connection/jdbc_connection/connection.py +++ b/onetl/connection/db_connection/jdbc_connection/connection.py @@ -90,7 +90,7 @@ def sql( query = clear_statement(query) - log.info("|%s| Detected dialect: '%s'", self.__class__.__name__, self.jdbc_dialect) + log.info("|%s| Detected dialect: '%s'", self.__class__.__name__, self._get_spark_dialect_name()) log.info("|%s| Executing SQL query (on executor):", self.__class__.__name__) log_lines(log, query) @@ -196,7 +196,7 @@ def get_df_schema( columns: list[str] | None = None, options: JDBCReadOptions | None = None, ) -> StructType: - log.info("|%s| Detected dialect: '%s'", self.__class__.__name__, self.jdbc_dialect) + log.info("|%s| Detected dialect: '%s'", self.__class__.__name__, self._get_spark_dialect_name()) log.info("|%s| Fetching schema of table %r ...", self.__class__.__name__, source) query = self.dialect.get_sql_query(source, columns=columns, limit=0, compact=True) diff --git a/onetl/connection/db_connection/jdbc_mixin/connection.py b/onetl/connection/db_connection/jdbc_mixin/connection.py index 5d309e584..2f25b5a9f 100644 --- a/onetl/connection/db_connection/jdbc_mixin/connection.py +++ b/onetl/connection/db_connection/jdbc_mixin/connection.py @@ -205,7 +205,7 @@ def fetch( query = clear_statement(query) - log.info("|%s| Detected dialect: '%s'", self.__class__.__name__, self.jdbc_dialect) + log.info("|%s| Detected dialect: '%s'", self.__class__.__name__, self._get_spark_dialect_name()) log.info("|%s| Executing SQL query (on driver):", self.__class__.__name__) log_lines(log, query) @@ -278,7 +278,7 @@ def execute( statement = clear_statement(statement) - log.info("|%s| Detected dialect: '%s'", self.__class__.__name__, self.jdbc_dialect) + log.info("|%s| Detected dialect: '%s'", self.__class__.__name__, self._get_spark_dialect_name()) log.info("|%s| Executing statement (on driver):", self.__class__.__name__) log_lines(log, statement) @@ -310,16 +310,6 @@ def execute( log_lines(log, str(metrics)) return df - @property - def jdbc_dialect(self): - """ - Returns the JDBC dialect associated with the connection URL. - """ - jdbc_dialects_package = self.spark._jvm.org.apache.spark.sql.jdbc - dialect = jdbc_dialects_package.JdbcDialects.get(self.jdbc_url).toString() - - return dialect.split("$")[0] if "$" in dialect else dialect - @validator("spark") def _check_java_class_imported(cls, spark): try: @@ -429,6 +419,17 @@ def _get_jdbc_connection(self, options: JDBCFetchOptions | JDBCExecuteOptions): self._last_connection_and_options.data = (new_connection, options) return new_connection + def _get_spark_dialect_name(self) -> str: + """ + Returns the name of the JDBC dialect associated with the connection URL. + """ + dialect = self._get_spark_dialect().toString() + return dialect.split("$")[0] if "$" in dialect else dialect + + def _get_spark_dialect(self): + jdbc_dialects_package = self.spark._jvm.org.apache.spark.sql.jdbc + return jdbc_dialects_package.JdbcDialects.get(self.jdbc_url) + def _close_connections(self): with suppress(Exception): # connection maybe not opened yet @@ -571,9 +572,7 @@ def _resultset_to_dataframe(self, result_set) -> DataFrame: from pyspark.sql import DataFrame # noqa: WPS442 - jdbc_dialects_package = self.spark._jvm.org.apache.spark.sql.jdbc # type: ignore - jdbc_dialect = jdbc_dialects_package.JdbcDialects.get(self.jdbc_url) - + jdbc_dialect = self._get_spark_dialect() jdbc_utils_package = self.spark._jvm.org.apache.spark.sql.execution.datasources.jdbc # type: ignore jdbc_utils = jdbc_utils_package.JdbcUtils diff --git a/requirements/docs.txt b/requirements/docs.txt index be2cd1275..877683501 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -9,7 +9,8 @@ sphinx<8 sphinx-copybutton sphinx-design sphinx-favicon -sphinx-last-updated-by-git +# https://github.com/mgeier/sphinx-last-updated-by-git/pull/77 +sphinx-last-updated-by-git>=0.3.8 # TODO: uncomment after https://github.com/zqmillet/sphinx-plantuml/pull/4 # sphinx-plantuml sphinx-tabs diff --git a/tests/tests_integration/tests_db_connection_integration/test_clickhouse_integration.py b/tests/tests_integration/tests_db_connection_integration/test_clickhouse_integration.py index 1bd2ebf32..aa9205b83 100644 --- a/tests/tests_integration/tests_db_connection_integration/test_clickhouse_integration.py +++ b/tests/tests_integration/tests_db_connection_integration/test_clickhouse_integration.py @@ -62,7 +62,7 @@ def test_clickhouse_connection_check_extra_is_handled_by_driver(spark, processin @pytest.mark.parametrize("suffix", ["", ";"]) -def test_clickhouse_connection_sql(spark, processing, load_table_data, suffix): +def test_clickhouse_connection_sql(spark, processing, load_table_data, suffix, caplog): clickhouse = Clickhouse( host=processing.host, port=processing.port, @@ -73,7 +73,11 @@ def test_clickhouse_connection_sql(spark, processing, load_table_data, suffix): ) table = load_table_data.full_name - df = clickhouse.sql(f"SELECT * FROM {table}{suffix}") + + with caplog.at_level(logging.INFO): + df = clickhouse.sql(f"SELECT * FROM {table}{suffix}") + assert "Detected dialect: 'org.apache.spark.sql.jdbc.NoopDialect'" in caplog.text + table_df = processing.get_expected_dataframe( schema=load_table_data.schema, table=load_table_data.table, @@ -91,7 +95,7 @@ def test_clickhouse_connection_sql(spark, processing, load_table_data, suffix): @pytest.mark.parametrize("suffix", ["", ";"]) -def test_clickhouse_connection_fetch(spark, processing, load_table_data, suffix): +def test_clickhouse_connection_fetch(spark, processing, load_table_data, suffix, caplog): clickhouse = Clickhouse( host=processing.host, port=processing.port, @@ -103,7 +107,10 @@ def test_clickhouse_connection_fetch(spark, processing, load_table_data, suffix) schema = load_table_data.schema table = load_table_data.full_name - df = clickhouse.fetch(f"SELECT * FROM {table}{suffix}") + + with caplog.at_level(logging.INFO): + df = clickhouse.fetch(f"SELECT * FROM {table}{suffix}") + assert "Detected dialect: 'org.apache.spark.sql.jdbc.NoopDialect'" in caplog.text table_df = processing.get_expected_dataframe( schema=load_table_data.schema, @@ -192,7 +199,7 @@ def test_clickhouse_connection_execute_ddl(spark, processing, get_schema_table, @pytest.mark.flaky @pytest.mark.parametrize("suffix", ["", ";"]) -def test_clickhouse_connection_execute_dml(request, spark, processing, load_table_data, suffix): +def test_clickhouse_connection_execute_dml(request, spark, processing, load_table_data, suffix, caplog): clickhouse = Clickhouse( host=processing.host, port=processing.port, @@ -242,7 +249,9 @@ def table_finalizer(): updated_df = pandas.concat([updated_rows, unchanged_rows]) processing.assert_equal_df(df=df, other_frame=updated_df, order_by="id_int") - clickhouse.execute(f"UPDATE {temp_table} SET hwm_int = 1 WHERE id_int < 50{suffix}") + with caplog.at_level(logging.INFO): + clickhouse.execute(f"UPDATE {temp_table} SET hwm_int = 1 WHERE id_int < 50{suffix}") + assert "Detected dialect: 'org.apache.spark.sql.jdbc.NoopDialect'" in caplog.text clickhouse.execute(f"ALTER TABLE {temp_table} DELETE WHERE id_int < 70{suffix}") df = clickhouse.fetch(f"SELECT * FROM {temp_table}{suffix}") @@ -273,6 +282,7 @@ def test_clickhouse_connection_execute_function( processing, load_table_data, suffix, + caplog, ): clickhouse = Clickhouse( host=processing.host, @@ -321,38 +331,3 @@ def func_finalizer(): # wrong syntax with pytest.raises(Exception): clickhouse.execute(f"CREATE FUNCTION wrong_function AS (a, b) -> {suffix}") - - -def test_clickhouse_connection_no_jdbc_dialect(spark, processing, load_table_data, caplog): - clickhouse = Clickhouse( - host=processing.host, - port=processing.port, - user=processing.user, - password=processing.password, - database=processing.database, - spark=spark, - ) - - table = load_table_data.full_name - clickhouse.get_df_schema(table) - - with caplog.at_level(logging.INFO): - assert "Detected dialect: 'org.apache.spark.sql.jdbc.NoopDialect'" in caplog.text - - # clear the caplog buffer - caplog.clear() - clickhouse.sql("SELECT version()") - with caplog.at_level(logging.INFO): - assert "Detected dialect: 'org.apache.spark.sql.jdbc.NoopDialect'" in caplog.text - - # clear the caplog buffer - caplog.clear() - clickhouse.fetch("SELECT version()") - with caplog.at_level(logging.INFO): - assert "Detected dialect: 'org.apache.spark.sql.jdbc.NoopDialect'" in caplog.text - - # clear the caplog buffer - caplog.clear() - clickhouse.execute(f"TRUNCATE TABLE {table}") - with caplog.at_level(logging.INFO): - assert "Detected dialect: 'org.apache.spark.sql.jdbc.NoopDialect'" in caplog.text diff --git a/tests/tests_integration/tests_db_connection_integration/test_postgres_integration.py b/tests/tests_integration/tests_db_connection_integration/test_postgres_integration.py index 0b05e7f9e..ead0275e2 100644 --- a/tests/tests_integration/tests_db_connection_integration/test_postgres_integration.py +++ b/tests/tests_integration/tests_db_connection_integration/test_postgres_integration.py @@ -48,7 +48,7 @@ def test_postgres_connection_check_fail(spark): @pytest.mark.parametrize("suffix", ["", ";"]) -def test_postgres_connection_sql(spark, processing, load_table_data, suffix): +def test_postgres_connection_sql(spark, processing, load_table_data, suffix, caplog): postgres = Postgres( host=processing.host, port=processing.port, @@ -60,7 +60,10 @@ def test_postgres_connection_sql(spark, processing, load_table_data, suffix): table = load_table_data.full_name - df = postgres.sql(f"SELECT * FROM {table}{suffix}") + with caplog.at_level(logging.INFO): + df = postgres.sql(f"SELECT * FROM {table}{suffix}") + assert "Detected dialect: 'org.apache.spark.sql.jdbc.PostgresDialect'" in caplog.text + table_df = processing.get_expected_dataframe( schema=load_table_data.schema, table=load_table_data.table, @@ -79,7 +82,7 @@ def test_postgres_connection_sql(spark, processing, load_table_data, suffix): @pytest.mark.parametrize("suffix", ["", ";"]) -def test_postgres_connection_fetch(spark, processing, load_table_data, suffix): +def test_postgres_connection_fetch(spark, processing, load_table_data, suffix, caplog): postgres = Postgres( host=processing.host, port=processing.port, @@ -91,7 +94,10 @@ def test_postgres_connection_fetch(spark, processing, load_table_data, suffix): table = load_table_data.full_name - df = postgres.fetch(f"SELECT * FROM {table}{suffix}", Postgres.FetchOptions(fetchsize=2)) + with caplog.at_level(logging.INFO): + df = postgres.fetch(f"SELECT * FROM {table}{suffix}", Postgres.FetchOptions(fetchsize=2)) + assert "Detected dialect: 'org.apache.spark.sql.jdbc.PostgresDialect'" in caplog.text + table_df = processing.get_expected_dataframe( schema=load_table_data.schema, table=load_table_data.table, @@ -1023,7 +1029,7 @@ def test_postgres_connection_fetch_with_legacy_jdbc_options(spark, processing): assert df is not None -def test_postgres_connection_execute_with_legacy_jdbc_options(spark, processing): +def test_postgres_connection_execute_with_legacy_jdbc_options(spark, processing, caplog): postgres = Postgres( host=processing.host, port=processing.port, @@ -1034,42 +1040,7 @@ def test_postgres_connection_execute_with_legacy_jdbc_options(spark, processing) ) options = Postgres.JDBCOptions(query_timeout=30) - postgres.execute("DROP TABLE IF EXISTS temp_table;", options=options) - - -def test_postgres_connection_jdbc_dialect_usage(spark, processing, load_table_data, caplog): - postgres = Postgres( - host=processing.host, - port=processing.port, - user=processing.user, - password=processing.password, - database=processing.database, - spark=spark, - ) - - table = load_table_data.full_name - postgres.get_df_schema(table) - - with caplog.at_level(logging.INFO): - assert "Detected dialect: 'org.apache.spark.sql.jdbc.PostgresDialect'" in caplog.text - - # clear the caplog buffer - caplog.clear() - postgres.sql("SELECT version()") - with caplog.at_level(logging.INFO): - assert "Detected dialect: 'org.apache.spark.sql.jdbc.PostgresDialect'" in caplog.text - - caplog.clear() - postgres.fetch("SELECT version()") - with caplog.at_level(logging.INFO): - assert "Detected dialect: 'org.apache.spark.sql.jdbc.PostgresDialect'" in caplog.text - - caplog.clear() - postgres.fetch("SELECT version()") - with caplog.at_level(logging.INFO): - assert "Detected dialect: 'org.apache.spark.sql.jdbc.PostgresDialect'" in caplog.text - caplog.clear() - postgres.execute(f"TRUNCATE TABLE {table}") with caplog.at_level(logging.INFO): + postgres.execute("DROP TABLE IF EXISTS temp_table;", options=options) assert "Detected dialect: 'org.apache.spark.sql.jdbc.PostgresDialect'" in caplog.text diff --git a/tests/tests_integration/tests_strategy_integration/test_strategy_incremental_batch.py b/tests/tests_integration/tests_strategy_integration/test_strategy_incremental_batch.py index 66c7ad31a..e72b91e83 100644 --- a/tests/tests_integration/tests_strategy_integration/test_strategy_incremental_batch.py +++ b/tests/tests_integration/tests_strategy_integration/test_strategy_incremental_batch.py @@ -1,3 +1,4 @@ +import logging import re import secrets from datetime import date, datetime, timedelta @@ -182,6 +183,7 @@ def test_postgres_strategy_incremental_batch_different_hwm_type_in_store( hwm_column, new_type, step, + caplog, ): postgres = Postgres( host=processing.host, @@ -200,7 +202,9 @@ def test_postgres_strategy_incremental_batch_different_hwm_type_in_store( with IncrementalBatchStrategy(step=step) as batches: for _ in batches: - reader.run() + with caplog.at_level(logging.INFO): + reader.run() + assert "Detected dialect: 'org.apache.spark.sql.jdbc.PostgresDialect'" in caplog.text # change table schema new_fields = {column_name: processing.get_column_type(column_name) for column_name in processing.column_names}