Skip to content

Commit

Permalink
[DOP-19024] Fix passing custom JDBC options to Greenplum.extra
Browse files Browse the repository at this point in the history
  • Loading branch information
dolfinus committed Aug 27, 2024
1 parent c6b09be commit 6c4e8da
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 14 deletions.
1 change: 1 addition & 0 deletions docs/changelog/next_release/308.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix passing ``Greenplum(extra={"options": ...)`` during read/write operations.
42 changes: 28 additions & 14 deletions onetl/connection/db_connection/greenplum/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import textwrap
import warnings
from typing import TYPE_CHECKING, Any, ClassVar
from urllib.parse import quote, urlencode, urlparse, urlunparse

from etl_entities.instance import Host

Expand Down Expand Up @@ -274,17 +275,20 @@ def __str__(self):
def jdbc_url(self) -> str:
return f"jdbc:postgresql://{self.host}:{self.port}/{self.database}"

@property
def jdbc_custom_params(self) -> dict:
result = {
key: value
for key, value in self.extra.dict(by_alias=True).items()
if not (key.startswith("server.") or key.startswith("pool."))
}
result["ApplicationName"] = result.get("ApplicationName", self.spark.sparkContext.appName)
return result

@property
def jdbc_params(self) -> dict:
result = super().jdbc_params
result.update(
{
key: value
for key, value in self.extra.dict(by_alias=True).items()
if not (key.startswith("server.") or key.startswith("pool."))
},
)
result["ApplicationName"] = result.get("ApplicationName", self.spark.sparkContext.appName)
result.update(self.jdbc_custom_params)
return result

@slot
Expand All @@ -305,7 +309,7 @@ def read_source_as_df(
fake_query_for_log = self.dialect.get_sql_query(table=source, columns=columns, where=where, limit=limit)
log_lines(log, fake_query_for_log)

df = self.spark.read.format("greenplum").options(**self._connector_params(source), **read_options).load()
df = self.spark.read.format("greenplum").options(**self._get_connector_params(source), **read_options).load()
self._check_expected_jobs_number(df, action="read")

if where:
Expand Down Expand Up @@ -340,7 +344,7 @@ def write_df_to_target(
else write_options.if_exists.value
)
df.write.format("greenplum").options(
**self._connector_params(target),
**self._get_connector_params(target),
**options_dict,
).mode(mode).save()

Expand Down Expand Up @@ -425,21 +429,31 @@ def _check_java_class_imported(cls, spark):
raise ValueError(msg) from e
return spark

def _connector_params(
def _get_connector_params(
self,
table: str,
) -> dict:
schema, table_name = table.split(".") # noqa: WPS414
extra = self.extra.dict(by_alias=True, exclude_none=True)
extra = {key: value for key, value in extra.items() if key.startswith("server.") or key.startswith("pool.")}
greenplum_connector_options = {
key: value for key, value in extra.items() if key.startswith("server.") or key.startswith("pool.")
}

# Greenplum connector requires all JDBC params to be passed via JDBC URL:
# https://docs.vmware.com/en/VMware-Greenplum-Connector-for-Apache-Spark/2.3/greenplum-connector-spark/using_the_connector.html#specifying-session-parameters
parsed_jdbc_url = urlparse(self.jdbc_url)
sorted_jdbc_params = [(k, v) for k, v in sorted(self.jdbc_custom_params.items(), key=lambda x: x[0].lower())]
jdbc_url_query = urlencode(sorted_jdbc_params, quote_via=quote)
jdbc_url = urlunparse(parsed_jdbc_url._replace(query=jdbc_url_query))

return {
"driver": self.DRIVER,
"url": self.jdbc_url,
"url": jdbc_url,
"user": self.user,
"password": self.password.get_secret_value(),
"dbschema": schema,
"dbtable": table_name,
**extra,
**greenplum_connector_options,
}

def _options_to_connection_properties(self, options: JDBCFetchOptions | JDBCExecuteOptions):
Expand Down
28 changes: 28 additions & 0 deletions tests/tests_unit/tests_db_connection_unit/test_greenplum_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,14 @@ def test_greenplum(spark_mock):
"ApplicationName": "abc",
"tcpKeepAlive": "true",
}
assert conn._get_connector_params("some.table") == {
"user": "user",
"password": "passwd",
"driver": "org.postgresql.Driver",
"url": "jdbc:postgresql://some_host:5432/database?ApplicationName=abc&tcpKeepAlive=true",
"dbschema": "some",
"dbtable": "table",
}

assert "passwd" not in repr(conn)

Expand All @@ -154,6 +162,14 @@ def test_greenplum_with_port(spark_mock):
"ApplicationName": "abc",
"tcpKeepAlive": "true",
}
assert conn._get_connector_params("some.table") == {
"user": "user",
"password": "passwd",
"driver": "org.postgresql.Driver",
"url": "jdbc:postgresql://some_host:5000/database?ApplicationName=abc&tcpKeepAlive=true",
"dbschema": "some",
"dbtable": "table",
}

assert conn.instance_url == "greenplum://some_host:5000/database"
assert str(conn) == "Greenplum[some_host:5000/database]"
Expand All @@ -174,6 +190,7 @@ def test_greenplum_with_extra(spark_mock):
"autosave": "always",
"tcpKeepAlive": "false",
"ApplicationName": "override",
"options": "-c search_path=public",
"server.port": 8000,
"pool.maxSize": 40,
},
Expand All @@ -191,6 +208,17 @@ def test_greenplum_with_extra(spark_mock):
"ApplicationName": "override",
"tcpKeepAlive": "false",
"autosave": "always",
"options": "-c search_path=public",
}
assert conn._get_connector_params("some.table") == {
"user": "user",
"password": "passwd",
"driver": "org.postgresql.Driver",
"url": "jdbc:postgresql://some_host:5432/database?ApplicationName=override&autosave=always&options=-c%20search_path%3Dpublic&tcpKeepAlive=false",
"dbschema": "some",
"dbtable": "table",
"pool.maxSize": 40,
"server.port": 8000,
}


Expand Down

0 comments on commit 6c4e8da

Please sign in to comment.