Skip to content

Commit

Permalink
Add server side parameters to session connection method (#823)
Browse files Browse the repository at this point in the history
* Pass existing server_side_parameters to session connection wrapper and use to configure SparkSession.

* Incorporating feedback. Moved server side parameters to Connection and pass to cursor from there.

* Add changie

* Add type hint

* Write out loop

* Add type hint

* Remove server_side_parameters from connection wrapper

* Add handle type hint

* Make server_side_parameters optional

---------

Co-authored-by: Anthony LaRocca <[email protected]>
Co-authored-by: Mike Alfare <[email protected]>
  • Loading branch information
3 people authored Jul 24, 2023
1 parent 98f4276 commit d91bd17
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 5 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20230707-104150.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Support server_side_parameters for Spark session connection method
time: 2023-07-07T10:41:50.01541+02:00
custom:
Author: alarocca-apixio
Issue: "690"
5 changes: 4 additions & 1 deletion dbt/adapters/spark/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,7 @@ def open(cls, connection: Connection) -> Connection:

creds = connection.credentials
exc = None
handle: Any

for i in range(1 + creds.connect_retries):
try:
Expand Down Expand Up @@ -460,7 +461,9 @@ def open(cls, connection: Connection) -> Connection:
SessionConnectionWrapper,
)

handle = SessionConnectionWrapper(Connection()) # type: ignore
handle = SessionConnectionWrapper(
Connection(server_side_parameters=creds.server_side_parameters)
)
else:
raise dbt.exceptions.DbtProfileError(
f"invalid credential method: {creds.method}"
Expand Down
17 changes: 13 additions & 4 deletions dbt/adapters/spark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import datetime as dt
from types import TracebackType
from typing import Any, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union

from dbt.events import AdapterLogger
from dbt.utils import DECIMALS
Expand All @@ -24,9 +24,10 @@ class Cursor:
https://github.com/mkleehammer/pyodbc/wiki/Cursor
"""

def __init__(self) -> None:
def __init__(self, *, server_side_parameters: Optional[Dict[str, Any]] = None) -> None:
self._df: Optional[DataFrame] = None
self._rows: Optional[List[Row]] = None
self.server_side_parameters = server_side_parameters or {}

def __enter__(self) -> Cursor:
return self
Expand Down Expand Up @@ -106,7 +107,12 @@ def execute(self, sql: str, *parameters: Any) -> None:
"""
if len(parameters) > 0:
sql = sql % parameters
spark_session = SparkSession.builder.enableHiveSupport().getOrCreate()
builder = SparkSession.builder.enableHiveSupport()

for parameter, value in self.server_side_parameters.items():
builder = builder.config(parameter, value)

spark_session = builder.getOrCreate()
self._df = spark_session.sql(sql)

def fetchall(self) -> Optional[List[Row]]:
Expand Down Expand Up @@ -159,6 +165,9 @@ class Connection:
https://github.com/mkleehammer/pyodbc/wiki/Connection
"""

def __init__(self, *, server_side_parameters: Optional[Dict[Any, str]] = None) -> None:
self.server_side_parameters = server_side_parameters or {}

def cursor(self) -> Cursor:
"""
Get a cursor.
Expand All @@ -168,7 +177,7 @@ def cursor(self) -> Cursor:
out : Cursor
The cursor.
"""
return Cursor()
return Cursor(server_side_parameters=self.server_side_parameters)


class SessionConnectionWrapper(object):
Expand Down

0 comments on commit d91bd17

Please sign in to comment.