From d91bd17fc919536820131181ce959110920da063 Mon Sep 17 00:00:00 2001 From: Cor Date: Mon, 24 Jul 2023 20:07:58 +0200 Subject: [PATCH] Add server side parameters to session connection method (#823) * Pass existing server_side_parameters to session connection wrapper and use to configure SparkSession. * Incorporating feedback. Moved server side parameters to Connection and pass to cursor from there. * Add changie * Add type hint * Write out loop * Add type hint * Remove server_side_parameters from connection wrapper * Add handle type hint * Make server_side_parameters optional --------- Co-authored-by: Anthony LaRocca Co-authored-by: Mike Alfare <13974384+mikealfare@users.noreply.github.com> --- .../unreleased/Features-20230707-104150.yaml | 6 ++++++ dbt/adapters/spark/connections.py | 5 ++++- dbt/adapters/spark/session.py | 17 +++++++++++++---- 3 files changed, 23 insertions(+), 5 deletions(-) create mode 100644 .changes/unreleased/Features-20230707-104150.yaml diff --git a/.changes/unreleased/Features-20230707-104150.yaml b/.changes/unreleased/Features-20230707-104150.yaml new file mode 100644 index 000000000..183a37b45 --- /dev/null +++ b/.changes/unreleased/Features-20230707-104150.yaml @@ -0,0 +1,6 @@ +kind: Features +body: Support server_side_parameters for Spark session connection method +time: 2023-07-07T10:41:50.01541+02:00 +custom: + Author: alarocca-apixio + Issue: "690" diff --git a/dbt/adapters/spark/connections.py b/dbt/adapters/spark/connections.py index 2a7f8188d..5d3e99a64 100644 --- a/dbt/adapters/spark/connections.py +++ b/dbt/adapters/spark/connections.py @@ -350,6 +350,7 @@ def open(cls, connection: Connection) -> Connection: creds = connection.credentials exc = None + handle: Any for i in range(1 + creds.connect_retries): try: @@ -460,7 +461,9 @@ def open(cls, connection: Connection) -> Connection: SessionConnectionWrapper, ) - handle = SessionConnectionWrapper(Connection()) # type: ignore + handle = SessionConnectionWrapper( + Connection(server_side_parameters=creds.server_side_parameters) + ) else: raise dbt.exceptions.DbtProfileError( f"invalid credential method: {creds.method}" diff --git a/dbt/adapters/spark/session.py b/dbt/adapters/spark/session.py index 5e4bcc492..0e3717172 100644 --- a/dbt/adapters/spark/session.py +++ b/dbt/adapters/spark/session.py @@ -4,7 +4,7 @@ import datetime as dt from types import TracebackType -from typing import Any, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union from dbt.events import AdapterLogger from dbt.utils import DECIMALS @@ -24,9 +24,10 @@ class Cursor: https://github.com/mkleehammer/pyodbc/wiki/Cursor """ - def __init__(self) -> None: + def __init__(self, *, server_side_parameters: Optional[Dict[str, Any]] = None) -> None: self._df: Optional[DataFrame] = None self._rows: Optional[List[Row]] = None + self.server_side_parameters = server_side_parameters or {} def __enter__(self) -> Cursor: return self @@ -106,7 +107,12 @@ def execute(self, sql: str, *parameters: Any) -> None: """ if len(parameters) > 0: sql = sql % parameters - spark_session = SparkSession.builder.enableHiveSupport().getOrCreate() + builder = SparkSession.builder.enableHiveSupport() + + for parameter, value in self.server_side_parameters.items(): + builder = builder.config(parameter, value) + + spark_session = builder.getOrCreate() self._df = spark_session.sql(sql) def fetchall(self) -> Optional[List[Row]]: @@ -159,6 +165,9 @@ class Connection: https://github.com/mkleehammer/pyodbc/wiki/Connection """ + def __init__(self, *, server_side_parameters: Optional[Dict[Any, str]] = None) -> None: + self.server_side_parameters = server_side_parameters or {} + def cursor(self) -> Cursor: """ Get a cursor. @@ -168,7 +177,7 @@ def cursor(self) -> Cursor: out : Cursor The cursor. """ - return Cursor() + return Cursor(server_side_parameters=self.server_side_parameters) class SessionConnectionWrapper(object):