Skip to content

Commit 43ed692

Browse files
mdesmethashhar
authored andcommitted
Support spooled protocol
1 parent d9c02a2 commit 43ed692

File tree

8 files changed

+440
-15
lines changed

8 files changed

+440
-15
lines changed

README.md

+26
Original file line numberDiff line numberDiff line change
@@ -469,6 +469,32 @@ conn = connect(
469469
)
470470
```
471471

472+
## Spooled protocol
473+
474+
The client spooling protocol requires [a Trino server with spooling protocol support](https://trino.io/docs/current/client/client-protocol.html#spooling-protocol).
475+
476+
Enable the spooling protocol by specifying a supported encoding in the `encoding` parameter:
477+
478+
Supported encodings are `json`, `json+lz4` and `json+zstd`.
479+
480+
```python
481+
from trino.dbapi import connect
482+
483+
conn = connect(
484+
encoding="json+zstd"
485+
)
486+
```
487+
488+
or a list of supported encodings in order of preference:
489+
490+
```python
491+
from trino.dbapi import connect
492+
493+
conn = connect(
494+
encoding=["json+zstd", "json"]
495+
)
496+
```
497+
472498
## Transactions
473499

474500
The client runs by default in *autocommit* mode. To enable transactions, set

setup.py

+2
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,13 @@
8383
],
8484
python_requires=">=3.9",
8585
install_requires=[
86+
"lz4",
8687
"python-dateutil",
8788
"pytz",
8889
# requests CVE https://github.com/advisories/GHSA-j8r2-6x86-q33q
8990
"requests>=2.31.0",
9091
"tzlocal",
92+
"zstandard",
9193
],
9294
extras_require={
9395
"all": all_require,

tests/integration/test_dbapi_integration.py

+22-7
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,13 @@
3838
from trino.transaction import IsolationLevel
3939

4040

41-
@pytest.fixture
42-
def trino_connection(run_trino):
41+
@pytest.fixture(params=[None, "json+zstd", "json+lz4", "json"])
42+
def trino_connection(request, run_trino):
4343
host, port = run_trino
44+
encoding = request.param
4445

4546
yield trino.dbapi.Connection(
46-
host=host, port=port, user="test", source="test", max_attempts=1
47+
host=host, port=port, user="test", source="test", max_attempts=1, encoding=encoding
4748
)
4849

4950

@@ -1831,8 +1832,8 @@ def test_prepared_statement_capability_autodetection(legacy_prepared_statements,
18311832

18321833

18331834
@pytest.mark.skipif(
1834-
trino_version() <= '464',
1835-
reason="spooled protocol was introduced in version 464"
1835+
trino_version() <= 466,
1836+
reason="spooling protocol was introduced in version 466"
18361837
)
18371838
def test_select_query_spooled_segments(trino_connection):
18381839
cur = trino_connection.cursor()
@@ -1842,8 +1843,22 @@ def test_select_query_spooled_segments(trino_connection):
18421843
stop => 5,
18431844
step => 1)) n""")
18441845
rows = cur.fetchall()
1845-
# TODO: improve test
1846-
assert len(rows) > 0
1846+
assert len(rows) == 300875
1847+
for row in rows:
1848+
assert isinstance(row[0], int), f"Expected integer for orderkey, got {type(row[0])}"
1849+
assert isinstance(row[1], int), f"Expected integer for partkey, got {type(row[1])}"
1850+
assert isinstance(row[2], int), f"Expected integer for suppkey, got {type(row[2])}"
1851+
assert isinstance(row[3], int), f"Expected int for linenumber, got {type(row[3])}"
1852+
assert isinstance(row[4], float), f"Expected float for quantity, got {type(row[4])}"
1853+
assert isinstance(row[5], float), f"Expected float for extendedprice, got {type(row[5])}"
1854+
assert isinstance(row[6], float), f"Expected float for discount, got {type(row[6])}"
1855+
assert isinstance(row[7], float), f"Expected string for tax, got {type(row[7])}"
1856+
assert isinstance(row[8], str), f"Expected string for returnflag, got {type(row[8])}"
1857+
assert isinstance(row[9], str), f"Expected string for linestatus, got {type(row[9])}"
1858+
assert isinstance(row[10], date), f"Expected date for shipdate, got {type(row[10])}"
1859+
assert isinstance(row[11], date), f"Expected date for commitdate, got {type(row[11])}"
1860+
assert isinstance(row[12], date), f"Expected date for receiptdate, got {type(row[12])}"
1861+
assert isinstance(row[13], str), f"Expected string for shipinstruct, got {type(row[13])}"
18471862

18481863

18491864
def get_cursor(legacy_prepared_statements, run_trino):

tests/integration/test_types_integration.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,18 @@
1717
from tests.integration.conftest import trino_version
1818

1919

20-
@pytest.fixture
21-
def trino_connection(run_trino):
20+
@pytest.fixture(params=[None, "json+zstd", "json+lz4", "json"])
21+
def trino_connection(request, run_trino):
2222
host, port = run_trino
23+
encoding = request.param
2324

2425
yield trino.dbapi.Connection(
25-
host=host, port=port, user="test", source="test", max_attempts=1
26+
host=host,
27+
port=port,
28+
user="test",
29+
source="test",
30+
max_attempts=1,
31+
encoding=encoding
2632
)
2733

2834

tests/unit/test_client.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def test_request_headers(mock_get_and_post):
9797
accept_encoding_value = "identity,deflate,gzip"
9898
client_info_header = constants.HEADER_CLIENT_INFO
9999
client_info_value = "some_client_info"
100+
encoding = "json+zstd"
100101

101102
with pytest.deprecated_call():
102103
req = TrinoRequest(
@@ -109,6 +110,7 @@ def test_request_headers(mock_get_and_post):
109110
catalog=catalog,
110111
schema=schema,
111112
timezone=timezone,
113+
encoding=encoding,
112114
headers={
113115
accept_encoding_header: accept_encoding_value,
114116
client_info_header: client_info_value,
@@ -143,7 +145,8 @@ def assert_headers(headers):
143145
"catalog2=" + urllib.parse.quote("ROLE{catalog2_role}")
144146
)
145147
assert headers["User-Agent"] == f"{constants.CLIENT_NAME}/{__version__}"
146-
assert len(headers.keys()) == 13
148+
assert headers[constants.HEADER_ENCODING] == encoding
149+
assert len(headers.keys()) == 14
147150

148151
req.post("URL")
149152
_, post_kwargs = post.call_args

0 commit comments

Comments
 (0)