Skip to content

Commit

Permalink
Upgrade to aiomqtt (fixes #23)
Browse files Browse the repository at this point in the history
  • Loading branch information
cbpowell committed Jan 1, 2024
1 parent 2f06f89 commit da48ce9
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 78 deletions.
114 changes: 40 additions & 74 deletions senselink/mqtt/mqtt_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,15 @@

import logging
import asyncio
from aiomqtt import Client, MqttError
from typing import Dict
from asyncio_mqtt import Client, MqttError
from contextlib import AsyncExitStack

from .mqtt_listener import MQTTListener

MQTT_LOGGER = logging.getLogger('mqtt')
MQTT_LOGGER.setLevel(logging.WARNING)


async def cancel_tasks(tasks):
for task in tasks:
if task.done():
continue
task.cancel()
try:
await task
except asyncio.CancelledError:
pass


class MQTTController:
client = None
topics: Dict[str, MQTTListener] = None
Expand All @@ -34,7 +22,7 @@ def __init__(self, host, port=1883, username=None, password=None):
self.password = password

self.data_sources = []
self.topics = {}
self.listeners = {}

async def connect(self):
# Create task
Expand All @@ -43,9 +31,12 @@ async def connect(self):
async def client_handler(self):
logging.info(f"Starting MQTT client to URL: {self.host}")
reconnect_interval = 5 # [seconds]

client = Client(self.host, self.port, username=self.username, password=self.password)
while True:
try:
await self.listen()
async with client:
await self.listen(client)
except MqttError as error:
logging.error(f'Disconnected from MQTT broker with error: {error}')
logging.debug(f'MQTT client disconnected/ended, reconnecting in {reconnect_interval}...')
Expand All @@ -56,65 +47,40 @@ async def client_handler(self):
logging.error(f'Stopping MQTT client with error: {error}')
return False

async def listen(self):
async with AsyncExitStack() as stack:
# Track tasks
tasks = set()
stack.push_async_callback(cancel_tasks, tasks)

# Connect to the MQTT broker
client = Client(self.host, self.port, username=self.username, password=self.password)
await stack.enter_async_context(client)

logging.info(f'MQTT client connected')
# Add tasks for each data source handler
for ds in self.data_sources:
# Get handlers from data source
ds_listeners = ds.listeners()
# Iterate through data source listeners and convert to
# 'prime' listeners for each topic
for listener in ds_listeners:
topic = listener.topic
funcs = listener.handlers
if topic in self.topics:
# Add these handlers to existing top level topic handler
logging.debug(f'Adding handlers for existing prime Listener: {topic}')
ext_topic = self.topics[topic]
ext_topic.handlers.extend(funcs)
else:
# Add this instance as a new top level handler
logging.debug(f'Creating new prime Listener for topic: {topic}')
self.topics[topic] = MQTTListener(topic, funcs)

# Add handlers for each topic as a filtered topic
for topic, listener in self.topics.items():
manager = client.filtered_messages(topic)
messages = await stack.enter_async_context(manager)
task = asyncio.create_task(self.parse_messages(messages))
tasks.add(task)

# Subscribe to all topics
# Assume QoS 0 for now
all_topics = [(t, 0) for t in self.topics.keys()]
logging.info(f'Subscribing to MQTT {len(all_topics)} topic(s)')
logging.debug(f'Topics: {all_topics}')
try:
await client.subscribe(all_topics)
except ValueError as err:
logging.error(f'MQTT Subscribe error: {err}')

# Gather all tasks
await asyncio.gather(*tasks)
logging.info(f'Listening for MQTT updates')

async def parse_messages(self, messages):
async for message in messages:
topic = message.topic
# Get handlers and iterate through
listener = self.topics[topic]
for func in listener.handlers:
# Decode to UTF-8
await func(message.payload.decode())
async def listen(self, client):
logging.info(f'MQTT client connected')
# Add tasks for each data source handler
for ds in self.data_sources:
# Get handlers from data source
ds_listeners = ds.listeners()
# Iterate through data source listeners and convert to
# 'prime' listeners for each topic
for listener in ds_listeners:
topic = listener.topic
funcs = listener.handlers
if topic in self.listeners:
# Add these handlers to existing top level topic handler
logging.debug(f'Adding handlers for existing prime Listener: {topic}')
ext_topic = self.listeners[topic]
ext_topic.handlers.extend(funcs)
else:
# Add this instance as a new top level handler
logging.debug(f'Creating new prime Listener for topic: {topic}')
self.listeners[topic] = MQTTListener(topic, funcs)

async with client.messages() as messages:
# Subscribe to specified topics
for topic, handlers in self.listeners.items():
await client.subscribe(topic)
# Handle messages that come in
async for message in messages:
topic = message.topic.value
handlers = self.listeners[topic].handlers
logging.debug(f'Got message for topic: {topic}')
for func in handlers:
# Decode to UTF-8
payload = message.payload.decode()
await func(payload)


if __name__ == "__main__":
Expand Down
8 changes: 4 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name='SenseLink',
version='2.2.0',
version='2.2.1',
description='A tool to create virtual smart plugs and inform a Sense Home Energy Monitor about usage in your home',
long_description=long_description,
long_description_content_type="text/markdown",
Expand All @@ -14,10 +14,10 @@
author_email='[email protected]',
license='MIT',
packages=find_packages(),
install_requires=['asyncio-mqtt>=0.12.1',
'dpath>=2.0.6',
install_requires=['aiomqtt~=1.2',
'dpath~=2.1',
'paho-mqtt>=1.6.1',
'PyYAML>=6.0',
'PyYAML~=6.0',
'websockets>=10.2'
],

Expand Down

0 comments on commit da48ce9

Please sign in to comment.