From da48ce996a19d0fe5a8ed80ac66213a031117079 Mon Sep 17 00:00:00 2001 From: Charles Powell Date: Mon, 1 Jan 2024 10:03:00 -0700 Subject: [PATCH] Upgrade to aiomqtt (fixes #23) --- senselink/mqtt/mqtt_controller.py | 114 +++++++++++------------------- setup.py | 8 +-- 2 files changed, 44 insertions(+), 78 deletions(-) diff --git a/senselink/mqtt/mqtt_controller.py b/senselink/mqtt/mqtt_controller.py index eb98186..e9028d3 100644 --- a/senselink/mqtt/mqtt_controller.py +++ b/senselink/mqtt/mqtt_controller.py @@ -2,9 +2,8 @@ import logging import asyncio +from aiomqtt import Client, MqttError from typing import Dict -from asyncio_mqtt import Client, MqttError -from contextlib import AsyncExitStack from .mqtt_listener import MQTTListener @@ -12,17 +11,6 @@ MQTT_LOGGER.setLevel(logging.WARNING) -async def cancel_tasks(tasks): - for task in tasks: - if task.done(): - continue - task.cancel() - try: - await task - except asyncio.CancelledError: - pass - - class MQTTController: client = None topics: Dict[str, MQTTListener] = None @@ -34,7 +22,7 @@ def __init__(self, host, port=1883, username=None, password=None): self.password = password self.data_sources = [] - self.topics = {} + self.listeners = {} async def connect(self): # Create task @@ -43,9 +31,12 @@ async def connect(self): async def client_handler(self): logging.info(f"Starting MQTT client to URL: {self.host}") reconnect_interval = 5 # [seconds] + + client = Client(self.host, self.port, username=self.username, password=self.password) while True: try: - await self.listen() + async with client: + await self.listen(client) except MqttError as error: logging.error(f'Disconnected from MQTT broker with error: {error}') logging.debug(f'MQTT client disconnected/ended, reconnecting in {reconnect_interval}...') @@ -56,65 +47,40 @@ async def client_handler(self): logging.error(f'Stopping MQTT client with error: {error}') return False - async def listen(self): - async with AsyncExitStack() as stack: - # Track tasks - tasks = set() - stack.push_async_callback(cancel_tasks, tasks) - - # Connect to the MQTT broker - client = Client(self.host, self.port, username=self.username, password=self.password) - await stack.enter_async_context(client) - - logging.info(f'MQTT client connected') - # Add tasks for each data source handler - for ds in self.data_sources: - # Get handlers from data source - ds_listeners = ds.listeners() - # Iterate through data source listeners and convert to - # 'prime' listeners for each topic - for listener in ds_listeners: - topic = listener.topic - funcs = listener.handlers - if topic in self.topics: - # Add these handlers to existing top level topic handler - logging.debug(f'Adding handlers for existing prime Listener: {topic}') - ext_topic = self.topics[topic] - ext_topic.handlers.extend(funcs) - else: - # Add this instance as a new top level handler - logging.debug(f'Creating new prime Listener for topic: {topic}') - self.topics[topic] = MQTTListener(topic, funcs) - - # Add handlers for each topic as a filtered topic - for topic, listener in self.topics.items(): - manager = client.filtered_messages(topic) - messages = await stack.enter_async_context(manager) - task = asyncio.create_task(self.parse_messages(messages)) - tasks.add(task) - - # Subscribe to all topics - # Assume QoS 0 for now - all_topics = [(t, 0) for t in self.topics.keys()] - logging.info(f'Subscribing to MQTT {len(all_topics)} topic(s)') - logging.debug(f'Topics: {all_topics}') - try: - await client.subscribe(all_topics) - except ValueError as err: - logging.error(f'MQTT Subscribe error: {err}') - - # Gather all tasks - await asyncio.gather(*tasks) - logging.info(f'Listening for MQTT updates') - - async def parse_messages(self, messages): - async for message in messages: - topic = message.topic - # Get handlers and iterate through - listener = self.topics[topic] - for func in listener.handlers: - # Decode to UTF-8 - await func(message.payload.decode()) + async def listen(self, client): + logging.info(f'MQTT client connected') + # Add tasks for each data source handler + for ds in self.data_sources: + # Get handlers from data source + ds_listeners = ds.listeners() + # Iterate through data source listeners and convert to + # 'prime' listeners for each topic + for listener in ds_listeners: + topic = listener.topic + funcs = listener.handlers + if topic in self.listeners: + # Add these handlers to existing top level topic handler + logging.debug(f'Adding handlers for existing prime Listener: {topic}') + ext_topic = self.listeners[topic] + ext_topic.handlers.extend(funcs) + else: + # Add this instance as a new top level handler + logging.debug(f'Creating new prime Listener for topic: {topic}') + self.listeners[topic] = MQTTListener(topic, funcs) + + async with client.messages() as messages: + # Subscribe to specified topics + for topic, handlers in self.listeners.items(): + await client.subscribe(topic) + # Handle messages that come in + async for message in messages: + topic = message.topic.value + handlers = self.listeners[topic].handlers + logging.debug(f'Got message for topic: {topic}') + for func in handlers: + # Decode to UTF-8 + payload = message.payload.decode() + await func(payload) if __name__ == "__main__": diff --git a/setup.py b/setup.py index 9e7c5b1..78fac06 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setup( name='SenseLink', - version='2.2.0', + version='2.2.1', description='A tool to create virtual smart plugs and inform a Sense Home Energy Monitor about usage in your home', long_description=long_description, long_description_content_type="text/markdown", @@ -14,10 +14,10 @@ author_email='cbpowell@gmail.com', license='MIT', packages=find_packages(), - install_requires=['asyncio-mqtt>=0.12.1', - 'dpath>=2.0.6', + install_requires=['aiomqtt~=1.2', + 'dpath~=2.1', 'paho-mqtt>=1.6.1', - 'PyYAML>=6.0', + 'PyYAML~=6.0', 'websockets>=10.2' ],