Skip to content

Simplest possible consumer #11

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion env.example
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ DEBUG=False

BROKER_URL=amqp://guest:guest@localhost:5672/
BROKER_EXCHANGE=guest-exchange
BROKER_ROUTING_KEYS=["test-create-event-key", "test-notification-event-key"]
BROKER_QUEUE=websockets-notifications-queue
BROKER_ROUTING_KEYS_CONSUME_FROM=["test-event-boobs"]

WEBSOCKETS_HOST=localhost
WEBSOCKETS_PORT=6789
Expand Down
3 changes: 2 additions & 1 deletion src/app/conf/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
class Settings(BaseSettings):
BROKER_URL: AmqpDsn
BROKER_EXCHANGE: str
BROKER_ROUTING_KEYS: list[str]
BROKER_QUEUE: str
BROKER_ROUTING_KEYS_CONSUME_FROM: list[str]
WEBSOCKETS_HOST: str
WEBSOCKETS_PORT: int
WEBSOCKETS_PATH: str
Expand Down
13 changes: 9 additions & 4 deletions src/app/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,15 @@


@pytest.fixture
def ws():
return MockedWebSocketServerProtocol()
def create_ws():
return lambda: MockedWebSocketServerProtocol()


@pytest.fixture
def ya_ws():
return MockedWebSocketServerProtocol()
def ws(create_ws):
return create_ws()


@pytest.fixture
def ya_ws(create_ws):
return create_ws()
3 changes: 2 additions & 1 deletion src/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ def settings(mocker):
return_value=Settings(
BROKER_URL="amqp://guest:guest@localhost/",
BROKER_EXCHANGE="test-exchange",
BROKER_ROUTING_KEYS=["test-routing-key", "ya-test-routing-key"],
BROKER_QUEUE="test-queue",
BROKER_ROUTING_KEYS_CONSUME_FROM=["test-routing-key", "ya-test-routing-key"],
WEBSOCKETS_HOST="localhost",
WEBSOCKETS_PORT=50000,
WEBSOCKETS_PATH="/v2/test-subscription-websocket",
Expand Down
73 changes: 73 additions & 0 deletions src/consumer/consumer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import asyncio
from dataclasses import dataclass
import logging
from typing import Protocol

import aio_pika

from app import conf
from consumer.dto import ConsumedMessage
from consumer.dto import OutgoingMessage
from storage.subscription_storage import SubscriptionStorage
from pydantic import ValidationError
import websockets


logger = logging.getLogger(__name__)


class ConsumerProtocol(Protocol):
async def consume(self) -> None:
pass


@dataclass
class Consumer:
storage: SubscriptionStorage

def __post_init__(self) -> None:
settings = conf.get_app_settings()

self.broker_url: str = str(settings.BROKER_URL)
self.exchange: str = settings.BROKER_EXCHANGE
self.queue: str = settings.BROKER_QUEUE
self.routing_keys_consume_from: list[str] = settings.BROKER_ROUTING_KEYS_CONSUME_FROM

async def consume(self, stop_signal: asyncio.Future) -> None:
connection = await aio_pika.connect_robust(self.broker_url)

async with connection:
channel = await connection.channel()

exchange = await channel.declare_exchange(self.exchange, type=aio_pika.ExchangeType.DIRECT)
queue = await channel.declare_queue(name=self.queue, exclusive=True)

for routing_key in self.routing_keys_consume_from:
await queue.bind(exchange=exchange, routing_key=routing_key)

await queue.consume(self.on_message)

await stop_signal

async def on_message(self, raw_message: aio_pika.abc.AbstractIncomingMessage) -> None:
async with raw_message.process():
processed_messages = self.parse_message(raw_message)

if processed_messages:
self.broadcast_subscribers(self.storage, processed_messages)

@staticmethod
def parse_message(raw_message: aio_pika.abc.AbstractIncomingMessage) -> ConsumedMessage | None:
try:
return ConsumedMessage.model_validate_json(raw_message.body)
except ValidationError as exc:
logger.error("Consumed message not in expected format. Errors: %s", exc.errors())
return None

@staticmethod
def broadcast_subscribers(storage: SubscriptionStorage, message: ConsumedMessage) -> None:
websockets_to_broadcast = storage.get_event_subscribers_websockets(message.event)

if websockets_to_broadcast:
outgoing_message = OutgoingMessage.model_construct(payload=message)
websockets.broadcast(websockets=websockets_to_broadcast, message=outgoing_message.model_dump_json())
15 changes: 15 additions & 0 deletions src/consumer/dto.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from pydantic import BaseModel
from pydantic import ConfigDict
from typing import Literal
from app.types import Event


class ConsumedMessage(BaseModel):
model_config = ConfigDict(extra="allow")

event: Event


class OutgoingMessage(BaseModel):
message_type: Literal["EventNotification"] = "EventNotification"
payload: ConsumedMessage
37 changes: 37 additions & 0 deletions src/consumer/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import pytest

from consumer.consumer import Consumer
from dataclasses import dataclass
from contextlib import asynccontextmanager
from typing import AsyncGenerator


@pytest.fixture(autouse=True)
def _adjust_settings(settings):
settings.BROKER_URL = "amqp://guest:guest@localhost/"
settings.BROKER_EXCHANGE = "test-exchange"
settings.BROKER_QUEUE = "test-queue"
settings.BROKER_ROUTING_KEYS_CONSUME_FROM = [
"event-routing-key",
"ya-event-routing-key",
]


@pytest.fixture
def consumer(storage) -> Consumer:
return Consumer(storage=storage)


@dataclass
class MockedIncomingMessage:
"""The simplest Incoming message class that represent incoming amqp message.

The safer choice is to use 'aio_pika.abc.AbstractIncomingMessage,' but the test setup will be significantly more challenging.
"""

body: bytes

@asynccontextmanager
async def process(self) -> AsyncGenerator:
"""Do nothing, just for compatibility with aio_pika.abc.AbstractIncomingMessage."""
yield None
30 changes: 30 additions & 0 deletions src/consumer/tests/tests_consumer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import asyncio
import pytest


@pytest.fixture
def run_consume_task(consumer):
def run(stop_signal: asyncio.Future):
return asyncio.create_task(consumer.consume(stop_signal))

return run


def test_consumer_attributes(consumer):
assert consumer.broker_url == "amqp://guest:guest@localhost/"
assert consumer.exchange == "test-exchange"
assert consumer.queue == "test-queue"
assert consumer.routing_keys_consume_from == [
"event-routing-key",
"ya-event-routing-key",
]


async def test_consumer_correctly_stopped_on_stop_signal(run_consume_task):
stop_signal = asyncio.get_running_loop().create_future()
consumer_task = run_consume_task(stop_signal)

stop_signal.set_result(None)

await asyncio.sleep(0.1) # get enough time to stop the task
assert consumer_task.done() is True
68 changes: 68 additions & 0 deletions src/consumer/tests/tests_consumer_on_message.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from contextlib import nullcontext as does_not_raise
import json
import pytest

from consumer.tests.conftest import MockedIncomingMessage


@pytest.fixture(autouse=True)
def mock_broadcast(mocker):
return mocker.patch("websockets.broadcast")


def python_to_bytes(data: dict) -> bytes:
return json.dumps(data).encode()


@pytest.fixture
def broker_message_data(event):
return {
"event": event,
"size": 3,
"quantity": 2,
}


@pytest.fixture
def ya_ws_subscribed(create_ws, ya_valid_token, ws_register_and_subscribe, event):
return ws_register_and_subscribe(create_ws(), ya_valid_token, event)


@pytest.fixture
def ya_user_ws_subscribed(create_ws, ya_user_valid_token, ws_register_and_subscribe, event):
return ws_register_and_subscribe(create_ws(), ya_user_valid_token, event)


@pytest.fixture
def consumed_message(broker_message_data):
return MockedIncomingMessage(body=python_to_bytes(broker_message_data))


@pytest.fixture
def on_message(consumer, consumed_message):
return lambda message=consumed_message: consumer.on_message(message)


async def test_broadcast_message_to_subscriber_websockets(on_message, ws_subscribed, mock_broadcast, mocker):
await on_message()

mock_broadcast.assert_called_once_with(websockets=[ws_subscribed], message=mocker.ANY)


async def test_broadcast_message_to_all_subscribers_websockets(on_message, mock_broadcast, ws_subscribed, ya_ws_subscribed, ya_user_ws_subscribed, mocker):
await on_message()

mock_broadcast.assert_called_once()
broadcasted_websockets = mock_broadcast.call_args.kwargs["websockets"]
assert len(broadcasted_websockets) == 3
assert set(broadcasted_websockets) == {ws_subscribed, ya_ws_subscribed, ya_user_ws_subscribed}


async def test_log_and_do_nothing_if_message_not_expected_format(on_message, ws_subscribed, mock_broadcast, consumed_message, caplog):
consumed_message.body = b"invalid-json"

with does_not_raise():
await on_message(consumed_message)

assert "Consumed message not in expected format" in caplog.text
mock_broadcast.assert_not_called()
3 changes: 3 additions & 0 deletions src/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@
import signal

import websockets
import logging

from app import conf
from handlers import WebSocketsHandler
from handlers import WebSocketsAccessGuardian
from storage.subscription_storage import SubscriptionStorage

logging.basicConfig(level=logging.INFO)


def create_stop_signal() -> asyncio.Future[None]:
loop = asyncio.get_running_loop()
Expand Down
10 changes: 10 additions & 0 deletions src/storage/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,13 @@ def subscribe(ws, event):
def ws_subscribed(ws_registered, subscribe_ws, event):
subscribe_ws(ws_registered, event)
return ws_registered


@pytest.fixture
def ws_register_and_subscribe(register_ws, subscribe_ws):
def register_and_subscribe(ws, token, event):
register_ws(ws, token)
subscribe_ws(ws, event)
return ws

return register_and_subscribe
19 changes: 6 additions & 13 deletions src/storage/subscription_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,15 @@ def get_websocket_user_id(self, websocket: WebSocketServerProtocol) -> UserId |
websocket_meta = self.registered_websockets.get(websocket)
return websocket_meta.user_id if websocket_meta else None

def get_event_subscribers_user_ids(self, event: Event) -> set[UserId]:
return self.subscriptions.get(event) or set()
def get_event_subscribers_websockets(self, event: Event) -> list[WebSocketServerProtocol]:
subscribers_user_ids = self.subscriptions.get(event) or set()

def is_event_has_subscribers(self, event: Event) -> bool:
return event in self.subscriptions
user_websockets = []

def get_users_websockets(self, user_ids: set[UserId]) -> list[WebSocketServerProtocol]:
users_websockets = []
for user_id in subscribers_user_ids:
user_websockets.extend(self.user_connections[user_id].websockets)

for user_id in user_ids:
user_connection_meta = self.user_connections.get(user_id)

if user_connection_meta:
users_websockets.extend(user_connection_meta.websockets)

return users_websockets
return user_websockets

def get_expired_websockets(self) -> list[WebSocketServerProtocol]:
now_timestamp = time.time()
Expand Down