Skip to content

Commit

Permalink
Support spooled protocol
Browse files Browse the repository at this point in the history
  • Loading branch information
mdesmet committed Dec 23, 2024
1 parent e1dabdd commit 0c749bf
Show file tree
Hide file tree
Showing 8 changed files with 420 additions and 15 deletions.
24 changes: 24 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,30 @@ conn = connect(
)
```

## Spooled protocol

The client spooling protocol requires [a Trino server with spooling protocol support](https://trino.io/docs/current/client/client-protocol.html#spooling-protocol).

Enable the spooling protocol by specifying a supported encoding in the `encoding` parameter:

```python
from trino.dbapi import connect

conn = connect(
encoding="json+zstd"
)
```

or a list of supported encodings:

```python
from trino.dbapi import connect

conn = connect(
encoding=["json+zstd", "json"]
)
```

## Transactions

The client runs by default in *autocommit* mode. To enable transactions, set
Expand Down
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,13 @@
],
python_requires=">=3.9",
install_requires=[
"lz4",
"python-dateutil",
"pytz",
# requests CVE https://github.com/advisories/GHSA-j8r2-6x86-q33q
"requests>=2.31.0",
"tzlocal",
"zstandard",
],
extras_require={
"all": all_require,
Expand Down
29 changes: 22 additions & 7 deletions tests/integration/test_dbapi_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,13 @@
from trino.transaction import IsolationLevel


@pytest.fixture
def trino_connection(run_trino):
@pytest.fixture(params=[None, "json+zstd", "json+lz4", "json"])
def trino_connection(request, run_trino):
host, port = run_trino
encoding = request.param

yield trino.dbapi.Connection(
host=host, port=port, user="test", source="test", max_attempts=1
host=host, port=port, user="test", source="test", max_attempts=1, encoding=encoding
)


Expand Down Expand Up @@ -1831,8 +1832,8 @@ def test_prepared_statement_capability_autodetection(legacy_prepared_statements,


@pytest.mark.skipif(
trino_version() <= '464',
reason="spooled protocol was introduced in version 464"
trino_version() <= 466,
reason="spooling protocol was introduced in version 466"
)
def test_select_query_spooled_segments(trino_connection):
cur = trino_connection.cursor()
Expand All @@ -1842,8 +1843,22 @@ def test_select_query_spooled_segments(trino_connection):
stop => 5,
step => 1)) n""")
rows = cur.fetchall()
# TODO: improve test
assert len(rows) > 0
assert len(rows) == 300875
for row in rows:
assert isinstance(row[0], int), f"Expected integer for orderkey, got {type(row[0])}"
assert isinstance(row[1], int), f"Expected integer for partkey, got {type(row[1])}"
assert isinstance(row[2], int), f"Expected integer for suppkey, got {type(row[2])}"
assert isinstance(row[3], int), f"Expected int for linenumber, got {type(row[3])}"
assert isinstance(row[4], float), f"Expected float for quantity, got {type(row[4])}"
assert isinstance(row[5], float), f"Expected float for extendedprice, got {type(row[5])}"
assert isinstance(row[6], float), f"Expected float for discount, got {type(row[6])}"
assert isinstance(row[7], float), f"Expected string for tax, got {type(row[7])}"
assert isinstance(row[8], str), f"Expected string for returnflag, got {type(row[8])}"
assert isinstance(row[9], str), f"Expected string for linestatus, got {type(row[9])}"
assert isinstance(row[10], date), f"Expected date for shipdate, got {type(row[10])}"
assert isinstance(row[11], date), f"Expected date for commitdate, got {type(row[11])}"
assert isinstance(row[12], date), f"Expected date for receiptdate, got {type(row[12])}"
assert isinstance(row[13], str), f"Expected string for shipinstruct, got {type(row[13])}"


def get_cursor(legacy_prepared_statements, run_trino):
Expand Down
12 changes: 9 additions & 3 deletions tests/integration/test_types_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,18 @@
from tests.integration.conftest import trino_version


@pytest.fixture
def trino_connection(run_trino):
@pytest.fixture(params=[None, "json+zstd", "json+lz4", "json"])
def trino_connection(request, run_trino):
host, port = run_trino
encoding = request.param

yield trino.dbapi.Connection(
host=host, port=port, user="test", source="test", max_attempts=1
host=host,
port=port,
user="test",
source="test",
max_attempts=1,
encoding=encoding
)


Expand Down
5 changes: 4 additions & 1 deletion tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def test_request_headers(mock_get_and_post):
accept_encoding_value = "identity,deflate,gzip"
client_info_header = constants.HEADER_CLIENT_INFO
client_info_value = "some_client_info"
encoding = "json+zstd"

with pytest.deprecated_call():
req = TrinoRequest(
Expand All @@ -109,6 +110,7 @@ def test_request_headers(mock_get_and_post):
catalog=catalog,
schema=schema,
timezone=timezone,
encoding=encoding,
headers={
accept_encoding_header: accept_encoding_value,
client_info_header: client_info_value,
Expand Down Expand Up @@ -143,7 +145,8 @@ def assert_headers(headers):
"catalog2=" + urllib.parse.quote("ROLE{catalog2_role}")
)
assert headers["User-Agent"] == f"{constants.CLIENT_NAME}/{__version__}"
assert len(headers.keys()) == 13
assert headers[constants.HEADER_ENCODING] == encoding
assert len(headers.keys()) == 14

req.post("URL")
_, post_kwargs = post.call_args
Expand Down
Loading

0 comments on commit 0c749bf

Please sign in to comment.