Skip to content

Commit

Permalink
feat(client): improve error message for http timeouts
Browse files Browse the repository at this point in the history
  • Loading branch information
jonathanblade committed Mar 6, 2024
1 parent b690482 commit 0384fbc
Show file tree
Hide file tree
Showing 10 changed files with 204 additions and 13 deletions.
16 changes: 16 additions & 0 deletions src/prisma/_async_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,28 @@

import httpx

from .utils import ExcConverter
from ._types import Method
from .errors import HTTPClientTimeoutError
from .http_abstract import AbstractHTTP, AbstractResponse

__all__ = ('HTTP', 'AsyncHTTP', 'Response', 'client')


convert_exc = ExcConverter(
{
httpx.ConnectTimeout: HTTPClientTimeoutError,
httpx.ReadTimeout: HTTPClientTimeoutError,
httpx.WriteTimeout: HTTPClientTimeoutError,
httpx.PoolTimeout: HTTPClientTimeoutError,
}
)


class AsyncHTTP(AbstractHTTP[httpx.AsyncClient, httpx.Response]):
session: httpx.AsyncClient

@convert_exc
@override
async def download(self, url: str, dest: str) -> None:
async with self.session.stream('GET', url, timeout=None) as resp:
Expand All @@ -21,14 +34,17 @@ async def download(self, url: str, dest: str) -> None:
async for chunk in resp.aiter_bytes():
fd.write(chunk)

@convert_exc
@override
async def request(self, method: Method, url: str, **kwargs: Any) -> 'Response':
return Response(await self.session.request(method, url, **kwargs))

@convert_exc
@override
def open(self) -> None:
self.session = httpx.AsyncClient(**self.session_kwargs)

@convert_exc
@override
async def close(self) -> None:
if self.should_close():
Expand Down
11 changes: 10 additions & 1 deletion src/prisma/_constants.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
from typing import Dict
from typing import Any, Dict
from datetime import timedelta

import httpx

DEFAULT_CONNECT_TIMEOUT: timedelta = timedelta(seconds=10)
DEFAULT_TX_MAX_WAIT: timedelta = timedelta(milliseconds=2000)
DEFAULT_TX_TIMEOUT: timedelta = timedelta(milliseconds=5000)

DEFAULT_HTTP_LIMITS: httpx.Limits = httpx.Limits(max_connections=1000)
DEFAULT_HTTP_TIMEOUT: httpx.Timeout = httpx.Timeout(30)
DEFAULT_HTTP_CONFIG: Dict[str, Any] = {
'limits': DEFAULT_HTTP_LIMITS,
'timeout': DEFAULT_HTTP_TIMEOUT,
}

# key aliases to transform query arguments to make them more pythonic
QUERY_BUILDER_ALIASES: Dict[str, str] = {
'startswith': 'startsWith',
Expand Down
16 changes: 16 additions & 0 deletions src/prisma/_sync_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,28 @@

import httpx

from .utils import ExcConverter
from ._types import Method
from .errors import HTTPClientTimeoutError
from .http_abstract import AbstractHTTP, AbstractResponse

__all__ = ('HTTP', 'SyncHTTP', 'Response', 'client')


convert_exc = ExcConverter(
{
httpx.ConnectTimeout: HTTPClientTimeoutError,
httpx.ReadTimeout: HTTPClientTimeoutError,
httpx.WriteTimeout: HTTPClientTimeoutError,
httpx.PoolTimeout: HTTPClientTimeoutError,
}
)


class SyncHTTP(AbstractHTTP[httpx.Client, httpx.Response]):
session: httpx.Client

@convert_exc
@override
def download(self, url: str, dest: str) -> None:
with self.session.stream('GET', url, timeout=None) as resp:
Expand All @@ -20,14 +33,17 @@ def download(self, url: str, dest: str) -> None:
for chunk in resp.iter_bytes():
fd.write(chunk)

@convert_exc
@override
def request(self, method: Method, url: str, **kwargs: Any) -> 'Response':
return Response(self.session.request(method, url, **kwargs))

@convert_exc
@override
def open(self) -> None:
self.session = httpx.Client(**self.session_kwargs)

@convert_exc
@override
def close(self) -> None:
if self.should_close():
Expand Down
2 changes: 2 additions & 0 deletions src/prisma/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
FuncType = Callable[..., object]
CoroType = Callable[..., Coroutine[Any, Any, object]]

ExcMapping = Mapping[Type[BaseException], Type[BaseException]]


@runtime_checkable
class InheritsGeneric(Protocol):
Expand Down
9 changes: 9 additions & 0 deletions src/prisma/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
'TableNotFoundError',
'RecordNotFoundError',
'HTTPClientClosedError',
'HTTPClientTimeoutError',
'ClientNotConnectedError',
'PrismaWarning',
'UnsupportedSubclassWarning',
Expand Down Expand Up @@ -44,6 +45,14 @@ def __init__(self) -> None:
super().__init__('Cannot make a request from a closed client.')


class HTTPClientTimeoutError(PrismaError):
def __init__(self) -> None:
super().__init__(
'HTTP operation has timed out.\n'
'The default timeout is 30 seconds. Maybe you should increase it: prisma.Prisma(http_config={"timeout": httpx.Timeout(30)})'
)


class UnsupportedDatabaseError(PrismaError):
context: str
database: str
Expand Down
10 changes: 3 additions & 7 deletions src/prisma/http_abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,18 @@
)
from typing_extensions import override

from httpx import Limits, Headers, Timeout
from httpx import Headers

from .utils import _NoneType
from ._types import Method
from .errors import HTTPClientClosedError
from ._constants import DEFAULT_HTTP_CONFIG

Session = TypeVar('Session')
Response = TypeVar('Response')
ReturnType = TypeVar('ReturnType')
MaybeCoroutine = Union[Coroutine[Any, Any, ReturnType], ReturnType]

DEFAULT_CONFIG: Dict[str, Any] = {
'limits': Limits(max_connections=1000),
'timeout': Timeout(30),
}


class AbstractHTTP(ABC, Generic[Session, Response]):
session_kwargs: Dict[str, Any]
Expand All @@ -45,7 +41,7 @@ def __init__(self, **kwargs: Any) -> None:
# Session = open
self._session: Optional[Union[Session, Type[_NoneType]]] = _NoneType
self.session_kwargs = {
**DEFAULT_CONFIG,
**DEFAULT_HTTP_CONFIG,
**kwargs,
}

Expand Down
59 changes: 57 additions & 2 deletions src/prisma/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
import inspect
import logging
import warnings
import functools
import contextlib
from typing import TYPE_CHECKING, Any, Dict, Union, TypeVar, Iterator, NoReturn, Coroutine
from types import TracebackType
from typing import TYPE_CHECKING, Any, Dict, Type, Union, TypeVar, Callable, Iterator, NoReturn, Optional, Coroutine
from importlib.util import find_spec

from ._types import CoroType, FuncType, TypeGuard
from ._types import CoroType, FuncType, TypeGuard, ExcMapping

if TYPE_CHECKING:
from typing_extensions import TypeGuard
Expand Down Expand Up @@ -139,3 +141,56 @@ def make_optional(value: _T) -> _T | None:

def is_dict(obj: object) -> TypeGuard[dict[object, object]]:
return isinstance(obj, dict)


# TODO: improve typing
class MaybeAsyncContextDecorator(contextlib.ContextDecorator):
"""`ContextDecorator` compatible with sync/async functions."""

def __call__(self, func: Callable[..., Any]) -> Callable[..., Any]: # type: ignore[override]
@functools.wraps(func)
async def async_inner(*args: Any, **kwargs: Any) -> object:
async with self._recreate_cm(): # type: ignore[attr-defined]
return await func(*args, **kwargs)

@functools.wraps(func)
def sync_inner(*args: Any, **kwargs: Any) -> object:
with self._recreate_cm(): # type: ignore[attr-defined]
return func(*args, **kwargs)

if is_coroutine(func):
return async_inner
else:
return sync_inner


class ExcConverter(MaybeAsyncContextDecorator):
"""`MaybeAsyncContextDecorator` to convert exceptions."""

def __init__(self, exc_mapping: ExcMapping) -> None:
self._exc_mapping = exc_mapping

def __enter__(self) -> 'ExcConverter':
return self

def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
if exc is not None and exc_type is not None:
target_exc_type = self._exc_mapping.get(exc_type)
if target_exc_type is not None:
raise target_exc_type() from exc

async def __aenter__(self) -> 'ExcConverter':
return self.__enter__()

async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
self.__exit__(exc_type, exc, exc_tb)
4 changes: 2 additions & 2 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
from prisma import ENGINE_TYPE, SCHEMA_PATH, Prisma, errors, get_client
from prisma.types import HttpConfig
from prisma.testing import reset_client
from prisma._constants import DEFAULT_HTTP_CONFIG
from prisma.cli.prisma import run
from prisma.engine.http import HTTPEngine
from prisma.engine.errors import AlreadyConnectedError
from prisma.http_abstract import DEFAULT_CONFIG

from .utils import Testdir, patch_method

Expand Down Expand Up @@ -140,7 +140,7 @@ async def _test(config: HttpConfig) -> None:

captured = getter()
assert captured is not None
assert captured == ((), {**DEFAULT_CONFIG, **config})
assert captured == ((), {**DEFAULT_HTTP_CONFIG, **config})

await _test({'timeout': 1})
await _test({'timeout': httpx.Timeout(5, connect=10, read=30)})
Expand Down
26 changes: 25 additions & 1 deletion tests/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

import httpx
import pytest
from pytest_mock import MockerFixture

from prisma.http import HTTP
from prisma.utils import _NoneType
from prisma._types import Literal
from prisma.errors import HTTPClientClosedError
from prisma.errors import HTTPClientClosedError, HTTPClientTimeoutError

from .utils import patch_method

Expand Down Expand Up @@ -81,3 +82,26 @@ async def test_httpx_default_config(monkeypatch: 'MonkeyPatch') -> None:
'timeout': httpx.Timeout(30),
},
)


@pytest.mark.asyncio
@pytest.mark.parametrize(
'httpx_error',
[
httpx.ConnectTimeout(''),
httpx.ReadTimeout(''),
httpx.WriteTimeout(''),
httpx.PoolTimeout(''),
],
)
async def test_http_timeout_error(httpx_error: BaseException, mocker: MockerFixture) -> None:
"""Ensure that `httpx.TimeoutException` is converted to `prisma.errors.HTTPClientTimeoutError`."""
mocker.patch('httpx.AsyncClient.request', side_effect=httpx_error)

http = HTTP()
http.open()

with pytest.raises(HTTPClientTimeoutError) as exc_info:
await http.request('GET', '/')

assert exc_info.value.__cause__ == httpx_error
64 changes: 64 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import asyncio
from typing import Type, NoReturn

import pytest

from prisma.utils import ExcConverter


@pytest.mark.asyncio
@pytest.mark.parametrize(
('convert_exc', 'raised_exc_type', 'expected_exc_type', 'should_be_converted'),
[
pytest.param(ExcConverter({ValueError: ImportError}), ValueError, ImportError, True, id='convert'),
pytest.param(ExcConverter({ValueError: ImportError}), RuntimeError, RuntimeError, False, id='do not convert'),
],
)
async def test_exc_converter(
convert_exc: ExcConverter,
raised_exc_type: Type[BaseException],
expected_exc_type: Type[BaseException],
should_be_converted: bool,
) -> None:
"""Ensure that `prisma.utils.ExcConverter` works as expected."""

# Test sync context manager
with pytest.raises(expected_exc_type) as exc_info_1:
with convert_exc:
raise raised_exc_type()

# Test async context manager
with pytest.raises(expected_exc_type) as exc_info_2:
async with convert_exc:
await asyncio.sleep(0.1)
raise raised_exc_type()

# Test sync decorator
with pytest.raises(expected_exc_type) as exc_info_3:

@convert_exc
def help_func() -> NoReturn:
raise raised_exc_type()

help_func()

# Test async decorator
with pytest.raises(expected_exc_type) as exc_info_4:

@convert_exc
async def help_func() -> NoReturn:
await asyncio.sleep(0.1)
raise raised_exc_type()

await help_func()

# Test exception cause
if should_be_converted:
assert all(
(
type(exc_info_1.value.__cause__) is raised_exc_type,
type(exc_info_2.value.__cause__) is raised_exc_type,
type(exc_info_3.value.__cause__) is raised_exc_type,
type(exc_info_4.value.__cause__) is raised_exc_type,
)
)

0 comments on commit 0384fbc

Please sign in to comment.