Skip to content

Commit

Permalink
Update docstring and names.
Browse files Browse the repository at this point in the history
  • Loading branch information
YooSunYoung committed Dec 16, 2024
1 parent 2e412c1 commit 1385eb6
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 56 deletions.
99 changes: 49 additions & 50 deletions src/appstract/event_driven.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2024 Scipp contributors (https://github.com/scipp)
"""Asynchronous application components."""
"""Event-driven application components."""

import asyncio
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator, Awaitable, Callable, Coroutine, Generator
from contextlib import asynccontextmanager, contextmanager
from dataclasses import dataclass
Expand Down Expand Up @@ -79,7 +80,7 @@ def __call__(self) -> AsyncGenerator[EventMessageProtocol | None, None]: ...


MessageT = TypeVar("MessageT", bound=EventMessageProtocol)
HandlerT = TypeVar("HandlerT", bound=Callable)
HandlerT = TypeVar("HandlerT", Callable, Awaitable)


class MessageRouter(LogMixin):
Expand Down Expand Up @@ -115,26 +116,25 @@ def _handler_wrapper(
def _register(
self,
*,
handler_list: dict[type[MessageT], list[HandlerT]],
handler_map: dict[type[MessageT], list[HandlerT]],
event_tp: type[MessageT],
handler: HandlerT,
):
if event_tp in handler_list:
handler_list[event_tp].append(handler)
else:
handler_list[event_tp] = [handler]
existing_list = handler_map.setdefault(event_tp, [])
if handler not in existing_list:
existing_list.append(handler)

def register_handler(
self,
event_tp: type[MessageT],
handler: Callable[[MessageT], Any] | Callable[[MessageT], Awaitable[Any]],
):
if asyncio.iscoroutinefunction(handler):
handler_list = self.awaitable_handlers
handler_map = self.awaitable_handlers
else:
handler_list = self.handlers
handler_map = self.handlers

self._register(handler_list=handler_list, event_tp=event_tp, handler=handler)
self._register(handler_map=handler_map, event_tp=event_tp, handler=handler)

def _collect_results(self, result: Any) -> list[EventMessageProtocol]:
"""Append or extend ``result`` to ``self.message_pipe``.
Expand All @@ -148,7 +148,8 @@ def _collect_results(self, result: Any) -> list[EventMessageProtocol]:
else:
return []

def sync_route(self, message: EventMessageProtocol) -> None:
def route(self, message: EventMessageProtocol) -> None:
"""Route the message to the appropriate handlers."""
# Synchronous handlers
results = []
for handler in (handlers := self.handlers.get(type(message), [])):
Expand All @@ -164,6 +165,7 @@ def sync_route(self, message: EventMessageProtocol) -> None:
self.message_pipe.put(result)

async def async_route(self, message: EventMessageProtocol) -> None:
"""Asynchronously route the message to the appropriate handlers."""
# Synchronous handlers
results = []
for handler in (handlers := self.handlers.get(type(message), [])):
Expand All @@ -186,32 +188,32 @@ async def async_route(self, message: EventMessageProtocol) -> None:
for result in results:
self.message_pipe.put(result)

async def async_run(self) -> AsyncGenerator[EventMessageProtocol | None, None]:
def run(self) -> Generator[EventMessageProtocol | None, None]:
"""Message router daemon."""
while True:
await asyncio.sleep(0)
if self.message_pipe.empty():
await asyncio.sleep(0.1)
yield
while not self.message_pipe.empty():
await self.async_route(self.message_pipe.get())
self.route(self.message_pipe.get())
yield
yield

def sync_run(self) -> Generator[EventMessageProtocol | None, None]:
async def async_run(self) -> AsyncGenerator[EventMessageProtocol | None, None]:
"""Message router daemon."""
while True:
await asyncio.sleep(0)
if self.message_pipe.empty():
yield
await asyncio.sleep(0.1)
while not self.message_pipe.empty():
self.sync_route(self.message_pipe.get())
yield
await self.async_route(self.message_pipe.get())
yield

async def send_message_async(self, message: EventMessageProtocol) -> None:
def send_message(self, message: EventMessageProtocol) -> None:
self.message_pipe.put(message)
await asyncio.sleep(0)

def send_message(self, message: EventMessageProtocol) -> None:
async def async_send_message(self, message: EventMessageProtocol) -> None:
self.message_pipe.put(message)
await asyncio.sleep(0)


@dataclass
Expand All @@ -221,7 +223,18 @@ class StopEventMessage:
content: Any


class ApplicationMixin:
class ApplicationInterface(ABC, LogMixin):
"""Application interface."""

logger: AppLogger
message_router: MessageRouter

@abstractmethod
def cancel_all_tasks(self) -> None: ...

@abstractmethod
def register_daemon(self, daemon: Callable) -> None: ...

def stop_tasks(self, message: EventMessageProtocol | None = None) -> None:
self.info('Stop running application %s...', self.__class__.__name__)
if message is not None and not isinstance(message, StopEventMessage):
Expand All @@ -236,18 +249,6 @@ def register_handler(
"""Register handlers to the application message router."""
self.message_router.register_handler(event_tp, handler)

def register_daemon(
self,
daemon: DaemonMessageGeneratorProtocol | AsyncDaemonMessageGeneratorProtocol,
) -> None:
"""Register a daemon generator to the application.
Registered daemons will be scheduled in the event loop
as :func:`~Application.run` method is called.
The future of the daemon will be collected in the ``self.tasks`` list.
"""
self.daemons.append(daemon)

@contextmanager
def _handle_keyboard_interrupt(self):
_interrupted_count = 0
Expand All @@ -258,13 +259,13 @@ def _handle_keyboard_interrupt(self):
self.info("Received a keyboard interrupt. Exiting...")
self.info("Press Ctrl+C one more time to kill immediately.")
self.message_router.message_pipe.put_nowait(
AsyncApplication.Stop(content=None)
StopEventMessage(content=None)
)
else:
raise e


class SyncApplication(ApplicationMixin, LogMixin):
class SyncApplication(ApplicationInterface, LogMixin):
"""Synchronous Application class.
Main Responsibilities:
Expand All @@ -278,9 +279,7 @@ def __init__(self, logger: AppLogger, message_router: MessageRouter) -> None:
self.tasks: dict[Callable, Generator] = {}
self.logger = logger
self.message_router = message_router
self.daemons: list[DaemonMessageGeneratorProtocol] = [
self.message_router.sync_run
]
self.daemons: list[DaemonMessageGeneratorProtocol] = [self.message_router.run]
self.register_handler(StopEventMessage, self.stop_tasks)
self._break = False
super().__init__()
Expand All @@ -307,19 +306,19 @@ def _daemon_wrapper(
self, daemon: DaemonMessageGeneratorProtocol
) -> Generator[None, None]:
try:
self.info('Running daemon %s', daemon.__class__.__qualname__)
self.info('Running daemon %s', daemon.__qualname__)
yield
except Exception as e:
self.error(f"Daemon {daemon} failed. Cancelling all other tasks...")
# Break all daemon generator loops.
self._break = True
# Let other daemons/handlers clean up.
self.message_router.sync_route(StopEventMessage(None))
self.message_router.route(StopEventMessage(None))
# Make sure all other async tasks are cancelled.
self.cancel_all_tasks()
raise e
else:
self.info("Daemon %s completed.", daemon.__class__.__qualname__)
self.info("Daemon %s completed.", daemon.__qualname__)

def _create_daemon_tasks(
self,
Expand Down Expand Up @@ -365,18 +364,18 @@ def run(self):
try:
next(task)
except StopIteration: # noqa: PERF203
self.info("Daemon %s completed.", daemon.__class__.__qualname__)
self.info("Removing completed daemon %s", daemon.__qualname__)
self.tasks.pop(daemon)
if not self.tasks:
self.info("All daemons completed.")
self.info("All daemons completed. Exiting...")
break

def run_after_run(self):
"""This method is only available for :class:`~AsyncApplication`."""
raise NotImplementedError("This method is only available for AsyncApplication.")


class AsyncApplication(ApplicationMixin, LogMixin):
class AsyncApplication(ApplicationInterface, LogMixin):
"""Asynchronous Application class.
Main Responsibilities:
Expand Down Expand Up @@ -428,7 +427,7 @@ async def _daemon_wrapper(
daemon: AsyncDaemonMessageGeneratorProtocol | DaemonMessageGeneratorProtocol,
) -> AsyncGenerator[None, None]:
try:
self.info('Running daemon %s', daemon.__class__.__qualname__)
self.info('Running daemon %s', daemon.__qualname__)
yield
except Exception as e:
# Make sure all other async tasks are cancelled.
Expand All @@ -444,7 +443,7 @@ async def _daemon_wrapper(
self.cancel_all_tasks()
raise e
else:
self.info("Daemon %s completed.", daemon.__class__.__qualname__)
self.info("Daemon %s completed.", daemon.__qualname__)

def _create_daemon_coroutines(
self,
Expand All @@ -460,14 +459,14 @@ async def run_daemon(
if isinstance(generator, AsyncGenerator):
async for message in generator:
if message is not None:
await self.message_router.send_message_async(message)
await self.message_router.async_send_message(message)
if self._break:
break
await asyncio.sleep(0)
elif isinstance(generator, Generator):
for message in generator:
if message is not None:
await self.message_router.send_message_async(message)
await self.message_router.async_send_message(message)
if self._break:
break
await asyncio.sleep(0)
Expand Down
16 changes: 10 additions & 6 deletions src/appstract/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,22 +71,25 @@ class HelloWorldMessage:
class Echo(LogMixin):
logger: AppLogger

async def echo(self, msg: EventMessageProtocol) -> None:
await asyncio.sleep(0.5)
async def async_echo(self, msg: EventMessageProtocol) -> None:
await asyncio.sleep(1)
self.error(msg.content)
await asyncio.sleep(1)

def sync_echo(self, msg: EventMessageProtocol) -> None:
time.sleep(0.5)
time.sleep(1)
self.error(msg.content)
time.sleep(1)


class Narc(LogMixin):
logger: AppLogger

async def shout(self) -> AsyncGenerator[EventMessageProtocol, None]:
async def async_shout(self) -> AsyncGenerator[EventMessageProtocol, None]:
self.info("Going to shout hello world 3 times...")
messages = ("Hello World", "Heelllloo World!", "Heeelllllloooo World!")
for msg in messages:
await asyncio.sleep(1)
self.info(msg)
yield HelloWorldMessage(msg)
await asyncio.sleep(1)
Expand All @@ -97,6 +100,7 @@ def sync_shout(self) -> Generator[EventMessageProtocol, None, None]:
self.info("Going to shout hello world 3 times...")
messages = ("Hello World", "Heelllloo World!", "Heeelllllloooo World!")
for msg in messages:
time.sleep(1)
self.info(msg)
yield HelloWorldMessage(msg)
time.sleep(1)
Expand Down Expand Up @@ -124,10 +128,10 @@ def run_async_helloworld():
narc = factory[Narc]

# Handlers
app.register_handler(HelloWorldMessage, echo.echo)
app.register_handler(HelloWorldMessage, echo.async_echo)

# Daemons
app.register_daemon(narc.shout)
app.register_daemon(narc.async_shout)
app.run()


Expand Down

0 comments on commit 1385eb6

Please sign in to comment.