From 0c749bf70f30d2add02dff6cf864a88101901aab Mon Sep 17 00:00:00 2001 From: Michiel De Smet Date: Thu, 21 Nov 2024 14:36:30 +0800 Subject: [PATCH] Support spooled protocol --- README.md | 24 ++ setup.py | 2 + tests/integration/test_dbapi_integration.py | 29 +- tests/integration/test_types_integration.py | 12 +- tests/unit/test_client.py | 5 +- trino/client.py | 349 +++++++++++++++++++- trino/constants.py | 1 + trino/dbapi.py | 13 + 8 files changed, 420 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index 3c911e4e..6f0eb90a 100644 --- a/README.md +++ b/README.md @@ -469,6 +469,30 @@ conn = connect( ) ``` +## Spooled protocol + +The client spooling protocol requires [a Trino server with spooling protocol support](https://trino.io/docs/current/client/client-protocol.html#spooling-protocol). + +Enable the spooling protocol by specifying a supported encoding in the `encoding` parameter: + +```python +from trino.dbapi import connect + +conn = connect( + encoding="json+zstd" +) +``` + +or a list of supported encodings: + +```python +from trino.dbapi import connect + +conn = connect( + encoding=["json+zstd", "json"] +) +``` + ## Transactions The client runs by default in *autocommit* mode. To enable transactions, set diff --git a/setup.py b/setup.py index b8b83b1d..e497ab36 100755 --- a/setup.py +++ b/setup.py @@ -83,11 +83,13 @@ ], python_requires=">=3.9", install_requires=[ + "lz4", "python-dateutil", "pytz", # requests CVE https://github.com/advisories/GHSA-j8r2-6x86-q33q "requests>=2.31.0", "tzlocal", + "zstandard", ], extras_require={ "all": all_require, diff --git a/tests/integration/test_dbapi_integration.py b/tests/integration/test_dbapi_integration.py index 01921474..fbd60b9d 100644 --- a/tests/integration/test_dbapi_integration.py +++ b/tests/integration/test_dbapi_integration.py @@ -38,12 +38,13 @@ from trino.transaction import IsolationLevel -@pytest.fixture -def trino_connection(run_trino): +@pytest.fixture(params=[None, "json+zstd", "json+lz4", "json"]) +def trino_connection(request, run_trino): host, port = run_trino + encoding = request.param yield trino.dbapi.Connection( - host=host, port=port, user="test", source="test", max_attempts=1 + host=host, port=port, user="test", source="test", max_attempts=1, encoding=encoding ) @@ -1831,8 +1832,8 @@ def test_prepared_statement_capability_autodetection(legacy_prepared_statements, @pytest.mark.skipif( - trino_version() <= '464', - reason="spooled protocol was introduced in version 464" + trino_version() <= 466, + reason="spooling protocol was introduced in version 466" ) def test_select_query_spooled_segments(trino_connection): cur = trino_connection.cursor() @@ -1842,8 +1843,22 @@ def test_select_query_spooled_segments(trino_connection): stop => 5, step => 1)) n""") rows = cur.fetchall() - # TODO: improve test - assert len(rows) > 0 + assert len(rows) == 300875 + for row in rows: + assert isinstance(row[0], int), f"Expected integer for orderkey, got {type(row[0])}" + assert isinstance(row[1], int), f"Expected integer for partkey, got {type(row[1])}" + assert isinstance(row[2], int), f"Expected integer for suppkey, got {type(row[2])}" + assert isinstance(row[3], int), f"Expected int for linenumber, got {type(row[3])}" + assert isinstance(row[4], float), f"Expected float for quantity, got {type(row[4])}" + assert isinstance(row[5], float), f"Expected float for extendedprice, got {type(row[5])}" + assert isinstance(row[6], float), f"Expected float for discount, got {type(row[6])}" + assert isinstance(row[7], float), f"Expected string for tax, got {type(row[7])}" + assert isinstance(row[8], str), f"Expected string for returnflag, got {type(row[8])}" + assert isinstance(row[9], str), f"Expected string for linestatus, got {type(row[9])}" + assert isinstance(row[10], date), f"Expected date for shipdate, got {type(row[10])}" + assert isinstance(row[11], date), f"Expected date for commitdate, got {type(row[11])}" + assert isinstance(row[12], date), f"Expected date for receiptdate, got {type(row[12])}" + assert isinstance(row[13], str), f"Expected string for shipinstruct, got {type(row[13])}" def get_cursor(legacy_prepared_statements, run_trino): diff --git a/tests/integration/test_types_integration.py b/tests/integration/test_types_integration.py index 4e595c78..cc927883 100644 --- a/tests/integration/test_types_integration.py +++ b/tests/integration/test_types_integration.py @@ -17,12 +17,18 @@ from tests.integration.conftest import trino_version -@pytest.fixture -def trino_connection(run_trino): +@pytest.fixture(params=[None, "json+zstd", "json+lz4", "json"]) +def trino_connection(request, run_trino): host, port = run_trino + encoding = request.param yield trino.dbapi.Connection( - host=host, port=port, user="test", source="test", max_attempts=1 + host=host, + port=port, + user="test", + source="test", + max_attempts=1, + encoding=encoding ) diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 653423a0..b33b72f5 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -97,6 +97,7 @@ def test_request_headers(mock_get_and_post): accept_encoding_value = "identity,deflate,gzip" client_info_header = constants.HEADER_CLIENT_INFO client_info_value = "some_client_info" + encoding = "json+zstd" with pytest.deprecated_call(): req = TrinoRequest( @@ -109,6 +110,7 @@ def test_request_headers(mock_get_and_post): catalog=catalog, schema=schema, timezone=timezone, + encoding=encoding, headers={ accept_encoding_header: accept_encoding_value, client_info_header: client_info_value, @@ -143,7 +145,8 @@ def assert_headers(headers): "catalog2=" + urllib.parse.quote("ROLE{catalog2_role}") ) assert headers["User-Agent"] == f"{constants.CLIENT_NAME}/{__version__}" - assert len(headers.keys()) == 13 + assert headers[constants.HEADER_ENCODING] == encoding + assert len(headers.keys()) == 14 req.post("URL") _, post_kwargs = post.call_args diff --git a/trino/client.py b/trino/client.py index da5e4047..298ece2b 100644 --- a/trino/client.py +++ b/trino/client.py @@ -34,27 +34,37 @@ """ from __future__ import annotations +import abc +import base64 import copy import functools +import json import os import random import re import threading import urllib.parse import warnings +from abc import abstractmethod +from collections.abc import Iterator from dataclasses import dataclass from datetime import datetime from email.utils import parsedate_to_datetime from time import sleep from typing import Any +from typing import cast from typing import Dict from typing import List +from typing import Literal from typing import Optional from typing import Tuple +from typing import TypedDict from typing import Union from zoneinfo import ZoneInfo +import lz4.block import requests +import zstandard from tzlocal import get_localzone_name # type: ignore import trino.logging @@ -64,7 +74,16 @@ from trino.mapper import RowMapper from trino.mapper import RowMapperFactory -__all__ = ["ClientSession", "TrinoQuery", "TrinoRequest", "PROXIES"] +__all__ = [ + "ClientSession", + "TrinoQuery", + "TrinoRequest", + "PROXIES", + "SpooledData", + "SpooledSegment", + "InlineSegment", + "Segment" +] logger = trino.logging.get_logger(__name__) @@ -114,6 +133,7 @@ class ClientSession: :param roles: roles for the current session. Some connectors do not support role management. See connector documentation for more details. :param timezone: The timezone for query processing. Defaults to the system's local timezone. + :param encoding: The encoding for the spooling protocol. Defaults to None. """ def __init__( @@ -130,6 +150,7 @@ def __init__( client_tags: Optional[List[str]] = None, roles: Optional[Union[Dict[str, str], str]] = None, timezone: Optional[str] = None, + encoding: Optional[Union[str, List[str]]] = None, ): self._object_lock = threading.Lock() self._prepared_statements: Dict[str, str] = {} @@ -148,6 +169,7 @@ def __init__( self._timezone = timezone or get_localzone_name() if timezone: # Check timezone validity ZoneInfo(timezone) + self._encoding = encoding @property def user(self) -> str: @@ -243,6 +265,11 @@ def timezone(self) -> str: with self._object_lock: return self._timezone + @property + def encoding(self): + with self._object_lock: + return self._encoding + @staticmethod def _format_roles(roles: Union[Dict[str, str], str]) -> Dict[str, str]: if isinstance(roles, str): @@ -308,7 +335,7 @@ class TrinoStatus: next_uri: Optional[str] update_type: Optional[str] update_count: Optional[int] - rows: List[Any] + rows: Union[List[Any], Dict[str, Any]] columns: List[Any] def __repr__(self): @@ -471,6 +498,14 @@ def http_headers(self) -> Dict[str, str]: headers[constants.HEADER_USER] = self._client_session.user headers[constants.HEADER_AUTHORIZATION_USER] = self._client_session.authorization_user headers[constants.HEADER_TIMEZONE] = self._client_session.timezone + if self._client_session.encoding is None: + pass + elif isinstance(self._client_session.encoding, list): + headers[constants.HEADER_ENCODING] = ",".join(self._client_session.encoding) + elif isinstance(self._client_session.encoding, str): + headers[constants.HEADER_ENCODING] = self._client_session.encoding + else: + raise ValueError("Invalid type for encoding: expected str or list") headers[constants.HEADER_CLIENT_CAPABILITIES] = 'PARAMETRIC_DATETIME' headers["user-agent"] = f"{constants.CLIENT_NAME}/{__version__}" if len(self._client_session.roles.values()): @@ -844,7 +879,7 @@ def _update_state(self, status): if status.columns: self._columns = status.columns - def fetch(self) -> List[List[Any]]: + def fetch(self) -> List[Union[List[Any]], Any]: """Continue fetching data for the current query_id""" try: response = self._request.get(self._request.next_uri) @@ -858,7 +893,32 @@ def fetch(self) -> List[List[Any]]: if not self._row_mapper: return [] - return self._row_mapper.map(status.rows) + rows = status.rows + if isinstance(status.rows, dict): + # spooling protocol + rows = cast(_SpooledProtocolResponseTO, rows) + segments = self._to_segments(rows) + return list(SegmentIterator(segments, self._row_mapper)) + elif isinstance(status.rows, list): + return self._row_mapper.map(rows) + else: + raise ValueError(f"Unexpected type: {type(status.rows)}") + + def _to_segments(self, rows: _SpooledProtocolResponseTO) -> SpooledData: + encoding = rows["encoding"] + segments = [] + for segment in rows["segments"]: + segment_type = segment["type"] + if segment_type == "inline": + inline_segment = cast(_InlineSegmentTO, segment) + segments.append(InlineSegment(inline_segment)) + elif segment_type == "spooled": + spooled_segment = cast(_SpooledSegmentTO, segment) + segments.append(SpooledSegment(spooled_segment, self._request)) + else: + raise ValueError(f"Unsupported segment type: {segment_type}") + + return SpooledData(encoding, segments) def cancel(self) -> None: """Cancel the current query""" @@ -934,3 +994,284 @@ def _parse_retry_after_header(retry_after): retry_date = parsedate_to_datetime(retry_after) now = datetime.utcnow() return (retry_date - now).total_seconds() + + +# Trino Spooled protocol transfer objects +class _SpooledProtocolResponseTO(TypedDict): + encoding: Literal["json", "json+std", "json+lz4"] + segments: List[_SegmentTO] + + +class _SegmentMetadataTO(TypedDict): + uncompressedSize: str + segmentSize: str + + +class _SegmentTO(_SegmentMetadataTO): + type: Literal["spooled", "inline"] + metadata: _SegmentMetadataTO + + +class _SpooledSegmentTO(_SegmentTO): + uri: str + ackUri: str + headers: Dict[str, List[str]] + + +class _InlineSegmentTO(_SegmentTO): + data: str + + +class Segment(abc.ABC): + """ + Abstract base class representing a segment of data produced by the spooling protocol. + + Attributes: + metadata (property): Metadata associated with the segment. + rows (property): Returns the decoded and mapped rows of data. + """ + def __init__(self, segment: _SegmentTO) -> None: + self._segment = segment + + @property + @abstractmethod + def data(self): + pass + + @property + def metadata(self) -> _SegmentMetadataTO: + return self._segment["metadata"] + + +class InlineSegment(Segment): + """ + A subclass of Segment that handles inline data segments. The data is base64 encoded and + requires mapping to rows using the provided row_mapper. + + Attributes: + rows (property): The rows of data in the segment, decoded and mapped from the base64 encoded data. + """ + def __init__(self, segment: _InlineSegmentTO) -> None: + super().__init__(segment) + self._segment = cast(_InlineSegmentTO, segment) + + @property + def data(self) -> bytes: + return base64.b64decode(self._segment["data"]) + + def __repr__(self): + return f"InlineSegment(metadata={self.metadata})" + + +class SpooledSegment(Segment): + """ + A subclass of Segment that handles spooled data segments, where data may be compressed and needs to be + retrieved via HTTP requests. The segment includes methods for acknowledging processing and loading the + segment from remote storage. + + Attributes: + rows (property): The rows of data, loaded and mapped from the spooled segment. + uri (property): The URI for the spooled segment. + ack_uri (property): The URI for acknowledging the processing of the spooled segment. + headers (property): The headers associated with the spooled segment. + + Methods: + acknowledge(): Sends an acknowledgment request for the segment. + """ + def __init__( + self, + segment: _SpooledSegmentTO, + request: TrinoRequest, + ) -> None: + super().__init__(segment) + self._segment = cast(_SpooledSegmentTO, segment) + self._request = request + + @property + def data(self) -> bytes: + http_response = self._send_spooling_request(self.uri) + if not http_response.ok: + self._request.raise_response_error(http_response) + return http_response.content + + @property + def uri(self) -> str: + return self._segment["uri"] + + @property + def ack_uri(self) -> str: + return self._segment["ackUri"] + + @property + def headers(self) -> Dict[str, List[str]]: + return self._segment.get("headers", {}) + + def acknowledge(self) -> None: + def acknowledge_request(): + try: + http_response = self._send_spooling_request(self.ack_uri, timeout=2) + if not http_response.ok: + self._request.raise_response_error(http_response) + except Exception as e: + logger.error(f"Failed to acknowledge spooling request for segment {self}: {e}") + # Start the acknowledgment in a background thread + thread = threading.Thread(target=acknowledge_request, daemon=True) + thread.start() + + def _send_spooling_request(self, uri: str, **kwargs) -> requests.Response: + headers_with_single_value = {} + for key, values in self.headers.items(): + if len(values) > 1: + raise ValueError(f"Header '{key}' contains multiple values: {values}") + headers_with_single_value[key] = values[0] + return self._request._get(uri, headers=headers_with_single_value, **kwargs) + + def __repr__(self): + return ( + f"SpooledSegment(metadata={self.metadata})" + ) + + +class SpooledData: + """ + Represents a collection of spooled segments of data, with an encoding format. + + Attributes: + encoding (str): The encoding format of the spooled data. + segments (List[Segment]): The list of segments in the spooled data. + """ + def __init__(self, encoding: str, segments: List[Segment]) -> None: + self._encoding = encoding + self._segments = segments + self._segments_iterator = iter(segments) + + @property + def encoding(self): + return self._encoding + + @property + def segments(self): + return self._segments + + def __iter__(self) -> Iterator[Tuple["SpooledData", "Segment"]]: + return self + + def __next__(self) -> Tuple["SpooledData", "Segment"]: + return self, next(self._segments_iterator) + + def __repr__(self): + return (f"SpooledData(encoding={self._encoding}, segments={list(self._segments)})") + + +class SegmentIterator: + def __init__(self, spooled_data: SpooledData, mapper: RowMapper) -> None: + self._segments = iter(spooled_data._segments) + self._decoder = SegmentDecoder(CompressedQueryDataDecoderFactory(mapper).create(spooled_data.encoding)) + self._rows: Iterator[List[List[Any]]] = iter([]) + self._finished = False + self._current_segment: Optional[Segment] = None + + def __iter__(self) -> Iterator[List[Any]]: + return self + + def __next__(self) -> List[Any]: + # If rows are exhausted, fetch the next segment + while True: + try: + return next(self._rows) + except StopIteration: + if self._current_segment and isinstance(self._current_segment, SpooledSegment): + self._current_segment.acknowledge() + if self._finished: + raise StopIteration + self._load_next_segment() + + def _load_next_segment(self): + try: + self._current_segment = segment = next(self._segments) + self._rows = iter(self._decoder.decode(segment)) + except StopIteration: + self._finished = True + + +class SegmentDecoder(): + def __init__(self, decoder: QueryDataDecoder): + self._decoder = decoder + + def decode(self, segment: Segment) -> List[List[Any]]: + if isinstance(segment, InlineSegment): + inline_segment = cast(InlineSegment, segment) + return self._decoder.decode(inline_segment.data, inline_segment.metadata) + elif isinstance(segment, SpooledSegment): + spooled_data = cast(SpooledSegment, segment) + return self._decoder.decode(spooled_data.data, spooled_data.metadata) + else: + raise ValueError(f"Unsupported segment type: {type(segment)}") + + +class CompressedQueryDataDecoderFactory(): + def __init__(self, mapper: RowMapper) -> None: + self._mapper = mapper + + def create(self, encoding: str) -> QueryDataDecoder: + if encoding == "json+zstd": + return ZStdQueryDataDecoder(JsonQueryDataDecoder(self._mapper)) + elif encoding == "json+lz4": + return Lz4QueryDataDecoder(JsonQueryDataDecoder(self._mapper)) + elif encoding == "json": + return JsonQueryDataDecoder(self._mapper) + else: + raise ValueError(f"Unsupported encoding: {encoding}") + + +class QueryDataDecoder(abc.ABC): + @abstractmethod + def decode(self, data: bytes, metadata: _SegmentMetadataTO) -> List[List[Any]]: + pass + + +class JsonQueryDataDecoder(QueryDataDecoder): + def __init__(self, mapper: RowMapper) -> None: + self._mapper = mapper + + def decode(self, data: bytes, metadata: Dict[str, Any]) -> List[List[Any]]: + return self._mapper.map(json.loads(data.decode("utf8"))) + + +class CompressedQueryDataDecoder(QueryDataDecoder): + def __init__(self, delegate: QueryDataDecoder) -> None: + self._delegate = delegate + + @abstractmethod + def decompress(self, data: bytes, metadata: _SegmentMetadataTO) -> bytes: + pass + + def decode(self, data: bytes, metadata: _SegmentMetadataTO) -> List[List[Any]]: + if "uncompressedSize" in metadata: + # Data is compressed + expected_compressed_size = metadata["segmentSize"] + if not len(data) == expected_compressed_size: + raise RuntimeError(f"Expected to read {expected_compressed_size} bytes but got {len(data)}") + compressed_data = self.decompress(data, metadata) + expected_uncompressed_size = metadata["uncompressedSize"] + if not len(compressed_data) == expected_uncompressed_size: + raise RuntimeError( + "Decompressed size does not match expected segment size, " + f"expected {expected_uncompressed_size}, got {len(compressed_data)}" + ) + return self._delegate.decode(compressed_data, metadata) + # Data not compressed - below threshold + return self._delegate.decode(data, metadata) + + +class ZStdQueryDataDecoder(CompressedQueryDataDecoder): + def decompress(self, data: bytes, metadata: _SegmentMetadataTO) -> bytes: + zstd_decompressor = zstandard.ZstdDecompressor() + return zstd_decompressor.decompress(data) + + +class Lz4QueryDataDecoder(CompressedQueryDataDecoder): + def decompress(self, data: bytes, metadata: _SegmentMetadataTO) -> bytes: + expected_uncompressed_size = metadata["uncompressedSize"] + decoded_bytes = lz4.block.decompress(data, uncompressed_size=int(expected_uncompressed_size)) + return decoded_bytes diff --git a/trino/constants.py b/trino/constants.py index 8193f218..20714e9f 100644 --- a/trino/constants.py +++ b/trino/constants.py @@ -37,6 +37,7 @@ HEADER_CLIENT_TAGS = "X-Trino-Client-Tags" HEADER_EXTRA_CREDENTIAL = "X-Trino-Extra-Credential" HEADER_TIMEZONE = "X-Trino-Time-Zone" +HEADER_ENCODING = "X-Trino-Query-Data-Encoding" HEADER_SESSION = "X-Trino-Session" HEADER_SET_SESSION = "X-Trino-Set-Session" diff --git a/trino/dbapi.py b/trino/dbapi.py index fb24f867..426532f9 100644 --- a/trino/dbapi.py +++ b/trino/dbapi.py @@ -30,6 +30,7 @@ from typing import List from typing import NamedTuple from typing import Optional +from typing import Union from urllib.parse import urlparse from zoneinfo import ZoneInfo @@ -128,6 +129,9 @@ def connect(*args, **kwargs): return Connection(*args, **kwargs) +_USE_DEFAULT_ENCODING = object() + + class Connection: """Trino supports transactions and the ability to either commit or rollback a sequence of SQL statements. A single query i.e. the execution of a SQL @@ -159,10 +163,18 @@ def __init__( legacy_prepared_statements=None, roles=None, timezone=None, + encoding: Union[str, List[str]] = _USE_DEFAULT_ENCODING, ): # Automatically assign http_schema, port based on hostname parsed_host = urlparse(host, allow_fragments=False) + if encoding is _USE_DEFAULT_ENCODING: + encoding = [ + "json+zstd", + "json+lz4", + "json", + ] + self.host = host if parsed_host.hostname is None else parsed_host.hostname + parsed_host.path self.port = port if parsed_host.port is None else parsed_host.port self.user = user @@ -182,6 +194,7 @@ def __init__( client_tags=client_tags, roles=roles, timezone=timezone, + encoding=encoding, ) # mypy cannot follow module import if http_session is None: