Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(client): improve error message for http timeouts #912

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions src/prisma/_async_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,28 @@

import httpx

from .utils import ExcConverter
from ._types import Method
from .errors import HTTPClientTimeoutError
from .http_abstract import AbstractHTTP, AbstractResponse

__all__ = ('HTTP', 'AsyncHTTP', 'Response', 'client')


convert_exc = ExcConverter(
{
httpx.ConnectTimeout: HTTPClientTimeoutError,
httpx.ReadTimeout: HTTPClientTimeoutError,
httpx.WriteTimeout: HTTPClientTimeoutError,
httpx.PoolTimeout: HTTPClientTimeoutError,
}
)


class AsyncHTTP(AbstractHTTP[httpx.AsyncClient, httpx.Response]):
session: httpx.AsyncClient

@convert_exc
@override
async def download(self, url: str, dest: str) -> None:
async with self.session.stream('GET', url, timeout=None) as resp:
Expand All @@ -21,14 +34,17 @@ async def download(self, url: str, dest: str) -> None:
async for chunk in resp.aiter_bytes():
fd.write(chunk)

@convert_exc
@override
async def request(self, method: Method, url: str, **kwargs: Any) -> 'Response':
return Response(await self.session.request(method, url, **kwargs))

@convert_exc
@override
def open(self) -> None:
self.session = httpx.AsyncClient(**self.session_kwargs)

@convert_exc
@override
async def close(self) -> None:
if self.should_close():
Expand Down
11 changes: 10 additions & 1 deletion src/prisma/_constants.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
from typing import Dict
from typing import Any, Dict
from datetime import timedelta

import httpx

DEFAULT_CONNECT_TIMEOUT: timedelta = timedelta(seconds=10)
DEFAULT_TX_MAX_WAIT: timedelta = timedelta(milliseconds=2000)
DEFAULT_TX_TIMEOUT: timedelta = timedelta(milliseconds=5000)

DEFAULT_HTTP_LIMITS: httpx.Limits = httpx.Limits(max_connections=1000)
DEFAULT_HTTP_TIMEOUT: httpx.Timeout = httpx.Timeout(30)
DEFAULT_HTTP_CONFIG: Dict[str, Any] = {
'limits': DEFAULT_HTTP_LIMITS,
'timeout': DEFAULT_HTTP_TIMEOUT,
}

# key aliases to transform query arguments to make them more pythonic
QUERY_BUILDER_ALIASES: Dict[str, str] = {
'startswith': 'startsWith',
Expand Down
16 changes: 16 additions & 0 deletions src/prisma/_sync_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,28 @@

import httpx

from .utils import ExcConverter
from ._types import Method
from .errors import HTTPClientTimeoutError
from .http_abstract import AbstractHTTP, AbstractResponse

__all__ = ('HTTP', 'SyncHTTP', 'Response', 'client')


convert_exc = ExcConverter(
{
httpx.ConnectTimeout: HTTPClientTimeoutError,
httpx.ReadTimeout: HTTPClientTimeoutError,
httpx.WriteTimeout: HTTPClientTimeoutError,
httpx.PoolTimeout: HTTPClientTimeoutError,
}
)


class SyncHTTP(AbstractHTTP[httpx.Client, httpx.Response]):
session: httpx.Client

@convert_exc
@override
def download(self, url: str, dest: str) -> None:
with self.session.stream('GET', url, timeout=None) as resp:
Expand All @@ -20,14 +33,17 @@ def download(self, url: str, dest: str) -> None:
for chunk in resp.iter_bytes():
fd.write(chunk)

@convert_exc
@override
def request(self, method: Method, url: str, **kwargs: Any) -> 'Response':
return Response(self.session.request(method, url, **kwargs))

@convert_exc
@override
def open(self) -> None:
self.session = httpx.Client(**self.session_kwargs)

@convert_exc
@override
def close(self) -> None:
if self.should_close():
Expand Down
2 changes: 2 additions & 0 deletions src/prisma/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
FuncType = Callable[..., object]
CoroType = Callable[..., Coroutine[Any, Any, object]]

ExcMapping = Mapping[Type[BaseException], Type[BaseException]]


@runtime_checkable
class InheritsGeneric(Protocol):
Expand Down
9 changes: 9 additions & 0 deletions src/prisma/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
'TableNotFoundError',
'RecordNotFoundError',
'HTTPClientClosedError',
'HTTPClientTimeoutError',
'ClientNotConnectedError',
'PrismaWarning',
'UnsupportedSubclassWarning',
Expand Down Expand Up @@ -44,6 +45,14 @@ def __init__(self) -> None:
super().__init__('Cannot make a request from a closed client.')


class HTTPClientTimeoutError(PrismaError):
def __init__(self) -> None:
super().__init__(
'HTTP operation has timed out.\n'
'The default timeout is 30 seconds. Maybe you should increase it: prisma.Prisma(http_config={"timeout": httpx.Timeout(30)})'
)


class UnsupportedDatabaseError(PrismaError):
context: str
database: str
Expand Down
10 changes: 3 additions & 7 deletions src/prisma/http_abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,18 @@
)
from typing_extensions import override

from httpx import Limits, Headers, Timeout
from httpx import Headers

from .utils import _NoneType
from ._types import Method
from .errors import HTTPClientClosedError
from ._constants import DEFAULT_HTTP_CONFIG

Session = TypeVar('Session')
Response = TypeVar('Response')
ReturnType = TypeVar('ReturnType')
MaybeCoroutine = Union[Coroutine[Any, Any, ReturnType], ReturnType]

DEFAULT_CONFIG: Dict[str, Any] = {
'limits': Limits(max_connections=1000),
'timeout': Timeout(30),
}


class AbstractHTTP(ABC, Generic[Session, Response]):
session_kwargs: Dict[str, Any]
Expand All @@ -45,7 +41,7 @@ def __init__(self, **kwargs: Any) -> None:
# Session = open
self._session: Optional[Union[Session, Type[_NoneType]]] = _NoneType
self.session_kwargs = {
**DEFAULT_CONFIG,
**DEFAULT_HTTP_CONFIG,
**kwargs,
}

Expand Down
59 changes: 57 additions & 2 deletions src/prisma/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
import inspect
import logging
import warnings
import functools
import contextlib
from typing import TYPE_CHECKING, Any, Dict, Union, TypeVar, Iterator, NoReturn, Coroutine
from types import TracebackType
from typing import TYPE_CHECKING, Any, Dict, Type, Union, TypeVar, Callable, Iterator, NoReturn, Optional, Coroutine
from importlib.util import find_spec

from ._types import CoroType, FuncType, TypeGuard
from ._types import CoroType, FuncType, TypeGuard, ExcMapping

if TYPE_CHECKING:
from typing_extensions import TypeGuard
Expand Down Expand Up @@ -139,3 +141,56 @@ def make_optional(value: _T) -> _T | None:

def is_dict(obj: object) -> TypeGuard[dict[object, object]]:
return isinstance(obj, dict)


# TODO: improve typing
class MaybeAsyncContextDecorator(contextlib.ContextDecorator):
"""`ContextDecorator` compatible with sync/async functions."""

def __call__(self, func: Callable[..., Any]) -> Callable[..., Any]: # type: ignore[override]
@functools.wraps(func)
async def async_inner(*args: Any, **kwargs: Any) -> object:
async with self._recreate_cm(): # type: ignore[attr-defined]
return await func(*args, **kwargs)

@functools.wraps(func)
def sync_inner(*args: Any, **kwargs: Any) -> object:
with self._recreate_cm(): # type: ignore[attr-defined]
return func(*args, **kwargs)

if is_coroutine(func):
return async_inner
else:
return sync_inner


class ExcConverter(MaybeAsyncContextDecorator):
"""`MaybeAsyncContextDecorator` to convert exceptions."""

def __init__(self, exc_mapping: ExcMapping) -> None:
self._exc_mapping = exc_mapping

def __enter__(self) -> 'ExcConverter':
return self

def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
if exc is not None and exc_type is not None:
target_exc_type = self._exc_mapping.get(exc_type)
if target_exc_type is not None:
raise target_exc_type() from exc

async def __aenter__(self) -> 'ExcConverter':
return self.__enter__()

async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
self.__exit__(exc_type, exc, exc_tb)
4 changes: 2 additions & 2 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
from prisma import ENGINE_TYPE, SCHEMA_PATH, Prisma, errors, get_client
from prisma.types import HttpConfig
from prisma.testing import reset_client
from prisma._constants import DEFAULT_HTTP_CONFIG
from prisma.cli.prisma import run
from prisma.engine.http import HTTPEngine
from prisma.engine.errors import AlreadyConnectedError
from prisma.http_abstract import DEFAULT_CONFIG

from .utils import Testdir, patch_method

Expand Down Expand Up @@ -140,7 +140,7 @@ async def _test(config: HttpConfig) -> None:

captured = getter()
assert captured is not None
assert captured == ((), {**DEFAULT_CONFIG, **config})
assert captured == ((), {**DEFAULT_HTTP_CONFIG, **config})

await _test({'timeout': 1})
await _test({'timeout': httpx.Timeout(5, connect=10, read=30)})
Expand Down
26 changes: 25 additions & 1 deletion tests/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

import httpx
import pytest
from pytest_mock import MockerFixture

from prisma.http import HTTP
from prisma.utils import _NoneType
from prisma._types import Literal
from prisma.errors import HTTPClientClosedError
from prisma.errors import HTTPClientClosedError, HTTPClientTimeoutError

from .utils import patch_method

Expand Down Expand Up @@ -81,3 +82,26 @@ async def test_httpx_default_config(monkeypatch: 'MonkeyPatch') -> None:
'timeout': httpx.Timeout(30),
},
)


@pytest.mark.asyncio
@pytest.mark.parametrize(
'httpx_error',
[
httpx.ConnectTimeout(''),
httpx.ReadTimeout(''),
httpx.WriteTimeout(''),
httpx.PoolTimeout(''),
],
)
async def test_http_timeout_error(httpx_error: BaseException, mocker: MockerFixture) -> None:
"""Ensure that `httpx.TimeoutException` is converted to `prisma.errors.HTTPClientTimeoutError`."""
mocker.patch('httpx.AsyncClient.request', side_effect=httpx_error)

http = HTTP()
http.open()

with pytest.raises(HTTPClientTimeoutError) as exc_info:
await http.request('GET', '/')

assert exc_info.value.__cause__ == httpx_error
66 changes: 66 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import asyncio
from typing import Type, NoReturn

import pytest

from prisma.utils import ExcConverter


@pytest.mark.asyncio
@pytest.mark.parametrize(
('convert_exc', 'raised_exc_type', 'expected_exc_type', 'should_be_converted'),
[
pytest.param(ExcConverter({ValueError: ImportError}), ValueError, ImportError, True, id='should convert'),
pytest.param(
ExcConverter({ValueError: ImportError}), RuntimeError, RuntimeError, False, id='should not convert'
),
],
)
async def test_exc_converter(
convert_exc: ExcConverter,
raised_exc_type: Type[BaseException],
expected_exc_type: Type[BaseException],
should_be_converted: bool,
) -> None:
"""Ensure that `prisma.utils.ExcConverter` works as expected."""

# Test sync context manager
with pytest.raises(expected_exc_type) as exc_info_1:
with convert_exc:
raise raised_exc_type()

# Test async context manager
with pytest.raises(expected_exc_type) as exc_info_2:
async with convert_exc:
await asyncio.sleep(0.1)
raise raised_exc_type()

# Test sync decorator
with pytest.raises(expected_exc_type) as exc_info_3:

@convert_exc
def help_func() -> NoReturn:
raise raised_exc_type()

help_func()

# Test async decorator
with pytest.raises(expected_exc_type) as exc_info_4:

@convert_exc
async def help_func() -> NoReturn:
await asyncio.sleep(0.1)
raise raised_exc_type()

await help_func()

# Test exception cause
if should_be_converted:
assert all(
(
type(exc_info_1.value.__cause__) is raised_exc_type,
type(exc_info_2.value.__cause__) is raised_exc_type,
type(exc_info_3.value.__cause__) is raised_exc_type,
type(exc_info_4.value.__cause__) is raised_exc_type,
)
)
Loading