Skip to content

Commit

Permalink
Support setting session properties on a individual statement
Browse files Browse the repository at this point in the history
  • Loading branch information
mdesmet committed Feb 28, 2023
1 parent 7c66e94 commit acd2a24
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 15 deletions.
23 changes: 23 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,29 @@ conn = trino.dbapi.connect(
)
```

## Session properties

Session properties can be set on the connection

```python
import trino
conn = trino.dbapi.connect(
...,
session_properties={"query_max_run_time": "1d"}
)
```

### Statement properties

It's also possible to set a session property for a specific statement by setting it on the Cursor. This is especially handy in the case of hive partitions.

```python
import trino
conn = trino.dbapi.connect()
cur = conn.cursor(statement_properties={"hive.insert_existing_partitions_behavior": "OVERWRITE"})
cur.execute("INSERT INTO hive_partitioned_table SELECT * from another_table")
```

## Timezone

The time zone for the session can be explicitly set using the IANA time zone
Expand Down
16 changes: 16 additions & 0 deletions tests/integration/test_dbapi_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1303,6 +1303,22 @@ def test_rowcount_insert(trino_connection):
assert cur.rowcount == 1


def test_statement_properties(trino_connection):
exchange_compression_statement = "SHOW SESSION LIKE 'exchange_compression'"
cur = trino_connection.cursor()
cur.execute(exchange_compression_statement)
result = cur.fetchall()
assert result[0][1] == "false"
cur = trino_connection.cursor(statement_properties={"exchange_compression": True})
cur.execute(exchange_compression_statement)
result = cur.fetchall()
assert result[0][1] == "True"
cur = trino_connection.cursor()
cur.execute(exchange_compression_statement)
result = cur.fetchall()
assert result[0][1] == "false"


def assert_cursor_description(cur, trino_type, size=None, precision=None, scale=None):
assert cur.description[0][1] == trino_type
assert cur.description[0][2] is None
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1011,7 +1011,7 @@ def json(self):
result = query.execute(additional_http_headers=additional_headers)

# Validate the the post function was called with the right argguments
mock_post.assert_called_once_with(sql, additional_headers)
mock_post.assert_called_once_with(sql, additional_headers, None)

# Validate the result is an instance of TrinoResult
assert isinstance(result, TrinoResult)
Expand Down
27 changes: 23 additions & 4 deletions trino/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,9 @@ def transaction_id(self, value):

@property
def http_headers(self) -> Dict[str, str]:
return self._create_headers()

def _create_headers(self, statement_properties: Dict[str, Any] = None):
headers = {}

headers[constants.HEADER_CATALOG] = self._client_session.catalog
Expand All @@ -466,10 +469,13 @@ def http_headers(self) -> Dict[str, str]:
if self._client_session.client_tags is not None and len(self._client_session.client_tags) > 0:
headers[constants.HEADER_CLIENT_TAGS] = ",".join(self._client_session.client_tags)

session_properties = copy.deepcopy(self._client_session.properties)
if statement_properties is not None:
session_properties.update(statement_properties)
headers[constants.HEADER_SESSION] = ",".join(
# ``name`` must not contain ``=``
"{}={}".format(name, urllib.parse.quote(str(value)))
for name, value in self._client_session.properties.items()
for name, value in session_properties.items()
)

if len(self._client_session.prepared_statements) != 0:
Expand Down Expand Up @@ -503,6 +509,9 @@ def http_headers(self) -> Dict[str, str]:

return headers

def with_statement_properties(self, statement_properties: Optional[Dict[str, Any]]):
return self._create_headers(statement_properties)

@property
def max_attempts(self) -> int:
return self._max_attempts
Expand Down Expand Up @@ -543,11 +552,15 @@ def statement_url(self) -> str:
def next_uri(self) -> Optional[str]:
return self._next_uri

def post(self, sql: str, additional_http_headers: Optional[Dict[str, Any]] = None):
def post(
self, sql: str,
additional_http_headers: Optional[Dict[str, Any]] = None,
statement_properties: Optional[Dict[str, Any]] = None,
):
data = sql.encode("utf-8")
# Deep copy of the http_headers dict since they may be modified for this
# request by the provided additional_http_headers
http_headers = copy.deepcopy(self.http_headers)
http_headers = copy.deepcopy(self.with_statement_properties(statement_properties))

# Update the request headers with the additional_http_headers
http_headers.update(additional_http_headers or {})
Expand Down Expand Up @@ -734,6 +747,7 @@ def __init__(
request: TrinoRequest,
query: str,
legacy_primitive_types: bool = False,
statement_properties: Optional[Dict[str, Any]] = None,
) -> None:
self._query_id: Optional[str] = None
self._stats: Dict[Any, Any] = {}
Expand All @@ -749,6 +763,7 @@ def __init__(
self._query = query
self._result: Optional[TrinoResult] = None
self._legacy_primitive_types = legacy_primitive_types
self._statement_properties = statement_properties
self._row_mapper: Optional[RowMapper] = None

@property
Expand Down Expand Up @@ -803,7 +818,11 @@ def execute(self, additional_http_headers=None) -> TrinoResult:
if self.cancelled:
raise exceptions.TrinoUserError("Query has been cancelled", self.query_id)

response = self._request.post(self._query, additional_http_headers)
response = self._request.post(
self._query,
additional_http_headers,
self._statement_properties,
)
status = self._request.process(response)
self._info_uri = status.info_uri
self._query_id = status.id
Expand Down
48 changes: 38 additions & 10 deletions trino/dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,10 @@ def _create_request(self):
self.request_timeout,
)

def cursor(self, legacy_primitive_types: bool = None):
def cursor(
self, legacy_primitive_types: bool = None,
statement_properties: Optional[Dict[str, Any]] = None,
):
"""Return a new :py:class:`Cursor` object using the connection."""
if self.isolation_level != IsolationLevel.AUTOCOMMIT:
if self.transaction is None:
Expand All @@ -220,7 +223,8 @@ def cursor(self, legacy_primitive_types: bool = None):
self,
request,
# if legacy_primitive_types is not explicitly set in Cursor, take from Connection
legacy_primitive_types if legacy_primitive_types is not None else self.legacy_primitive_types
legacy_primitive_types if legacy_primitive_types is not None else self.legacy_primitive_types,
statement_properties
)


Expand Down Expand Up @@ -271,7 +275,13 @@ class Cursor(object):
"""

def __init__(self, connection, request, legacy_primitive_types: bool = False):
def __init__(
self,
connection,
request,
legacy_primitive_types: bool = False,
statement_properties: Optional[Dict[str, Any]] = None
):
if not isinstance(connection, Connection):
raise ValueError(
"connection must be a Connection object: {}".format(type(connection))
Expand All @@ -283,6 +293,7 @@ def __init__(self, connection, request, legacy_primitive_types: bool = False):
self._iterator = None
self._query = None
self._legacy_primitive_types = legacy_primitive_types
self._statement_properties = statement_properties

def __iter__(self):
return self._iterator
Expand Down Expand Up @@ -370,8 +381,12 @@ def _prepare_statement(self, statement: str, name: str) -> None:
:param name: name that will be assigned to the prepared statement.
"""
sql = f"PREPARE {name} FROM {statement}"
query = trino.client.TrinoQuery(self.connection._create_request(), query=sql,
legacy_primitive_types=self._legacy_primitive_types)
query = trino.client.TrinoQuery(
self.connection._create_request(),
query=sql,
legacy_primitive_types=self._legacy_primitive_types,
statement_properties=self._statement_properties,
)
query.execute()

def _execute_prepared_statement(
Expand All @@ -380,7 +395,12 @@ def _execute_prepared_statement(
params
):
sql = 'EXECUTE ' + statement_name + ' USING ' + ','.join(map(self._format_prepared_param, params))
return trino.client.TrinoQuery(self._request, query=sql, legacy_primitive_types=self._legacy_primitive_types)
return trino.client.TrinoQuery(
self._request,
query=sql,
legacy_primitive_types=self._legacy_primitive_types,
statement_properties=self._statement_properties,
)

def _format_prepared_param(self, param):
"""
Expand Down Expand Up @@ -460,8 +480,12 @@ def _format_prepared_param(self, param):

def _deallocate_prepared_statement(self, statement_name: str) -> None:
sql = 'DEALLOCATE PREPARE ' + statement_name
query = trino.client.TrinoQuery(self.connection._create_request(), query=sql,
legacy_primitive_types=self._legacy_primitive_types)
query = trino.client.TrinoQuery(
self.connection._create_request(),
query=sql,
legacy_primitive_types=self._legacy_primitive_types,
statement_properties=self._statement_properties,
)
query.execute()

def _generate_unique_statement_name(self):
Expand Down Expand Up @@ -492,8 +516,12 @@ def execute(self, operation, params=None):
self._deallocate_prepared_statement(statement_name)

else:
self._query = trino.client.TrinoQuery(self._request, query=operation,
legacy_primitive_types=self._legacy_primitive_types)
self._query = trino.client.TrinoQuery(
self._request,
query=operation,
legacy_primitive_types=self._legacy_primitive_types,
statement_properties=self._statement_properties,
)
self._iterator = iter(self._query.execute())
return self

Expand Down

0 comments on commit acd2a24

Please sign in to comment.