Skip to content

Commit acd2a24

Browse files
committed
Support setting session properties on a individual statement
1 parent 7c66e94 commit acd2a24

File tree

5 files changed

+101
-15
lines changed

5 files changed

+101
-15
lines changed

README.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,29 @@ conn = trino.dbapi.connect(
359359
)
360360
```
361361

362+
## Session properties
363+
364+
Session properties can be set on the connection
365+
366+
```python
367+
import trino
368+
conn = trino.dbapi.connect(
369+
...,
370+
session_properties={"query_max_run_time": "1d"}
371+
)
372+
```
373+
374+
### Statement properties
375+
376+
It's also possible to set a session property for a specific statement by setting it on the Cursor. This is especially handy in the case of hive partitions.
377+
378+
```python
379+
import trino
380+
conn = trino.dbapi.connect()
381+
cur = conn.cursor(statement_properties={"hive.insert_existing_partitions_behavior": "OVERWRITE"})
382+
cur.execute("INSERT INTO hive_partitioned_table SELECT * from another_table")
383+
```
384+
362385
## Timezone
363386

364387
The time zone for the session can be explicitly set using the IANA time zone

tests/integration/test_dbapi_integration.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1303,6 +1303,22 @@ def test_rowcount_insert(trino_connection):
13031303
assert cur.rowcount == 1
13041304

13051305

1306+
def test_statement_properties(trino_connection):
1307+
exchange_compression_statement = "SHOW SESSION LIKE 'exchange_compression'"
1308+
cur = trino_connection.cursor()
1309+
cur.execute(exchange_compression_statement)
1310+
result = cur.fetchall()
1311+
assert result[0][1] == "false"
1312+
cur = trino_connection.cursor(statement_properties={"exchange_compression": True})
1313+
cur.execute(exchange_compression_statement)
1314+
result = cur.fetchall()
1315+
assert result[0][1] == "True"
1316+
cur = trino_connection.cursor()
1317+
cur.execute(exchange_compression_statement)
1318+
result = cur.fetchall()
1319+
assert result[0][1] == "false"
1320+
1321+
13061322
def assert_cursor_description(cur, trino_type, size=None, precision=None, scale=None):
13071323
assert cur.description[0][1] == trino_type
13081324
assert cur.description[0][2] is None

tests/unit/test_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1011,7 +1011,7 @@ def json(self):
10111011
result = query.execute(additional_http_headers=additional_headers)
10121012

10131013
# Validate the the post function was called with the right argguments
1014-
mock_post.assert_called_once_with(sql, additional_headers)
1014+
mock_post.assert_called_once_with(sql, additional_headers, None)
10151015

10161016
# Validate the result is an instance of TrinoResult
10171017
assert isinstance(result, TrinoResult)

trino/client.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,9 @@ def transaction_id(self, value):
449449

450450
@property
451451
def http_headers(self) -> Dict[str, str]:
452+
return self._create_headers()
453+
454+
def _create_headers(self, statement_properties: Dict[str, Any] = None):
452455
headers = {}
453456

454457
headers[constants.HEADER_CATALOG] = self._client_session.catalog
@@ -466,10 +469,13 @@ def http_headers(self) -> Dict[str, str]:
466469
if self._client_session.client_tags is not None and len(self._client_session.client_tags) > 0:
467470
headers[constants.HEADER_CLIENT_TAGS] = ",".join(self._client_session.client_tags)
468471

472+
session_properties = copy.deepcopy(self._client_session.properties)
473+
if statement_properties is not None:
474+
session_properties.update(statement_properties)
469475
headers[constants.HEADER_SESSION] = ",".join(
470476
# ``name`` must not contain ``=``
471477
"{}={}".format(name, urllib.parse.quote(str(value)))
472-
for name, value in self._client_session.properties.items()
478+
for name, value in session_properties.items()
473479
)
474480

475481
if len(self._client_session.prepared_statements) != 0:
@@ -503,6 +509,9 @@ def http_headers(self) -> Dict[str, str]:
503509

504510
return headers
505511

512+
def with_statement_properties(self, statement_properties: Optional[Dict[str, Any]]):
513+
return self._create_headers(statement_properties)
514+
506515
@property
507516
def max_attempts(self) -> int:
508517
return self._max_attempts
@@ -543,11 +552,15 @@ def statement_url(self) -> str:
543552
def next_uri(self) -> Optional[str]:
544553
return self._next_uri
545554

546-
def post(self, sql: str, additional_http_headers: Optional[Dict[str, Any]] = None):
555+
def post(
556+
self, sql: str,
557+
additional_http_headers: Optional[Dict[str, Any]] = None,
558+
statement_properties: Optional[Dict[str, Any]] = None,
559+
):
547560
data = sql.encode("utf-8")
548561
# Deep copy of the http_headers dict since they may be modified for this
549562
# request by the provided additional_http_headers
550-
http_headers = copy.deepcopy(self.http_headers)
563+
http_headers = copy.deepcopy(self.with_statement_properties(statement_properties))
551564

552565
# Update the request headers with the additional_http_headers
553566
http_headers.update(additional_http_headers or {})
@@ -734,6 +747,7 @@ def __init__(
734747
request: TrinoRequest,
735748
query: str,
736749
legacy_primitive_types: bool = False,
750+
statement_properties: Optional[Dict[str, Any]] = None,
737751
) -> None:
738752
self._query_id: Optional[str] = None
739753
self._stats: Dict[Any, Any] = {}
@@ -749,6 +763,7 @@ def __init__(
749763
self._query = query
750764
self._result: Optional[TrinoResult] = None
751765
self._legacy_primitive_types = legacy_primitive_types
766+
self._statement_properties = statement_properties
752767
self._row_mapper: Optional[RowMapper] = None
753768

754769
@property
@@ -803,7 +818,11 @@ def execute(self, additional_http_headers=None) -> TrinoResult:
803818
if self.cancelled:
804819
raise exceptions.TrinoUserError("Query has been cancelled", self.query_id)
805820

806-
response = self._request.post(self._query, additional_http_headers)
821+
response = self._request.post(
822+
self._query,
823+
additional_http_headers,
824+
self._statement_properties,
825+
)
807826
status = self._request.process(response)
808827
self._info_uri = status.info_uri
809828
self._query_id = status.id

trino/dbapi.py

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,10 @@ def _create_request(self):
207207
self.request_timeout,
208208
)
209209

210-
def cursor(self, legacy_primitive_types: bool = None):
210+
def cursor(
211+
self, legacy_primitive_types: bool = None,
212+
statement_properties: Optional[Dict[str, Any]] = None,
213+
):
211214
"""Return a new :py:class:`Cursor` object using the connection."""
212215
if self.isolation_level != IsolationLevel.AUTOCOMMIT:
213216
if self.transaction is None:
@@ -220,7 +223,8 @@ def cursor(self, legacy_primitive_types: bool = None):
220223
self,
221224
request,
222225
# if legacy_primitive_types is not explicitly set in Cursor, take from Connection
223-
legacy_primitive_types if legacy_primitive_types is not None else self.legacy_primitive_types
226+
legacy_primitive_types if legacy_primitive_types is not None else self.legacy_primitive_types,
227+
statement_properties
224228
)
225229

226230

@@ -271,7 +275,13 @@ class Cursor(object):
271275
272276
"""
273277

274-
def __init__(self, connection, request, legacy_primitive_types: bool = False):
278+
def __init__(
279+
self,
280+
connection,
281+
request,
282+
legacy_primitive_types: bool = False,
283+
statement_properties: Optional[Dict[str, Any]] = None
284+
):
275285
if not isinstance(connection, Connection):
276286
raise ValueError(
277287
"connection must be a Connection object: {}".format(type(connection))
@@ -283,6 +293,7 @@ def __init__(self, connection, request, legacy_primitive_types: bool = False):
283293
self._iterator = None
284294
self._query = None
285295
self._legacy_primitive_types = legacy_primitive_types
296+
self._statement_properties = statement_properties
286297

287298
def __iter__(self):
288299
return self._iterator
@@ -370,8 +381,12 @@ def _prepare_statement(self, statement: str, name: str) -> None:
370381
:param name: name that will be assigned to the prepared statement.
371382
"""
372383
sql = f"PREPARE {name} FROM {statement}"
373-
query = trino.client.TrinoQuery(self.connection._create_request(), query=sql,
374-
legacy_primitive_types=self._legacy_primitive_types)
384+
query = trino.client.TrinoQuery(
385+
self.connection._create_request(),
386+
query=sql,
387+
legacy_primitive_types=self._legacy_primitive_types,
388+
statement_properties=self._statement_properties,
389+
)
375390
query.execute()
376391

377392
def _execute_prepared_statement(
@@ -380,7 +395,12 @@ def _execute_prepared_statement(
380395
params
381396
):
382397
sql = 'EXECUTE ' + statement_name + ' USING ' + ','.join(map(self._format_prepared_param, params))
383-
return trino.client.TrinoQuery(self._request, query=sql, legacy_primitive_types=self._legacy_primitive_types)
398+
return trino.client.TrinoQuery(
399+
self._request,
400+
query=sql,
401+
legacy_primitive_types=self._legacy_primitive_types,
402+
statement_properties=self._statement_properties,
403+
)
384404

385405
def _format_prepared_param(self, param):
386406
"""
@@ -460,8 +480,12 @@ def _format_prepared_param(self, param):
460480

461481
def _deallocate_prepared_statement(self, statement_name: str) -> None:
462482
sql = 'DEALLOCATE PREPARE ' + statement_name
463-
query = trino.client.TrinoQuery(self.connection._create_request(), query=sql,
464-
legacy_primitive_types=self._legacy_primitive_types)
483+
query = trino.client.TrinoQuery(
484+
self.connection._create_request(),
485+
query=sql,
486+
legacy_primitive_types=self._legacy_primitive_types,
487+
statement_properties=self._statement_properties,
488+
)
465489
query.execute()
466490

467491
def _generate_unique_statement_name(self):
@@ -492,8 +516,12 @@ def execute(self, operation, params=None):
492516
self._deallocate_prepared_statement(statement_name)
493517

494518
else:
495-
self._query = trino.client.TrinoQuery(self._request, query=operation,
496-
legacy_primitive_types=self._legacy_primitive_types)
519+
self._query = trino.client.TrinoQuery(
520+
self._request,
521+
query=operation,
522+
legacy_primitive_types=self._legacy_primitive_types,
523+
statement_properties=self._statement_properties,
524+
)
497525
self._iterator = iter(self._query.execute())
498526
return self
499527

0 commit comments

Comments
 (0)