Skip to content

Commit

Permalink
Add example script.
Browse files Browse the repository at this point in the history
  • Loading branch information
YooSunYoung committed Dec 13, 2024
1 parent 0f187b2 commit 05ac23d
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 32 deletions.
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ test = [
"rich",
]

[project.scripts]
appstract-helloworld = "appstract.script:run_helloworld"

[project.urls]
"Bug Tracker" = "https://github.com/scipp/appstract/issues"
"Documentation" = "https://scipp.github.io/appstract"
Expand Down
81 changes: 49 additions & 32 deletions src/appstract/async.py → src/appstract/asyncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def __init__(self, logger: AppLogger, message_router: MessageRouter) -> None:
self.logger = logger
self.message_router = message_router
self.daemons: list[DaemonMessageGeneratorProtocol] = [self.message_router.run]
self.register_handling_method(self.Stop, self.stop_tasks)
self.register_handler(self.Stop, self.stop_tasks)
self._break = False
super().__init__()

Expand All @@ -189,7 +189,7 @@ def stop_tasks(self, message: MessageProtocol | None = None) -> None:
)
self._break = True

def register_handling_method(
def register_handler(
self, event_tp: type[MessageT], handler: Callable[[MessageT], Any]
) -> None:
"""Register handlers to the application message router."""
Expand Down Expand Up @@ -248,6 +248,21 @@ async def run_daemon(daemon: DaemonMessageGeneratorProtocol):

return {daemon: run_daemon(daemon) for daemon in self.daemons}

@contextmanager
def _handle_keyboard_interrupt(self):
_interrupted_count = 0
try:
yield
except KeyboardInterrupt as e:
if _interrupted_count < 1:
self.info("Received a keyboard interrupt. Exiting...")
self.info("Press Ctrl+C one more time to kill immediately.")
self.message_router.message_pipe.put_nowait(
Application.Stop(content=None)
)
else:
raise e

def run(self):
"""
Register all handling methods and run all daemons.
Expand All @@ -267,23 +282,24 @@ def run(self):

from appstract.schedulers import temporary_event_loop

self.info('Start running %s...', self.__class__.__qualname__)
if self.tasks:
raise RuntimeError(
"Application is already running. "
"Cancel all tasks and clear them before running it again."
)

with temporary_event_loop() as loop:
self.loop = loop
daemon_coroutines = self._create_daemon_coroutines()
daemon_tasks = {
daemon: loop.create_task(coro)
for daemon, coro in daemon_coroutines.items()
}
self.tasks.update(daemon_tasks)
if not loop.is_running():
loop.run_until_complete(asyncio.gather(*self.tasks.values()))
with self._handle_keyboard_interrupt():
self.info('Start running %s...', self.__class__.__qualname__)
if self.tasks:
raise RuntimeError(
"Application is already running. "
"Cancel all tasks and clear them before running it again."
)

with temporary_event_loop() as loop:
self.loop = loop
daemon_coroutines = self._create_daemon_coroutines()
daemon_tasks = {
daemon: loop.create_task(coro)
for daemon, coro in daemon_coroutines.items()
}
self.tasks.update(daemon_tasks)
if not loop.is_running():
loop.run_until_complete(asyncio.gather(*self.tasks.values()))

def run_after_run(self):
"""
Expand All @@ -294,16 +310,17 @@ def run_after_run(self):
"""
import asyncio

self.info('Start running %s...', self.__class__.__qualname__)
if self.tasks:
raise RuntimeError(
"Application is already running. "
"Cancel all tasks and clear them before running it again."
)
self.loop = asyncio.get_event_loop()
daemon_coroutines = self._create_daemon_coroutines()
daemon_tasks = {
daemon: self.loop.create_task(coro)
for daemon, coro in daemon_coroutines.items()
}
self.tasks.update(daemon_tasks)
with self._handle_keyboard_interrupt():
self.info('Start running %s...', self.__class__.__qualname__)
if self.tasks:
raise RuntimeError(
"Application is already running. "
"Cancel all tasks and clear them before running it again."
)
self.loop = asyncio.get_event_loop()
daemon_coroutines = self._create_daemon_coroutines()
daemon_tasks = {
daemon: self.loop.create_task(coro)
for daemon, coro in daemon_coroutines.items()
}
self.tasks.update(daemon_tasks)
111 changes: 111 additions & 0 deletions src/appstract/script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2024 Scipp contributors (https://github.com/scipp)
import argparse
import asyncio
from collections.abc import AsyncGenerator
from dataclasses import dataclass
from importlib.metadata import entry_points
from typing import Protocol, TypeVar

from .asyncs import Application, MessageProtocol, MessageRouter
from .constructors import (
Factory,
ProviderGroup,
SingletonProvider,
multiple_constant_providers,
)
from .logging import AppLogger
from .logging.providers import log_providers
from .mixins import LogMixin

T = TypeVar("T", bound="ArgumentInstantiable")


def list_entry_points() -> list[str]:
return [ep.name for ep in entry_points(group='beamlime.workflow_plugin')]


def build_arg_parser(*sub_group_classes: type) -> argparse.ArgumentParser:
"""Builds the minimum argument parser for the highest-level entry point."""
parser = argparse.ArgumentParser(description="BEAMLIME configuration.")
parser.add_argument(
"--log-level",
help="Set logging level. Default is INFO.",
type=str,
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
default="INFO",
)
for sub_group_class in sub_group_classes:
if callable(add_arg := getattr(sub_group_class, "add_argument_group", None)):
add_arg(parser)

return parser


class ArgumentInstantiable(Protocol):
@classmethod
def add_argument_group(cls, parser: argparse.ArgumentParser) -> None: ...

@classmethod
def from_args(cls: type[T], logger: AppLogger, args: argparse.Namespace) -> T: ...


def instantiate_from_args(
logger: AppLogger, args: argparse.Namespace, tp: type[T]
) -> T:
return tp.from_args(logger=logger, args=args)


@dataclass
class HelloWorldMessage:
content: str


class Echo(LogMixin):
logger: AppLogger

async def echo(self, msg: MessageProtocol) -> None:
await asyncio.sleep(0.5)
self.error(msg.content)


class Narc(LogMixin):
logger: AppLogger

async def shout(self) -> AsyncGenerator[MessageProtocol, None]:
self.info("Going to shout hello world 3 times...")
messages = ("Hello World", "Heelllloo World!", "Heeelllllloooo World!")
for msg in messages:
self.info(msg)
yield HelloWorldMessage(msg)
await asyncio.sleep(1)

yield Application.Stop(content=None)


def run_helloworld():
arg_name_space: argparse.Namespace = build_arg_parser().parse_args()
parameters = {argparse.Namespace: arg_name_space}

factory = Factory(
log_providers,
ProviderGroup(
SingletonProvider(Application),
SingletonProvider(MessageRouter),
Echo,
Narc,
),
)

with multiple_constant_providers(factory, parameters):
factory[AppLogger].setLevel(arg_name_space.log_level.upper())
app = factory[Application]
echo = factory[Echo]
narc = factory[Narc]

# Handlers
app.register_handler(HelloWorldMessage, echo.echo)

# Daemons
app.register_daemon(narc.shout)
app.run()

0 comments on commit 05ac23d

Please sign in to comment.