Skip to content

Commit 5c51cb8

Browse files
mdesmetebyhr
authored andcommitted
Support USE catalog.schema and USE schema
1 parent f44b792 commit 5c51cb8

File tree

3 files changed

+74
-2
lines changed

3 files changed

+74
-2
lines changed

tests/integration/test_dbapi_integration.py

+51
Original file line numberDiff line numberDiff line change
@@ -983,3 +983,54 @@ def retrieve_client_tags_from_query(run_trino, client_tags):
983983

984984
query_client_tags = query_info['session']['clientTags']
985985
return query_client_tags
986+
987+
988+
@pytest.mark.skipif(trino_version() == '351', reason="current_catalog not supported in older Trino versions")
989+
def test_use_catalog_schema(trino_connection):
990+
cur = trino_connection.cursor()
991+
cur.execute('SELECT current_catalog, current_schema')
992+
result = cur.fetchall()
993+
assert result[0][0] is None
994+
assert result[0][1] is None
995+
996+
cur.execute('USE tpch.tiny')
997+
cur.fetchall()
998+
cur.execute('SELECT current_catalog, current_schema')
999+
result = cur.fetchall()
1000+
assert result[0][0] == 'tpch'
1001+
assert result[0][1] == 'tiny'
1002+
1003+
cur.execute('USE tpcds.sf1')
1004+
cur.fetchall()
1005+
cur.execute('SELECT current_catalog, current_schema')
1006+
result = cur.fetchall()
1007+
assert result[0][0] == 'tpcds'
1008+
assert result[0][1] == 'sf1'
1009+
1010+
1011+
@pytest.mark.skipif(trino_version() == '351', reason="current_catalog not supported in older Trino versions")
1012+
def test_use_catalog(run_trino):
1013+
_, host, port = run_trino
1014+
1015+
trino_connection = trino.dbapi.Connection(
1016+
host=host, port=port, user="test", source="test", catalog="tpch", max_attempts=1
1017+
)
1018+
cur = trino_connection.cursor()
1019+
cur.execute('SELECT current_catalog, current_schema')
1020+
result = cur.fetchall()
1021+
assert result[0][0] == 'tpch'
1022+
assert result[0][1] is None
1023+
1024+
cur.execute('USE tiny')
1025+
cur.fetchall()
1026+
cur.execute('SELECT current_catalog, current_schema')
1027+
result = cur.fetchall()
1028+
assert result[0][0] == 'tpch'
1029+
assert result[0][1] == 'tiny'
1030+
1031+
cur.execute('USE sf1')
1032+
cur.fetchall()
1033+
cur.execute('SELECT current_catalog, current_schema')
1034+
result = cur.fetchall()
1035+
assert result[0][0] == 'tpch'
1036+
assert result[0][1] == 'sf1'

trino/client.py

+20-2
Original file line numberDiff line numberDiff line change
@@ -126,11 +126,23 @@ def user(self):
126126

127127
@property
128128
def catalog(self):
129-
return self._catalog
129+
with self._object_lock:
130+
return self._catalog
131+
132+
@catalog.setter
133+
def catalog(self, catalog):
134+
with self._object_lock:
135+
self._catalog = catalog
130136

131137
@property
132138
def schema(self):
133-
return self._schema
139+
with self._object_lock:
140+
return self._schema
141+
142+
@schema.setter
143+
def schema(self, schema):
144+
with self._object_lock:
145+
self._schema = schema
134146

135147
@property
136148
def source(self):
@@ -489,6 +501,12 @@ def process(self, http_response) -> TrinoStatus:
489501
):
490502
self._client_session.properties[key] = value
491503

504+
if constants.HEADER_SET_CATALOG in http_response.headers:
505+
self._client_session.catalog = http_response.headers[constants.HEADER_SET_CATALOG]
506+
507+
if constants.HEADER_SET_SCHEMA in http_response.headers:
508+
self._client_session.schema = http_response.headers[constants.HEADER_SET_SCHEMA]
509+
492510
self._next_uri = response.get("nextUri")
493511

494512
return TrinoStatus(

trino/constants.py

+3
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,6 @@
4545
HEADER_PREPARED_STATEMENT = 'X-Trino-Prepared-Statement'
4646
HEADER_ADDED_PREPARE = 'X-Trino-Added-Prepare'
4747
HEADER_DEALLOCATED_PREPARE = 'X-Trino-Deallocated-Prepare'
48+
49+
HEADER_SET_SCHEMA = "X-Trino-Set-Schema"
50+
HEADER_SET_CATALOG = "X-Trino-Set-Catalog"

0 commit comments

Comments
 (0)