|
@@ -6,6 +6,8 @@ import json
|
|
|
from typing import Dict, Callable, List, Optional, Any
|
|
from typing import Dict, Callable, List, Optional, Any
|
|
|
import paho.mqtt.client as mqtt
|
|
import paho.mqtt.client as mqtt
|
|
|
import logging
|
|
import logging
|
|
|
|
|
+import asyncio
|
|
|
|
|
+from functools import partial
|
|
|
|
|
|
|
|
from .base import BaseMQTTHandler
|
|
from .base import BaseMQTTHandler
|
|
|
from modules.core.state import state
|
|
from modules.core.state import state
|
|
@@ -66,6 +68,9 @@ class MQTTHandler(BaseMQTTHandler):
|
|
|
self.state = state
|
|
self.state = state
|
|
|
self.state.mqtt_handler = self # Set reference to self in state, needed so that state setters can update the state
|
|
self.state.mqtt_handler = self # Set reference to self in state, needed so that state setters can update the state
|
|
|
|
|
|
|
|
|
|
+ # Store the main event loop during initialization
|
|
|
|
|
+ self.main_loop = asyncio.get_event_loop()
|
|
|
|
|
+
|
|
|
def setup_ha_discovery(self):
|
|
def setup_ha_discovery(self):
|
|
|
"""Publish Home Assistant MQTT discovery configurations."""
|
|
"""Publish Home Assistant MQTT discovery configurations."""
|
|
|
if not self.is_enabled:
|
|
if not self.is_enabled:
|
|
@@ -227,13 +232,13 @@ class MQTTHandler(BaseMQTTHandler):
|
|
|
else:
|
|
else:
|
|
|
self.client.publish(f"{self.pattern_select_topic}/state", "None", retain=True)
|
|
self.client.publish(f"{self.pattern_select_topic}/state", "None", retain=True)
|
|
|
|
|
|
|
|
- def _publish_playlist_state(self, playlist=None):
|
|
|
|
|
|
|
+ def _publish_playlist_state(self, playlist_name=None):
|
|
|
"""Helper to publish playlist state."""
|
|
"""Helper to publish playlist state."""
|
|
|
- if playlist is None:
|
|
|
|
|
- playlist = self.state.current_playlist
|
|
|
|
|
|
|
+ if playlist_name is None:
|
|
|
|
|
+ playlist_name = self.state.current_playlist_name
|
|
|
|
|
|
|
|
- if playlist:
|
|
|
|
|
- self.client.publish(f"{self.playlist_select_topic}/state", playlist, retain=True)
|
|
|
|
|
|
|
+ if playlist_name:
|
|
|
|
|
+ self.client.publish(f"{self.playlist_select_topic}/state", playlist_name, retain=True)
|
|
|
else:
|
|
else:
|
|
|
self.client.publish(f"{self.playlist_select_topic}/state", "None", retain=True)
|
|
self.client.publish(f"{self.playlist_select_topic}/state", "None", retain=True)
|
|
|
|
|
|
|
@@ -244,7 +249,7 @@ class MQTTHandler(BaseMQTTHandler):
|
|
|
serial_status = f"connected to {serial_port}" if serial_connected else "disconnected"
|
|
serial_status = f"connected to {serial_port}" if serial_connected else "disconnected"
|
|
|
self.client.publish(self.serial_state_topic, serial_status, retain=True)
|
|
self.client.publish(self.serial_state_topic, serial_status, retain=True)
|
|
|
|
|
|
|
|
- def update_state(self, current_file=None, is_running=None, playlist=None):
|
|
|
|
|
|
|
+ def update_state(self, current_file=None, is_running=None, playlist=None, playlist_name=None):
|
|
|
"""Update state in Home Assistant. Only publishes the attributes that are explicitly passed."""
|
|
"""Update state in Home Assistant. Only publishes the attributes that are explicitly passed."""
|
|
|
if not self.is_enabled:
|
|
if not self.is_enabled:
|
|
|
return
|
|
return
|
|
@@ -259,8 +264,8 @@ class MQTTHandler(BaseMQTTHandler):
|
|
|
self._publish_running_state(running_state)
|
|
self._publish_running_state(running_state)
|
|
|
|
|
|
|
|
# Update playlist state if playlist info is provided
|
|
# Update playlist state if playlist info is provided
|
|
|
- if playlist is not None:
|
|
|
|
|
- self._publish_playlist_state(playlist)
|
|
|
|
|
|
|
+ if playlist_name is not None:
|
|
|
|
|
+ self._publish_playlist_state(playlist_name)
|
|
|
|
|
|
|
|
def on_connect(self, client, userdata, flags, rc):
|
|
def on_connect(self, client, userdata, flags, rc):
|
|
|
"""Callback when connected to MQTT broker."""
|
|
"""Callback when connected to MQTT broker."""
|
|
@@ -299,13 +304,21 @@ class MQTTHandler(BaseMQTTHandler):
|
|
|
# Handle pattern selection
|
|
# Handle pattern selection
|
|
|
pattern_name = msg.payload.decode()
|
|
pattern_name = msg.payload.decode()
|
|
|
if pattern_name in self.patterns:
|
|
if pattern_name in self.patterns:
|
|
|
- self.callback_registry['run_pattern'](file_path=f"{THETA_RHO_DIR}/{pattern_name}")
|
|
|
|
|
|
|
+ # Schedule the coroutine to run in the main event loop
|
|
|
|
|
+ asyncio.run_coroutine_threadsafe(
|
|
|
|
|
+ self.callback_registry['run_pattern'](file_path=f"{THETA_RHO_DIR}/{pattern_name}"),
|
|
|
|
|
+ self.main_loop
|
|
|
|
|
+ )
|
|
|
self.client.publish(f"{self.pattern_select_topic}/state", pattern_name, retain=True)
|
|
self.client.publish(f"{self.pattern_select_topic}/state", pattern_name, retain=True)
|
|
|
elif msg.topic == self.playlist_select_topic:
|
|
elif msg.topic == self.playlist_select_topic:
|
|
|
# Handle playlist selection
|
|
# Handle playlist selection
|
|
|
playlist_name = msg.payload.decode()
|
|
playlist_name = msg.payload.decode()
|
|
|
if playlist_name in self.playlists:
|
|
if playlist_name in self.playlists:
|
|
|
- self.callback_registry['run_playlist'](playlist_name=playlist_name)
|
|
|
|
|
|
|
+ # Schedule the coroutine to run in the main event loop
|
|
|
|
|
+ asyncio.run_coroutine_threadsafe(
|
|
|
|
|
+ self.callback_registry['run_playlist'](playlist_name=playlist_name),
|
|
|
|
|
+ self.main_loop
|
|
|
|
|
+ )
|
|
|
self.client.publish(f"{self.playlist_select_topic}/state", playlist_name, retain=True)
|
|
self.client.publish(f"{self.playlist_select_topic}/state", playlist_name, retain=True)
|
|
|
elif msg.topic == self.speed_topic:
|
|
elif msg.topic == self.speed_topic:
|
|
|
speed = int(msg.payload.decode())
|
|
speed = int(msg.payload.decode())
|