Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: queryset support for flowruns #1460

Merged
merged 6 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions tableauserverclient/server/endpoint/flow_runs_endpoint.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import logging
from typing import Optional, TYPE_CHECKING
from typing import Optional, TYPE_CHECKING, Union

from tableauserverclient.server.endpoint.endpoint import QuerysetEndpoint, api
from tableauserverclient.server.endpoint.exceptions import FlowRunFailedException, FlowRunCancelledException
from tableauserverclient.models import FlowRunItem, PaginationItem
from tableauserverclient.models import FlowRunItem
from tableauserverclient.exponential_backoff import ExponentialBackoffTimer

from tableauserverclient.helpers.logging import logger
Expand All @@ -25,13 +25,15 @@ def baseurl(self) -> str:

# Get all flows
@api(version="3.10")
def get(self, req_options: Optional["RequestOptions"] = None) -> tuple[list[FlowRunItem], PaginationItem]:
# QuerysetEndpoint expects a PaginationItem to be returned, but FlowRuns
# does not return a PaginationItem. Suppressing the mypy error because the
# changes to the QuerySet class should permit this to function regardless.
def get(self, req_options: Optional["RequestOptions"] = None) -> list[FlowRunItem]: # type: ignore[override]
logger.info("Querying all flow runs on site")
url = self.baseurl
server_response = self.get_request(url, req_options)
pagination_item = PaginationItem.from_response(server_response.content, self.parent_srv.namespace)
all_flow_run_items = FlowRunItem.from_response(server_response.content, self.parent_srv.namespace)
return all_flow_run_items, pagination_item
return all_flow_run_items

# Get 1 flow by id
@api(version="3.10")
Expand All @@ -46,7 +48,7 @@ def get_by_id(self, flow_run_id: str) -> FlowRunItem:

# Cancel 1 flow run by id
@api(version="3.10")
def cancel(self, flow_run_id: str) -> None:
def cancel(self, flow_run_id: Union[str, FlowRunItem]) -> None:
if not flow_run_id:
error = "Flow ID undefined."
raise ValueError(error)
Expand Down
64 changes: 55 additions & 9 deletions tableauserverclient/server/query.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from collections.abc import Sized
from collections.abc import Iterable, Iterator, Sized
from itertools import count
from typing import Optional, Protocol, TYPE_CHECKING, TypeVar, overload
from collections.abc import Iterable, Iterator
import sys
from tableauserverclient.config import config
from tableauserverclient.models.pagination_item import PaginationItem
from tableauserverclient.server.endpoint.exceptions import ServerResponseError
from tableauserverclient.server.filter import Filter
from tableauserverclient.server.request_options import RequestOptions
from tableauserverclient.server.sort import Sort
Expand Down Expand Up @@ -35,6 +36,32 @@ def to_camel_case(word: str) -> str:


class QuerySet(Iterable[T], Sized):
"""
QuerySet is a class that allows easy filtering, sorting, and iterating over
many endpoints in TableauServerClient. It is designed to be used in a similar
way to Django QuerySets, but with a more limited feature set.

QuerySet is an iterable, and can be used in for loops, list comprehensions,
and other places where iterables are expected.

QuerySet is also Sized, and can be used in places where the length of the
QuerySet is needed. The length of the QuerySet is the total number of items
available in the QuerySet, not just the number of items that have been
fetched. If the endpoint does not return a total count of items, the length
of the QuerySet will be sys.maxsize. If there is no total count, the
QuerySet will continue to fetch items until there are no more items to
fetch.

QuerySet is not re-entrant. It is not designed to be used in multiple places
at the same time. If you need to use a QuerySet in multiple places, you
should create a new QuerySet for each place you need to use it, convert it
to a list, or create a deep copy of the QuerySet.

QuerySets are also indexable, and can be sliced. If you try to access an
index that has not been fetched, the QuerySet will fetch the page that
contains the item you are looking for.
"""

def __init__(self, model: "QuerysetEndpoint[T]", page_size: Optional[int] = None) -> None:
self.model = model
self.request_options = RequestOptions(pagesize=page_size or config.PAGE_SIZE)
Expand All @@ -50,10 +77,20 @@ def __iter__(self: Self) -> Iterator[T]:
for page in count(1):
self.request_options.pagenumber = page
self._result_cache = []
self._fetch_all()
try:
self._fetch_all()
except ServerResponseError as e:
if e.code == "400006":
# If the endpoint does not support pagination, it will end
# up overrunning the total number of pages. Catch the
# error and break out of the loop.
raise StopIteration
yield from self._result_cache
# Set result_cache to empty so the fetch will populate
if (page * self.page_size) >= len(self):
# If the length of the QuerySet is unknown, continue fetching until
# the result cache is empty.
if (size := len(self)) == 0:
continue
if (page * self.page_size) >= size:
return

@overload
Expand Down Expand Up @@ -114,10 +151,15 @@ def _fetch_all(self: Self) -> None:
Retrieve the data and store result and pagination item in cache
"""
if not self._result_cache:
self._result_cache, self._pagination_item = self.model.get(self.request_options)
response = self.model.get(self.request_options)
if isinstance(response, tuple):
self._result_cache, self._pagination_item = response
else:
self._result_cache = response
self._pagination_item = PaginationItem()

def __len__(self: Self) -> int:
return self.total_available
return self.total_available or sys.maxsize

@property
def total_available(self: Self) -> int:
Expand All @@ -127,12 +169,16 @@ def total_available(self: Self) -> int:
@property
def page_number(self: Self) -> int:
self._fetch_all()
return self._pagination_item.page_number
# If the PaginationItem is not returned from the endpoint, use the
# pagenumber from the RequestOptions.
return self._pagination_item.page_number or self.request_options.pagenumber

@property
def page_size(self: Self) -> int:
self._fetch_all()
return self._pagination_item.page_size
# If the PaginationItem is not returned from the endpoint, use the
# pagesize from the RequestOptions.
return self._pagination_item.page_size or self.request_options.pagesize

def filter(self: Self, *invalid, page_size: Optional[int] = None, **kwargs) -> Self:
if invalid:
Expand Down
14 changes: 14 additions & 0 deletions test/_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os.path
import unittest
from xml.etree import ElementTree as ET
from contextlib import contextmanager

TEST_ASSET_DIR = os.path.join(os.path.dirname(__file__), "assets")
Expand All @@ -18,6 +19,19 @@ def read_xml_assets(*args):
return map(read_xml_asset, args)


def server_response_error_factory(code: str, summary: str, detail: str) -> str:
root = ET.Element("tsResponse")
error = ET.SubElement(root, "error")
error.attrib["code"] = code

summary_element = ET.SubElement(error, "summary")
summary_element.text = summary

detail_element = ET.SubElement(error, "detail")
detail_element.text = detail
return ET.tostring(root, encoding="utf-8").decode("utf-8")


@contextmanager
def mocked_time():
mock_time = 0
Expand Down
3 changes: 1 addition & 2 deletions test/assets/flow_runs_get.xml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
<tsResponse xmlns="http://tableau.com/api" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://tableau.com/api http://tableau.com/api/ts-api-3.10.xsd">
<pagination pageNumber="1" pageSize="100" totalAvailable="2"/>
<flowRuns>
<flowRuns id="cc2e652d-4a9b-4476-8c93-b238c45db968"
flowId="587daa37-b84d-4400-a9a2-aa90e0be7837"
Expand All @@ -16,4 +15,4 @@
progress="100"
backgroundJobId="1ad21a9d-2530-4fbf-9064-efd3c736e023"/>
</flowRuns>
</tsResponse>
</tsResponse>
17 changes: 14 additions & 3 deletions test/test_flowruns.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import sys
import unittest

import requests_mock

import tableauserverclient as TSC
from tableauserverclient.datetime_helpers import format_datetime
from tableauserverclient.server.endpoint.exceptions import FlowRunFailedException
from ._utils import read_xml_asset, mocked_time
from ._utils import read_xml_asset, mocked_time, server_response_error_factory

GET_XML = "flow_runs_get.xml"
GET_BY_ID_XML = "flow_runs_get_by_id.xml"
Expand All @@ -28,9 +29,8 @@ def test_get(self) -> None:
response_xml = read_xml_asset(GET_XML)
with requests_mock.mock() as m:
m.get(self.baseurl, text=response_xml)
all_flow_runs, pagination_item = self.server.flow_runs.get()
all_flow_runs = self.server.flow_runs.get()

self.assertEqual(2, pagination_item.total_available)
self.assertEqual("cc2e652d-4a9b-4476-8c93-b238c45db968", all_flow_runs[0].id)
self.assertEqual("2021-02-11T01:42:55Z", format_datetime(all_flow_runs[0].started_at))
self.assertEqual("2021-02-11T01:57:38Z", format_datetime(all_flow_runs[0].completed_at))
Expand Down Expand Up @@ -98,3 +98,14 @@ def test_wait_for_job_timeout(self) -> None:
m.get(f"{self.baseurl}/{flow_run_id}", text=response_xml)
with self.assertRaises(TimeoutError):
self.server.flow_runs.wait_for_job(flow_run_id, timeout=30)

def test_queryset(self) -> None:
response_xml = read_xml_asset(GET_XML)
error_response = server_response_error_factory(
"400006", "Bad Request", "0xB4EAB088 : The start index '9900' is greater than or equal to the total count.)"
)
with requests_mock.mock() as m:
m.get(f"{self.baseurl}?pageNumber=1", text=response_xml)
m.get(f"{self.baseurl}?pageNumber=2", text=error_response)
queryset = self.server.flow_runs.all()
assert len(queryset) == sys.maxsize
Loading