Skip to content

Commit

Permalink
mypy for all
Browse files Browse the repository at this point in the history
  • Loading branch information
mosquito committed Nov 2, 2021
1 parent a2b0b6c commit 66e268b
Show file tree
Hide file tree
Showing 7 changed files with 195 additions and 153 deletions.
26 changes: 13 additions & 13 deletions aiormq/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,17 @@
from pamqp import commands as spec
from pamqp.base import Frame
from pamqp.body import ContentBody
from pamqp.common import FieldTable, FieldValue, FieldArray
from pamqp.commands import Basic, Channel, Exchange, Queue, Tx, Confirm
from pamqp.commands import Basic, Channel, Confirm, Exchange, Queue, Tx
from pamqp.common import FieldArray, FieldTable, FieldValue
from pamqp.constants import REPLY_SUCCESS
from pamqp.header import ContentHeader
from pamqp.heartbeat import Heartbeat
from yarl import URL


ExceptionType = Union[BaseException, Type[BaseException]]


# noinspection PyShadowingNames
class TaskWrapper:
__slots__ = "exception", "task"
Expand All @@ -27,7 +30,7 @@ def __init__(self, task: asyncio.Task):
self.task = task
self.exception = asyncio.CancelledError

def throw(self, exception: BaseException) -> None:
def throw(self, exception: ExceptionType) -> None:
self.exception = exception
self.task.cancel()

Expand Down Expand Up @@ -56,15 +59,15 @@ def __repr__(self) -> str:


class DeliveredMessage(NamedTuple):
delivery: Union[Basic.Deliver, GetResultType]
delivery: Union[spec.Basic.Deliver, spec.Basic.Return, GetResultType]
header: ContentHeader
body: bytes
channel: "AbstractChannel"


ChannelRType = Tuple[int, Channel.OpenOk]

CallbackCoro = Coroutine[Any, None, Any]
CallbackCoro = Coroutine[Any, Any, Any]
ConsumerCallback = Callable[[DeliveredMessage], CallbackCoro]
ReturnCallback = Callable[[], CallbackCoro]

Expand Down Expand Up @@ -128,9 +131,6 @@ class ChannelFrame(NamedTuple):
drain_future: Optional[asyncio.Future] = None


ExceptionType = Union[Exception, Type[Exception]]


class AbstractFutureStore:
futures: Set[Union[asyncio.Future, TaskType]]
loop: asyncio.AbstractEventLoop
Expand All @@ -140,7 +140,7 @@ def add(self, future: Union[asyncio.Future, TaskWrapper]) -> None:
raise NotImplementedError

@abstractmethod
async def reject_all(self, exception: Optional[ExceptionType]) -> None:
def reject_all(self, exception: Optional[ExceptionType]) -> Any:
raise NotImplementedError

@abstractmethod
Expand Down Expand Up @@ -226,7 +226,7 @@ async def basic_consume(

@abstractmethod
def basic_ack(
self, delivery_tag: str, multiple: bool = False,
self, delivery_tag: int, multiple: bool = False,
) -> DrainResult:
raise NotImplementedError

Expand All @@ -241,7 +241,7 @@ def basic_nack(

@abstractmethod
def basic_reject(
self, delivery_tag: str, *, requeue: bool = True
self, delivery_tag: int, *, requeue: bool = True
) -> DrainResult:
raise NotImplementedError

Expand Down Expand Up @@ -418,7 +418,7 @@ async def confirm_delivery(


class AbstractConnection(AbstractBase):
FRAME_BUFFER: int = 10
FRAME_BUFFER_SIZE: int = 10
# Interval between sending heartbeats based on the heartbeat(timeout)
HEARTBEAT_INTERVAL_MULTIPLIER: TimeoutType

Expand Down Expand Up @@ -475,7 +475,7 @@ async def channel(
self,
channel_number: int = None,
publisher_confirms: bool = True,
frame_buffer: int = FRAME_BUFFER,
frame_buffer_size: int = FRAME_BUFFER_SIZE,
**kwargs: Any
) -> AbstractChannel:
raise NotImplementedError
Expand Down
43 changes: 26 additions & 17 deletions aiormq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
import asyncio
from contextlib import suppress
from functools import wraps
from typing import Any, Callable, Optional, Set, Type, TypeVar, Union
from typing import Any, Callable, Coroutine, Optional, Set, TypeVar, Union
from weakref import WeakSet

from .abc import (
AbstractBase, AbstractFutureStore, CoroutineType, TaskType, TaskWrapper,
AbstractBase, AbstractFutureStore, CoroutineType, ExceptionType, TaskType,
TaskWrapper,
)
from .tools import shield

Expand All @@ -26,23 +27,25 @@ def __init__(self, loop: asyncio.AbstractEventLoop):
self.loop = loop
self.parent: Optional[FutureStore] = None

def __on_task_done(self, future):
def remover(*_):
def __on_task_done(
self, future: Union[asyncio.Future, TaskWrapper],
) -> Callable[..., Any]:
def remover(*_: Any) -> None:
nonlocal future
if future in self.futures:
self.futures.remove(future)

return remover

def add(self, future: Union[asyncio.Future, TaskWrapper]):
def add(self, future: Union[asyncio.Future, TaskWrapper]) -> None:
self.futures.add(future)
future.add_done_callback(self.__on_task_done(future))

if self.parent:
self.parent.add(future)

@shield
async def reject_all(self, exception: Exception):
async def reject_all(self, exception: Optional[ExceptionType]) -> None:
tasks = []

while self.futures:
Expand All @@ -55,7 +58,7 @@ async def reject_all(self, exception: Exception):
future.throw(exception or Exception)
tasks.append(future)
elif isinstance(future, asyncio.Future):
future.set_exception(exception)
future.set_exception(exception or Exception)

if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
Expand All @@ -65,7 +68,7 @@ def create_task(self, coro: CoroutineType) -> TaskType:
self.add(task)
return task

def create_future(self, weak: bool = False):
def create_future(self, weak: bool = False) -> asyncio.Future:
future = self.loop.create_future()
self.add(future)
return future
Expand All @@ -92,12 +95,14 @@ def __init__(

self.closing = self._create_closing_future()

def _create_closing_future(self):
def _create_closing_future(self) -> asyncio.Future:
future = self.__future_store.create_future()
future.add_done_callback(lambda x: x.exception())
return future

def _cancel_tasks(self, exc: Union[Exception, Type[Exception]] = None):
def _cancel_tasks(
self, exc: ExceptionType = None,
) -> Coroutine[Any, Any, None]:
return self.__future_store.reject_all(exc)

def _future_store_child(self) -> AbstractFutureStore:
Expand All @@ -110,10 +115,12 @@ def create_future(self) -> asyncio.Future:
return self.__future_store.create_future()

@abc.abstractmethod
async def _on_close(self, exc=None): # pragma: no cover
async def _on_close(
self, exc: Optional[ExceptionType] = None
) -> None: # pragma: no cover
return

async def __closer(self, exc):
async def __closer(self, exc: Optional[ExceptionType]) -> None:
if self.is_closed: # pragma: no cover
return

Expand All @@ -123,24 +130,26 @@ async def __closer(self, exc):
with suppress(Exception):
await self._cancel_tasks(exc)

async def close(self, exc=asyncio.CancelledError()) -> None:
async def close(
self, exc: Optional[ExceptionType] = asyncio.CancelledError
) -> None:
if self.is_closed:
return None

await self.loop.create_task(self.__closer(exc))

def __repr__(self):
def __repr__(self) -> str:
cls_name = self.__class__.__name__
return '<{0}: "{1}" at 0x{2:02x}>'.format(
cls_name, str(self), id(self),
)

@abc.abstractmethod
def __str__(self): # pragma: no cover
def __str__(self) -> str: # pragma: no cover
raise NotImplementedError

@property
def is_closed(self):
def is_closed(self) -> bool:
return self.closing.done()


Expand All @@ -149,7 +158,7 @@ def is_closed(self):

def task(func: TaskFunctionType) -> TaskFunctionType:
@wraps(func)
async def wrap(self: Base, *args, **kwargs) -> Any:
async def wrap(self: Base, *args: Any, **kwargs: Any) -> Any:
return await self.create_task(func(self, *args, **kwargs))

return wrap
Loading

0 comments on commit 66e268b

Please sign in to comment.