Skip to content

Commit 0444960

Browse files
committed
Clean up minor issues in ZMQ broker
- Remove AiiDA import from server.py, default to json encoder/decoder; service.py passes YAML encoder/decoder explicitly - Remove unused threading lock from server - Remove unused dataclasses from protocol.py (all construction uses make_* factory functions) - Move handler dispatch dict to instance attribute (avoid per-message allocation) - Replace hasattr broker checks with isinstance(broker, ZmqBroker)
1 parent 3a54329 commit 0444960

5 files changed

Lines changed: 31 additions & 93 deletions

File tree

src/aiida/brokers/zmq/protocol.py

Lines changed: 0 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@
4747

4848
import json
4949
import uuid
50-
from dataclasses import dataclass, field
5150
from enum import Enum
5251
from typing import Any
5352

@@ -77,71 +76,6 @@ class MessageType(str, Enum):
7776
PING = 'ping'
7877

7978

80-
@dataclass
81-
class Message:
82-
"""Base message structure."""
83-
84-
type: MessageType
85-
id: str = field(default_factory=lambda: uuid.uuid4().hex)
86-
sender: str = ''
87-
88-
89-
@dataclass
90-
class TaskMessage(Message):
91-
"""Task message sent to the broker for processing."""
92-
93-
type: MessageType = field(default=MessageType.TASK, init=False)
94-
body: Any = None
95-
no_reply: bool = False
96-
97-
98-
@dataclass
99-
class TaskResponse(Message):
100-
"""Response to a task message."""
101-
102-
type: MessageType = field(default=MessageType.TASK_RESPONSE, init=False)
103-
task_id: str = ''
104-
result: Any = None
105-
error: str | None = None
106-
107-
108-
@dataclass
109-
class RpcMessage(Message):
110-
"""RPC message sent to a specific recipient."""
111-
112-
type: MessageType = field(default=MessageType.RPC, init=False)
113-
recipient: str = ''
114-
body: Any = None
115-
116-
117-
@dataclass
118-
class RpcResponse(Message):
119-
"""Response to an RPC message."""
120-
121-
type: MessageType = field(default=MessageType.RPC_RESPONSE, init=False)
122-
rpc_id: str = ''
123-
result: Any = None
124-
error: str | None = None
125-
126-
127-
@dataclass
128-
class BroadcastMessage(Message):
129-
"""Broadcast message sent to all subscribers."""
130-
131-
type: MessageType = field(default=MessageType.BROADCAST, init=False)
132-
body: Any = None
133-
subject: str | None = None
134-
correlation_id: str | None = None
135-
136-
137-
@dataclass
138-
class SubscribeMessage(Message):
139-
"""Subscription request message."""
140-
141-
type: MessageType = MessageType.SUBSCRIBE_TASK
142-
identifier: str | None = None
143-
144-
14579
def encode_message(msg: dict, encoder=json.dumps) -> bytes:
14680
"""Encode a message dictionary to bytes."""
14781
return encoder(msg).encode('utf-8')

src/aiida/brokers/zmq/server.py

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,22 @@
1-
"""ZeroMQ Broker Server - standalone message broker.
1+
"""ZeroMQ Broker Server.
22
3-
This module is completely independent of AiiDA and can be used as a standalone
4-
message broker server. It handles:
3+
Can be started as a standalone message broker process. It handles:
54
- Task queue management with persistence
65
- Request/reply routing for RPC
76
- Broadcast distribution
87
"""
98

109
from __future__ import annotations
1110

11+
import json
1212
import logging
13-
import threading
1413
import time
1514
from collections import deque
1615
from pathlib import Path
1716
from typing import Any, Callable
1817

1918
import zmq
2019

21-
from aiida.brokers.utils import YAML_DECODER, YAML_ENCODER
22-
2320
from .protocol import MessageType, decode_message, encode_message
2421
from .queue import PersistentQueue
2522

@@ -54,8 +51,8 @@ def __init__(
5451
:param encoder: Function to encode messages (default: yaml.dump)
5552
:param decoder: Function to decode messages (default: yaml.load)
5653
"""
57-
encoder = encoder if encoder is not None else YAML_ENCODER
58-
decoder = decoder if decoder is not None else YAML_DECODER
54+
encoder = encoder if encoder is not None else json.dumps
55+
decoder = decoder if decoder is not None else json.loads
5956
self._storage_path = Path(storage_path)
6057
self._sockets_path = Path(sockets_path)
6158

@@ -96,7 +93,21 @@ def __init__(
9693

9794
# Server state
9895
self._running = False
99-
self._lock = threading.Lock()
96+
97+
# Message type -> handler mapping (built once, not per message)
98+
self._handlers: dict[str, Callable] = {
99+
MessageType.TASK.value: self._handle_task,
100+
MessageType.TASK_RESPONSE.value: self._handle_task_response,
101+
MessageType.TASK_ACK.value: self._handle_task_ack,
102+
MessageType.TASK_NACK.value: self._handle_task_nack,
103+
MessageType.RPC.value: self._handle_rpc,
104+
MessageType.RPC_RESPONSE.value: self._handle_rpc_response,
105+
MessageType.BROADCAST.value: self._handle_broadcast,
106+
MessageType.SUBSCRIBE_TASK.value: self._handle_subscribe_task,
107+
MessageType.SUBSCRIBE_RPC.value: self._handle_subscribe_rpc,
108+
MessageType.UNSUBSCRIBE_TASK.value: self._handle_unsubscribe_task,
109+
MessageType.UNSUBSCRIBE_RPC.value: self._handle_unsubscribe_rpc,
110+
}
100111

101112
@property
102113
def storage_path(self) -> Path:
@@ -257,22 +268,7 @@ def _handle_router_message(self) -> None:
257268
_LOGGER.warning('Message missing type field')
258269
return
259270

260-
# Route by message type
261-
handlers: dict[str, Any] = {
262-
MessageType.TASK.value: self._handle_task,
263-
MessageType.TASK_RESPONSE.value: self._handle_task_response,
264-
MessageType.TASK_ACK.value: self._handle_task_ack,
265-
MessageType.TASK_NACK.value: self._handle_task_nack,
266-
MessageType.RPC.value: self._handle_rpc,
267-
MessageType.RPC_RESPONSE.value: self._handle_rpc_response,
268-
MessageType.BROADCAST.value: self._handle_broadcast,
269-
MessageType.SUBSCRIBE_TASK.value: self._handle_subscribe_task,
270-
MessageType.SUBSCRIBE_RPC.value: self._handle_subscribe_rpc,
271-
MessageType.UNSUBSCRIBE_TASK.value: self._handle_unsubscribe_task,
272-
MessageType.UNSUBSCRIBE_RPC.value: self._handle_unsubscribe_rpc,
273-
}
274-
275-
handler = handlers.get(msg_type)
271+
handler = self._handlers.get(msg_type)
276272
if handler:
277273
handler(identity, msg)
278274
else:

src/aiida/brokers/zmq/service.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,13 @@ def start(self) -> None:
6666
self._sockets_path = Path(tempfile.mkdtemp(prefix='aiida_zmq_'))
6767
self._sockets_file.write_text(str(self._sockets_path))
6868

69+
from aiida.brokers.utils import YAML_DECODER, YAML_ENCODER
70+
6971
self._server = ZmqBrokerServer(
7072
storage_path=self._storage_path,
7173
sockets_path=self._sockets_path,
74+
encoder=YAML_ENCODER,
75+
decoder=YAML_DECODER,
7276
)
7377

7478
self._pid_file.write_text(str(os.getpid()))

src/aiida/cmdline/commands/cmd_daemon.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,9 @@ def status(ctx, all_profiles, timeout):
151151
# Build broker status line for managed brokers (e.g., ZMQ)
152152
broker_line = ''
153153
broker = get_manager().get_broker()
154-
if broker is not None and hasattr(broker, 'get_service_status'):
154+
from aiida.brokers.zmq.broker import ZmqBroker
155+
156+
if isinstance(broker, ZmqBroker):
155157
if broker.is_running():
156158
status_info = broker.get_service_status()
157159
if status_info:

src/aiida/cmdline/commands/cmd_status.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,9 @@ def verdi_status(print_traceback, no_rmq):
156156
else:
157157
daemon_msg = f'Daemon is running with PID {daemon_status["pid"]}'
158158
# Append broker info for managed brokers (e.g., ZMQ)
159-
if hasattr(broker, 'get_service_status'):
159+
from aiida.brokers.zmq.broker import ZmqBroker
160+
161+
if isinstance(broker, ZmqBroker):
160162
if broker.is_running():
161163
status_info = broker.get_service_status()
162164
if status_info:

0 commit comments

Comments
 (0)