From 6c4e8da7d9571fe3e78a17aa70194ae60d2d6a6d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=9C=D0=B0=D1=80=D1=82=D1=8B=D0=BD=D0=BE=D0=B2=20=D0=9C?= =?UTF-8?q?=D0=B0=D0=BA=D1=81=D0=B8=D0=BC=20=D0=A1=D0=B5=D1=80=D0=B3=D0=B5?= =?UTF-8?q?=D0=B5=D0=B2=D0=B8=D1=87?= Date: Tue, 27 Aug 2024 12:40:00 +0000 Subject: [PATCH] [DOP-19024] Fix passing custom JDBC options to Greenplum.extra --- docs/changelog/next_release/308.bugfix.rst | 1 + .../db_connection/greenplum/connection.py | 42 ++++++++++++------- .../test_greenplum_unit.py | 28 +++++++++++++ 3 files changed, 57 insertions(+), 14 deletions(-) create mode 100644 docs/changelog/next_release/308.bugfix.rst diff --git a/docs/changelog/next_release/308.bugfix.rst b/docs/changelog/next_release/308.bugfix.rst new file mode 100644 index 00000000..3ffcdcc5 --- /dev/null +++ b/docs/changelog/next_release/308.bugfix.rst @@ -0,0 +1 @@ +Fix passing ``Greenplum(extra={"options": ...)`` during read/write operations. diff --git a/onetl/connection/db_connection/greenplum/connection.py b/onetl/connection/db_connection/greenplum/connection.py index 0f40436f..cc3191af 100644 --- a/onetl/connection/db_connection/greenplum/connection.py +++ b/onetl/connection/db_connection/greenplum/connection.py @@ -7,6 +7,7 @@ import textwrap import warnings from typing import TYPE_CHECKING, Any, ClassVar +from urllib.parse import quote, urlencode, urlparse, urlunparse from etl_entities.instance import Host @@ -274,17 +275,20 @@ def __str__(self): def jdbc_url(self) -> str: return f"jdbc:postgresql://{self.host}:{self.port}/{self.database}" + @property + def jdbc_custom_params(self) -> dict: + result = { + key: value + for key, value in self.extra.dict(by_alias=True).items() + if not (key.startswith("server.") or key.startswith("pool.")) + } + result["ApplicationName"] = result.get("ApplicationName", self.spark.sparkContext.appName) + return result + @property def jdbc_params(self) -> dict: result = super().jdbc_params - result.update( - { - key: value - for key, value in self.extra.dict(by_alias=True).items() - if not (key.startswith("server.") or key.startswith("pool.")) - }, - ) - result["ApplicationName"] = result.get("ApplicationName", self.spark.sparkContext.appName) + result.update(self.jdbc_custom_params) return result @slot @@ -305,7 +309,7 @@ def read_source_as_df( fake_query_for_log = self.dialect.get_sql_query(table=source, columns=columns, where=where, limit=limit) log_lines(log, fake_query_for_log) - df = self.spark.read.format("greenplum").options(**self._connector_params(source), **read_options).load() + df = self.spark.read.format("greenplum").options(**self._get_connector_params(source), **read_options).load() self._check_expected_jobs_number(df, action="read") if where: @@ -340,7 +344,7 @@ def write_df_to_target( else write_options.if_exists.value ) df.write.format("greenplum").options( - **self._connector_params(target), + **self._get_connector_params(target), **options_dict, ).mode(mode).save() @@ -425,21 +429,31 @@ def _check_java_class_imported(cls, spark): raise ValueError(msg) from e return spark - def _connector_params( + def _get_connector_params( self, table: str, ) -> dict: schema, table_name = table.split(".") # noqa: WPS414 extra = self.extra.dict(by_alias=True, exclude_none=True) - extra = {key: value for key, value in extra.items() if key.startswith("server.") or key.startswith("pool.")} + greenplum_connector_options = { + key: value for key, value in extra.items() if key.startswith("server.") or key.startswith("pool.") + } + + # Greenplum connector requires all JDBC params to be passed via JDBC URL: + # https://docs.vmware.com/en/VMware-Greenplum-Connector-for-Apache-Spark/2.3/greenplum-connector-spark/using_the_connector.html#specifying-session-parameters + parsed_jdbc_url = urlparse(self.jdbc_url) + sorted_jdbc_params = [(k, v) for k, v in sorted(self.jdbc_custom_params.items(), key=lambda x: x[0].lower())] + jdbc_url_query = urlencode(sorted_jdbc_params, quote_via=quote) + jdbc_url = urlunparse(parsed_jdbc_url._replace(query=jdbc_url_query)) + return { "driver": self.DRIVER, - "url": self.jdbc_url, + "url": jdbc_url, "user": self.user, "password": self.password.get_secret_value(), "dbschema": schema, "dbtable": table_name, - **extra, + **greenplum_connector_options, } def _options_to_connection_properties(self, options: JDBCFetchOptions | JDBCExecuteOptions): diff --git a/tests/tests_unit/tests_db_connection_unit/test_greenplum_unit.py b/tests/tests_unit/tests_db_connection_unit/test_greenplum_unit.py index 47821642..b6ea9544 100644 --- a/tests/tests_unit/tests_db_connection_unit/test_greenplum_unit.py +++ b/tests/tests_unit/tests_db_connection_unit/test_greenplum_unit.py @@ -128,6 +128,14 @@ def test_greenplum(spark_mock): "ApplicationName": "abc", "tcpKeepAlive": "true", } + assert conn._get_connector_params("some.table") == { + "user": "user", + "password": "passwd", + "driver": "org.postgresql.Driver", + "url": "jdbc:postgresql://some_host:5432/database?ApplicationName=abc&tcpKeepAlive=true", + "dbschema": "some", + "dbtable": "table", + } assert "passwd" not in repr(conn) @@ -154,6 +162,14 @@ def test_greenplum_with_port(spark_mock): "ApplicationName": "abc", "tcpKeepAlive": "true", } + assert conn._get_connector_params("some.table") == { + "user": "user", + "password": "passwd", + "driver": "org.postgresql.Driver", + "url": "jdbc:postgresql://some_host:5000/database?ApplicationName=abc&tcpKeepAlive=true", + "dbschema": "some", + "dbtable": "table", + } assert conn.instance_url == "greenplum://some_host:5000/database" assert str(conn) == "Greenplum[some_host:5000/database]" @@ -174,6 +190,7 @@ def test_greenplum_with_extra(spark_mock): "autosave": "always", "tcpKeepAlive": "false", "ApplicationName": "override", + "options": "-c search_path=public", "server.port": 8000, "pool.maxSize": 40, }, @@ -191,6 +208,17 @@ def test_greenplum_with_extra(spark_mock): "ApplicationName": "override", "tcpKeepAlive": "false", "autosave": "always", + "options": "-c search_path=public", + } + assert conn._get_connector_params("some.table") == { + "user": "user", + "password": "passwd", + "driver": "org.postgresql.Driver", + "url": "jdbc:postgresql://some_host:5432/database?ApplicationName=override&autosave=always&options=-c%20search_path%3Dpublic&tcpKeepAlive=false", + "dbschema": "some", + "dbtable": "table", + "pool.maxSize": 40, + "server.port": 8000, }