Skip to content

Commit

Permalink
Rename interface of Coordinator
Browse files Browse the repository at this point in the history
subscriber -> receiver to distinguish RMQ communicator.
  • Loading branch information
unkcpz committed Feb 21, 2025
1 parent f28a5ab commit 031906d
Show file tree
Hide file tree
Showing 9 changed files with 200 additions and 1,123 deletions.
39 changes: 19 additions & 20 deletions src/plumpy/coordinator.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,42 @@
# -*- coding: utf-8 -*-
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Callable, Hashable, Pattern, Protocol
from typing import TYPE_CHECKING, Any, Callable, Hashable, Protocol
from re import Pattern

if TYPE_CHECKING:
# identifiers for subscribers
ID_TYPE = Hashable
Subscriber = Callable[..., Any]
# RPC subscriber params: communicator, msg
RpcSubscriber = Callable[[Any], Any]
# Task subscriber params: communicator, task
TaskSubscriber = Callable[[Any], Any]
# Broadcast subscribers params: communicator, body, sender, subject, correlation id
BroadcastSubscriber = Callable[[Any, Any, Any, ID_TYPE], Any]
Receiver = Callable[..., Any]


class Coordinator(Protocol):
# XXX: naming - 'add_message_handler'
def add_rpc_subscriber(self, subscriber: 'RpcSubscriber', identifier: 'ID_TYPE | None' = None) -> Any: ...
def hook_rpc_receiver(
self,
receiver: 'Receiver',
identifier: 'ID_TYPE | None' = None,
) -> Any: ...

# XXX: naming - 'add_broadcast_handler'
def add_broadcast_subscriber(
def hook_broadcast_receiver(
self,
subscriber: 'BroadcastSubscriber',
receiver: 'Receiver',
subject_filters: list[Hashable | Pattern[str]] | None = None,
sender_filters: list[Hashable | Pattern[str]] | None = None,
identifier: 'ID_TYPE | None' = None,
) -> Any: ...

# XXX: naming - absorbed into 'add_message_handler'
def add_task_subscriber(self, subscriber: 'TaskSubscriber', identifier: 'ID_TYPE | None' = None) -> 'ID_TYPE': ...
def hook_task_receiver(
self,
receiver: 'Receiver',
identifier: 'ID_TYPE | None' = None,
) -> 'ID_TYPE': ...

def remove_rpc_subscriber(self, identifier: 'ID_TYPE | None') -> None: ...
def unhook_rpc_receiver(self, identifier: 'ID_TYPE | None') -> None: ...

def remove_broadcast_subscriber(self, identifier: 'ID_TYPE | None') -> None: ...
def unhook_broadcast_receiver(self, identifier: 'ID_TYPE | None') -> None: ...

def remove_task_subscriber(self, identifier: 'ID_TYPE') -> None: ...
def unhook_task_receiver(self, identifier: 'ID_TYPE') -> None: ...

def rpc_send(self, recipient_id: Hashable, msg: Any) -> Any: ...
def rpc_send(self, recipient_id: Hashable, msg: Any,) -> Any: ...

def broadcast_send(
self,
Expand Down
10 changes: 5 additions & 5 deletions src/plumpy/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,18 +322,18 @@ def init(self) -> None:

if self._coordinator is not None:
try:
identifier = self._coordinator.add_rpc_subscriber(self.message_receive, identifier=str(self.pid))
self.add_cleanup(functools.partial(self._coordinator.remove_rpc_subscriber, identifier))
identifier = self._coordinator.hook_rpc_receiver(self.message_receive, identifier=str(self.pid))
self.add_cleanup(functools.partial(self._coordinator.unhook_rpc_receiver, identifier))
except concurrent.futures.TimeoutError:
self.logger.exception('Process<%s>: failed to register as an RPC subscriber', self.pid)
# XXX: handle duplicate subscribing here: see aiida-core test_duplicate_subscriber_identifier.

# XXX: handle duplicate subscribing here: see aiida-core test_duplicate_subscriber_identifier.
try:
# filter out state change broadcasts
subscriber = BroadcastFilter(self.broadcast_receive, subject=re.compile(r'^(?!state_changed).*'))
identifier = self._coordinator.add_broadcast_subscriber(subscriber, identifier=str(self.pid))
identifier = self._coordinator.hook_broadcast_receiver(subscriber, identifier=str(self.pid))

self.add_cleanup(functools.partial(self._coordinator.remove_broadcast_subscriber, identifier))
self.add_cleanup(functools.partial(self._coordinator.unhook_broadcast_receiver, identifier))
except concurrent.futures.TimeoutError:
self.logger.exception('Process<%s>: failed to register as a broadcast subscriber', self.pid)

Expand Down
60 changes: 33 additions & 27 deletions tests/rmq/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
from re import Pattern
from typing import TYPE_CHECKING, Generic, Hashable, TypeVar, final
from typing import TYPE_CHECKING, Any, Callable, Generic, Hashable, TypeVar, final
import kiwipy
import concurrent.futures

from plumpy.exceptions import CoordinatorConnectionError

if TYPE_CHECKING:
ID_TYPE = Hashable
BroadcastSubscriber = Callable[[Any, Any, Any, ID_TYPE], Any]
Receiver = Callable[..., Any]

U = TypeVar('U', bound=kiwipy.Communicator)

Expand All @@ -24,54 +24,61 @@ def communicator(self) -> U:
"""The inner communicator."""
return self._comm

# XXX: naming - `add_receiver_rpc`
def add_rpc_subscriber(self, subscriber, identifier=None):
def hook_rpc_receiver(
self,
receiver: 'Receiver',
identifier: 'ID_TYPE | None' = None,
) -> Any:
def _subscriber(_, *args, **kwargs):
return subscriber(*args, **kwargs)
return receiver(*args, **kwargs)

return self._comm.add_rpc_subscriber(_subscriber, identifier)

# XXX: naming - `add_receiver_broadcast`
def add_broadcast_subscriber(
def hook_broadcast_receiver(
self,
subscriber: 'BroadcastSubscriber',
receiver: 'Receiver',
subject_filters: list[Hashable | Pattern[str]] | None = None,
sender_filters: list[Hashable | Pattern[str]] | None = None,
identifier: 'ID_TYPE | None' = None,
):
) -> Any:
def _subscriber(_, *args, **kwargs):
return subscriber(*args, **kwargs)
return receiver(*args, **kwargs)

return self._comm.add_broadcast_subscriber(_subscriber, identifier)

# XXX: naming - `add_reciver_task` (can be combined with two above maybe??)
def add_task_subscriber(self, subscriber, identifier=None):
def hook_task_receiver(
self,
receiver: 'Receiver',
identifier: 'ID_TYPE | None' = None,
) -> 'ID_TYPE':
async def _subscriber(_comm, *args, **kwargs):
return await subscriber(*args, **kwargs)
return await receiver(*args, **kwargs)

return self._comm.add_task_subscriber(_subscriber, identifier)

def remove_rpc_subscriber(self, identifier):
def unhook_rpc_receiver(self, identifier: 'ID_TYPE | None') -> None:
return self._comm.remove_rpc_subscriber(identifier)

def remove_broadcast_subscriber(self, identifier):
def unhook_broadcast_receiver(self, identifier: 'ID_TYPE | None') -> None:
return self._comm.remove_broadcast_subscriber(identifier)

def remove_task_subscriber(self, identifier):
def unhook_task_receiver(self, identifier: 'ID_TYPE') -> None:
return self._comm.remove_task_subscriber(identifier)

# XXX: naming - `send_to`
def rpc_send(self, recipient_id, msg):
def rpc_send(
self,
recipient_id: Hashable,
msg: Any,
) -> Any:
return self._comm.rpc_send(recipient_id, msg)

# XXX: naming - `broadcast`
def broadcast_send(
self,
body,
sender=None,
subject=None,
correlation_id=None,
):
body: Any | None,
sender: 'ID_TYPE | None' = None,
subject: str | None = None,
correlation_id: 'ID_TYPE | None' = None,
) -> Any:
from aio_pika.exceptions import ChannelInvalidStateError, AMQPConnectionError

try:
Expand All @@ -81,9 +88,8 @@ def broadcast_send(
else:
return rsp

# XXX: naming - `assign_task` (this may able to be combined with send_to)
def task_send(self, task, no_reply=False):
def task_send(self, task: Any, no_reply: bool = False) -> Any:
return self._comm.task_send(task, no_reply)

def close(self):
def close(self) -> None:
self._comm.close()
43 changes: 22 additions & 21 deletions tests/rmq/test_communications.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def broadcast_send(self, body, sender=None, subject=None, correlation_id=None):


@pytest.fixture
def subscriber():
def receiver_fn():
"""Return an instance of mocked `Subscriber`."""

class Subscriber:
Expand All @@ -43,40 +43,41 @@ def __call__(self):
return Subscriber()


def test_add_rpc_subscriber(_coordinator, subscriber):
"""Test the `LoopCommunicator.add_rpc_subscriber` method."""
assert _coordinator.add_rpc_subscriber(subscriber) is not None
def test_hook_rpc_receiver(_coordinator, receiver_fn):
"""Test the `LoopCommunicator.add_rpc_receiver` method."""
assert _coordinator.hook_rpc_receiver(receiver_fn) is not None

identifier = 'identifier'
assert _coordinator.add_rpc_subscriber(subscriber, identifier) == identifier
assert _coordinator.hook_rpc_receiver(receiver_fn, identifier) == identifier


def test_remove_rpc_subscriber(_coordinator, subscriber):
def test_unhook_rpc_receiver(_coordinator, receiver_fn):
"""Test the `LoopCommunicator.remove_rpc_subscriber` method."""
identifier = _coordinator.add_rpc_subscriber(subscriber)
_coordinator.remove_rpc_subscriber(identifier)
identifier = _coordinator.hook_rpc_receiver(receiver_fn)
_coordinator.unhook_rpc_receiver(identifier)


def test_add_broadcast_subscriber(_coordinator, subscriber):
"""Test the `LoopCommunicator.add_broadcast_subscriber` method."""
assert _coordinator.add_broadcast_subscriber(subscriber) is not None
def test_hook_broadcast_receiver(_coordinator, receiver_fn):
"""Test the coordinator hook_broadcast_receiver which calls
`LoopCommunicator.add_broadcast_subscriber` method."""
assert _coordinator.hook_broadcast_receiver(receiver_fn) is not None

identifier = 'identifier'
assert _coordinator.add_broadcast_subscriber(subscriber, identifier=identifier) == identifier
assert _coordinator.hook_broadcast_receiver(receiver_fn, identifier=identifier) == identifier


def test_remove_broadcast_subscriber(_coordinator, subscriber):
def test_unhook_broadcast_receiver(_coordinator, receiver_fn):
"""Test the `LoopCommunicator.remove_broadcast_subscriber` method."""
identifier = _coordinator.add_broadcast_subscriber(subscriber)
_coordinator.remove_broadcast_subscriber(identifier)
identifier = _coordinator.hook_broadcast_receiver(receiver_fn)
_coordinator.unhook_broadcast_receiver(identifier)


def test_add_task_subscriber(_coordinator, subscriber):
"""Test the `LoopCommunicator.add_task_subscriber` method."""
assert _coordinator.add_task_subscriber(subscriber) is not None
def test_hook_task_receiver(_coordinator, receiver_fn):
"""Test the hook_task_receiver calls `LoopCommunicator.add_task_subscriber` method."""
assert _coordinator.hook_task_receiver(receiver_fn) is not None


def test_remove_task_subscriber(_coordinator, subscriber):
def test_unhook_task_receiver(_coordinator, receiver_fn):
"""Test the `LoopCommunicator.remove_task_subscriber` method."""
identifier = _coordinator.add_task_subscriber(subscriber)
_coordinator.remove_task_subscriber(identifier)
identifier = _coordinator.hook_task_receiver(receiver_fn)
_coordinator.unhook_task_receiver(identifier)
22 changes: 11 additions & 11 deletions tests/rmq/test_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def get_broadcast(body, sender, subject, correlation_id):
{'body': body, 'sender': sender, 'subject': subject, 'correlation_id': correlation_id}
)

_coordinator.add_broadcast_subscriber(get_broadcast)
_coordinator.hook_broadcast_receiver(get_broadcast)
_coordinator.broadcast_send(**BROADCAST)

result = await broadcast_future
Expand All @@ -96,8 +96,8 @@ def ignore_broadcast(body, sender, subject, correlation_id):
def get_broadcast(body, sender, subject, correlation_id):
broadcast_future.set_result(True)

_coordinator.add_broadcast_subscriber(BroadcastFilter(ignore_broadcast, subject='other'))
_coordinator.add_broadcast_subscriber(get_broadcast)
_coordinator.hook_broadcast_receiver(BroadcastFilter(ignore_broadcast, subject='other'))
_coordinator.hook_broadcast_receiver(get_broadcast)
_coordinator.broadcast_send(**{'body': 'present', 'sender': 'Martin', 'subject': 'sup', 'correlation_id': 420})

result = await broadcast_future
Expand All @@ -114,7 +114,7 @@ def get_rpc(msg):
assert loop is asyncio.get_event_loop()
rpc_future.set_result(msg)

_coordinator.add_rpc_subscriber(get_rpc, 'rpc')
_coordinator.hook_rpc_receiver(get_rpc, 'rpc')
_coordinator.rpc_send('rpc', MSG)

result = await rpc_future
Expand All @@ -131,7 +131,7 @@ def get_task(msg):
assert loop is asyncio.get_event_loop()
task_future.set_result(msg)

_coordinator.add_task_subscriber(get_task)
_coordinator.hook_task_receiver(get_task)
_coordinator.task_send(TASK)

# TODO: Error in the event loop log although the test pass
Expand All @@ -146,7 +146,7 @@ async def test_launch(self, _coordinator, async_controller, persister):
# Let the process run to the end
loop = asyncio.get_event_loop()
launcher = plumpy.ProcessLauncher(loop, persister=persister)
_coordinator.add_task_subscriber(launcher.call)
_coordinator.hook_task_receiver(launcher.call)
result = await async_controller.launch_process(utils.DummyProcess)
# Check that we got a result
assert result == utils.DummyProcess.EXPECTED_OUTPUTS
Expand All @@ -156,7 +156,7 @@ async def test_launch_nowait(self, _coordinator, async_controller, persister):
"""Testing launching but don't wait, just get the pid"""
loop = asyncio.get_event_loop()
launcher = plumpy.ProcessLauncher(loop, persister=persister)
_coordinator.add_task_subscriber(launcher.call)
_coordinator.hook_task_receiver(launcher.call)
pid = await async_controller.launch_process(utils.DummyProcess, nowait=True)
assert isinstance(pid, uuid.UUID)

Expand All @@ -165,7 +165,7 @@ async def test_execute_action(self, _coordinator, async_controller, persister):
"""Test the process execute action"""
loop = asyncio.get_event_loop()
launcher = plumpy.ProcessLauncher(loop, persister=persister)
_coordinator.add_task_subscriber(launcher.call)
_coordinator.hook_task_receiver(launcher.call)
result = await async_controller.execute_process(utils.DummyProcessWithOutput)
assert utils.DummyProcessWithOutput.EXPECTED_OUTPUTS == result

Expand All @@ -174,7 +174,7 @@ async def test_execute_action_nowait(self, _coordinator, async_controller, persi
"""Test the process execute action"""
loop = asyncio.get_event_loop()
launcher = plumpy.ProcessLauncher(loop, persister=persister)
_coordinator.add_task_subscriber(launcher.call)
_coordinator.hook_task_receiver(launcher.call)
pid = await async_controller.execute_process(utils.DummyProcessWithOutput, nowait=True)
assert isinstance(pid, uuid.UUID)

Expand All @@ -183,7 +183,7 @@ async def test_launch_many(self, _coordinator, async_controller, persister):
"""Test launching multiple processes"""
loop = asyncio.get_event_loop()
launcher = plumpy.ProcessLauncher(loop, persister=persister)
_coordinator.add_task_subscriber(launcher.call)
_coordinator.hook_task_receiver(launcher.call)
num_to_launch = 10

launch_futures = []
Expand All @@ -200,7 +200,7 @@ async def test_continue(self, _coordinator, async_controller, persister):
"""Test continuing a saved process"""
loop = asyncio.get_event_loop()
launcher = plumpy.ProcessLauncher(loop, persister=persister)
_coordinator.add_task_subscriber(launcher.call)
_coordinator.hook_task_receiver(launcher.call)
process = utils.DummyProcessWithOutput()
persister.save_checkpoint(process)
pid = process.pid
Expand Down
2 changes: 1 addition & 1 deletion tests/rmq/test_process_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def test_broadcast(self, _coordinator):
def on_broadcast_receive(**msg):
messages.append(msg)

_coordinator.add_broadcast_subscriber(on_broadcast_receive)
_coordinator.hook_broadcast_receiver(on_broadcast_receive)

proc = utils.DummyProcess(coordinator=_coordinator)
proc.execute()
Expand Down
2 changes: 1 addition & 1 deletion tests/test_processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1071,7 +1071,7 @@ def test_broadcast(self):
def on_broadcast_receive(body, sender, subject, correlation_id):
messages.append({'body': body, 'subject': subject, 'sender': sender, 'correlation_id': correlation_id})

coordinator.add_broadcast_subscriber(on_broadcast_receive)
coordinator.hook_broadcast_receiver(on_broadcast_receive)
proc = utils.DummyProcess(coordinator=coordinator)
proc.execute()

Expand Down
Loading

0 comments on commit 031906d

Please sign in to comment.