Skip to content

Commit db6fe66

Browse files
committed
Simplest possible consumer
1 parent 60997d4 commit db6fe66

File tree

12 files changed

+257
-20
lines changed

12 files changed

+257
-20
lines changed

env.example

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@ DEBUG=False
22

33
BROKER_URL=amqp://guest:guest@localhost:5672/
44
BROKER_EXCHANGE=guest-exchange
5-
BROKER_ROUTING_KEYS=["test-create-event-key", "test-notification-event-key"]
5+
BROKER_QUEUE=websockets-notifications-queue
6+
BROKER_ROUTING_KEYS_CONSUME_FROM=["test-event-boobs"]
67

78
WEBSOCKETS_HOST=localhost
89
WEBSOCKETS_PORT=6789

src/app/conf/settings.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
class Settings(BaseSettings):
1010
BROKER_URL: AmqpDsn
1111
BROKER_EXCHANGE: str
12-
BROKER_ROUTING_KEYS: list[str]
12+
BROKER_QUEUE: str
13+
BROKER_ROUTING_KEYS_CONSUME_FROM: list[str]
1314
WEBSOCKETS_HOST: str
1415
WEBSOCKETS_PORT: int
1516
WEBSOCKETS_PATH: str

src/app/fixtures.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,15 @@
44

55

66
@pytest.fixture
7-
def ws():
8-
return MockedWebSocketServerProtocol()
7+
def create_ws():
8+
return lambda: MockedWebSocketServerProtocol()
99

1010

1111
@pytest.fixture
12-
def ya_ws():
13-
return MockedWebSocketServerProtocol()
12+
def ws(create_ws):
13+
return create_ws()
14+
15+
16+
@pytest.fixture
17+
def ya_ws(create_ws):
18+
return create_ws()

src/conftest.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ def settings(mocker):
1515
return_value=Settings(
1616
BROKER_URL="amqp://guest:guest@localhost/",
1717
BROKER_EXCHANGE="test-exchange",
18-
BROKER_ROUTING_KEYS=["test-routing-key", "ya-test-routing-key"],
18+
BROKER_QUEUE="test-queue",
19+
BROKER_ROUTING_KEYS_CONSUME_FROM=["test-routing-key", "ya-test-routing-key"],
1920
WEBSOCKETS_HOST="localhost",
2021
WEBSOCKETS_PORT=50000,
2122
WEBSOCKETS_PATH="/v2/test-subscription-websocket",

src/consumer/consumer.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import asyncio
2+
from dataclasses import dataclass
3+
import logging
4+
from typing import Protocol
5+
6+
import aio_pika
7+
8+
from app import conf
9+
from consumer.dto import ConsumedMessage
10+
from consumer.dto import OutgoingMessage
11+
from storage.subscription_storage import SubscriptionStorage
12+
from pydantic import ValidationError
13+
import websockets
14+
15+
16+
logger = logging.getLogger(__name__)
17+
18+
19+
class ConsumerProtocol(Protocol):
20+
async def consume(self) -> None:
21+
pass
22+
23+
24+
@dataclass
25+
class Consumer:
26+
storage: SubscriptionStorage
27+
28+
def __post_init__(self) -> None:
29+
settings = conf.get_app_settings()
30+
31+
self.broker_url: str = str(settings.BROKER_URL)
32+
self.exchange: str = settings.BROKER_EXCHANGE
33+
self.queue: str = settings.BROKER_QUEUE
34+
self.routing_keys_consume_from: list[str] = settings.BROKER_ROUTING_KEYS_CONSUME_FROM
35+
36+
async def consume(self, stop_signal: asyncio.Future) -> None:
37+
connection = await aio_pika.connect_robust(self.broker_url)
38+
39+
async with connection:
40+
channel = await connection.channel()
41+
42+
exchange = await channel.declare_exchange(self.exchange, type=aio_pika.ExchangeType.DIRECT)
43+
queue = await channel.declare_queue(name=self.queue, exclusive=True)
44+
45+
for routing_key in self.routing_keys_consume_from:
46+
await queue.bind(exchange=exchange, routing_key=routing_key)
47+
48+
await queue.consume(self.on_message)
49+
50+
await stop_signal
51+
52+
async def on_message(self, raw_message: aio_pika.abc.AbstractIncomingMessage) -> None:
53+
async with raw_message.process():
54+
processed_messages = self.parse_message(raw_message)
55+
56+
if processed_messages:
57+
self.broadcast_subscribers(self.storage, processed_messages)
58+
59+
@staticmethod
60+
def parse_message(raw_message: aio_pika.abc.AbstractIncomingMessage) -> ConsumedMessage | None:
61+
try:
62+
return ConsumedMessage.model_validate_json(raw_message.body)
63+
except ValidationError as exc:
64+
logger.error("Consumed message not in expected format. Errors: %s", exc.errors())
65+
return None
66+
67+
@staticmethod
68+
def broadcast_subscribers(storage: SubscriptionStorage, message: ConsumedMessage) -> None:
69+
websockets_to_broadcast = storage.get_event_subscribers_websockets(message.event)
70+
71+
if websockets_to_broadcast:
72+
outgoing_message = OutgoingMessage.model_construct(payload=message)
73+
websockets.broadcast(websockets=websockets_to_broadcast, message=outgoing_message.model_dump_json())

src/consumer/dto.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from pydantic import BaseModel
2+
from pydantic import ConfigDict
3+
from typing import Literal
4+
from app.types import Event
5+
6+
7+
class ConsumedMessage(BaseModel):
8+
model_config = ConfigDict(extra="allow")
9+
10+
event: Event
11+
12+
13+
class OutgoingMessage(BaseModel):
14+
message_type: Literal["EventNotification"] = "EventNotification"
15+
payload: ConsumedMessage

src/consumer/tests/conftest.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import pytest
2+
3+
from consumer.consumer import Consumer
4+
from dataclasses import dataclass
5+
from contextlib import asynccontextmanager
6+
from typing import AsyncGenerator
7+
8+
9+
@pytest.fixture(autouse=True)
10+
def _adjust_settings(settings):
11+
settings.BROKER_URL = "amqp://guest:guest@localhost/"
12+
settings.BROKER_EXCHANGE = "test-exchange"
13+
settings.BROKER_QUEUE = "test-queue"
14+
settings.BROKER_ROUTING_KEYS_CONSUME_FROM = [
15+
"event-routing-key",
16+
"ya-event-routing-key",
17+
]
18+
19+
20+
@pytest.fixture
21+
def consumer(storage) -> Consumer:
22+
return Consumer(storage=storage)
23+
24+
25+
@dataclass
26+
class MockedIncomingMessage:
27+
"""The simplest Incoming message class that represent incoming amqp message.
28+
29+
The safer choice is to use 'aio_pika.abc.AbstractIncomingMessage,' but the test setup will be significantly more challenging.
30+
"""
31+
32+
body: bytes
33+
34+
@asynccontextmanager
35+
async def process(self) -> AsyncGenerator:
36+
"""Do nothing, just for compatibility with aio_pika.abc.AbstractIncomingMessage."""
37+
yield None

src/consumer/tests/tests_consumer.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import asyncio
2+
import pytest
3+
4+
5+
@pytest.fixture
6+
def run_consume_task(consumer):
7+
def run(stop_signal: asyncio.Future):
8+
return asyncio.create_task(consumer.consume(stop_signal))
9+
10+
return run
11+
12+
13+
def test_consumer_attributes(consumer):
14+
assert consumer.broker_url == "amqp://guest:guest@localhost/"
15+
assert consumer.exchange == "test-exchange"
16+
assert consumer.queue == "test-queue"
17+
assert consumer.routing_keys_consume_from == [
18+
"event-routing-key",
19+
"ya-event-routing-key",
20+
]
21+
22+
23+
async def test_consumer_correctly_stopped_on_stop_signal(run_consume_task):
24+
stop_signal = asyncio.get_running_loop().create_future()
25+
consumer_task = run_consume_task(stop_signal)
26+
27+
stop_signal.set_result(None)
28+
29+
await asyncio.sleep(0.1) # get enough time to stop the task
30+
assert consumer_task.done() is True
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from contextlib import nullcontext as does_not_raise
2+
import json
3+
import pytest
4+
5+
from consumer.tests.conftest import MockedIncomingMessage
6+
7+
8+
@pytest.fixture(autouse=True)
9+
def mock_broadcast(mocker):
10+
return mocker.patch("websockets.broadcast")
11+
12+
13+
def python_to_bytes(data: dict) -> bytes:
14+
return json.dumps(data).encode()
15+
16+
17+
@pytest.fixture
18+
def broker_message_data(event):
19+
return {
20+
"event": event,
21+
"size": 3,
22+
"quantity": 2,
23+
}
24+
25+
26+
@pytest.fixture
27+
def ya_ws_subscribed(create_ws, ya_valid_token, ws_register_and_subscribe, event):
28+
return ws_register_and_subscribe(create_ws(), ya_valid_token, event)
29+
30+
31+
@pytest.fixture
32+
def ya_user_ws_subscribed(create_ws, ya_user_valid_token, ws_register_and_subscribe, event):
33+
return ws_register_and_subscribe(create_ws(), ya_user_valid_token, event)
34+
35+
36+
@pytest.fixture
37+
def consumed_message(broker_message_data):
38+
return MockedIncomingMessage(body=python_to_bytes(broker_message_data))
39+
40+
41+
@pytest.fixture
42+
def on_message(consumer, consumed_message):
43+
return lambda message=consumed_message: consumer.on_message(message)
44+
45+
46+
async def test_broadcast_message_to_subscriber_websockets(on_message, ws_subscribed, mock_broadcast, mocker):
47+
await on_message()
48+
49+
mock_broadcast.assert_called_once_with(websockets=[ws_subscribed], message=mocker.ANY)
50+
51+
52+
async def test_broadcast_message_to_all_subscribers_websockets(on_message, mock_broadcast, ws_subscribed, ya_ws_subscribed, ya_user_ws_subscribed, mocker):
53+
await on_message()
54+
55+
mock_broadcast.assert_called_once()
56+
broadcasted_websockets = mock_broadcast.call_args.kwargs["websockets"]
57+
assert len(broadcasted_websockets) == 3
58+
assert set(broadcasted_websockets) == {ws_subscribed, ya_ws_subscribed, ya_user_ws_subscribed}
59+
60+
61+
async def test_log_and_do_nothing_if_message_not_expected_format(on_message, ws_subscribed, mock_broadcast, consumed_message, caplog):
62+
consumed_message.body = b"invalid-json"
63+
64+
with does_not_raise():
65+
await on_message(consumed_message)
66+
67+
assert "Consumed message not in expected format" in caplog.text
68+
mock_broadcast.assert_not_called()

src/entrypoint.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,15 @@
22
import signal
33

44
import websockets
5+
import logging
56

67
from app import conf
78
from handlers import WebSocketsHandler
89
from handlers import WebSocketsAccessGuardian
910
from storage.subscription_storage import SubscriptionStorage
1011

12+
logging.basicConfig(level=logging.INFO)
13+
1114

1215
def create_stop_signal() -> asyncio.Future[None]:
1316
loop = asyncio.get_running_loop()

src/storage/fixtures.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,3 +66,13 @@ def subscribe(ws, event):
6666
def ws_subscribed(ws_registered, subscribe_ws, event):
6767
subscribe_ws(ws_registered, event)
6868
return ws_registered
69+
70+
71+
@pytest.fixture
72+
def ws_register_and_subscribe(register_ws, subscribe_ws):
73+
def register_and_subscribe(ws, token, event):
74+
register_ws(ws, token)
75+
subscribe_ws(ws, event)
76+
return ws
77+
78+
return register_and_subscribe

src/storage/subscription_storage.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,22 +26,15 @@ def get_websocket_user_id(self, websocket: WebSocketServerProtocol) -> UserId |
2626
websocket_meta = self.registered_websockets.get(websocket)
2727
return websocket_meta.user_id if websocket_meta else None
2828

29-
def get_event_subscribers_user_ids(self, event: Event) -> set[UserId]:
30-
return self.subscriptions.get(event) or set()
29+
def get_event_subscribers_websockets(self, event: Event) -> list[WebSocketServerProtocol]:
30+
subscribers_user_ids = self.subscriptions.get(event) or set()
3131

32-
def is_event_has_subscribers(self, event: Event) -> bool:
33-
return event in self.subscriptions
32+
user_websockets = []
3433

35-
def get_users_websockets(self, user_ids: set[UserId]) -> list[WebSocketServerProtocol]:
36-
users_websockets = []
34+
for user_id in subscribers_user_ids:
35+
user_websockets.extend(self.user_connections[user_id].websockets)
3736

38-
for user_id in user_ids:
39-
user_connection_meta = self.user_connections.get(user_id)
40-
41-
if user_connection_meta:
42-
users_websockets.extend(user_connection_meta.websockets)
43-
44-
return users_websockets
37+
return user_websockets
4538

4639
def get_expired_websockets(self) -> list[WebSocketServerProtocol]:
4740
now_timestamp = time.time()

0 commit comments

Comments
 (0)