Procházet zdrojové kódy

fix ha integration

Tuan Nguyen před 11 měsíci
rodič
revize
32815118fe

+ 1 - 0
modules/core/playlist_manager.py

@@ -106,6 +106,7 @@ async def run_playlist(playlist_name, pause_time=0, clear_pattern=None, run_mode
     try:
     try:
         logger.info(f"Starting playlist '{playlist_name}' with mode={run_mode}, shuffle={shuffle}")
         logger.info(f"Starting playlist '{playlist_name}' with mode={run_mode}, shuffle={shuffle}")
         state.current_playlist = file_paths
         state.current_playlist = file_paths
+        state.current_playlist_name = playlist_name
         asyncio.create_task(
         asyncio.create_task(
             pattern_manager.run_theta_rho_files(
             pattern_manager.run_theta_rho_files(
                 file_paths,
                 file_paths,

+ 13 - 0
modules/core/state.py

@@ -10,6 +10,7 @@ class AppState:
         self._pause_requested = False
         self._pause_requested = False
         self._speed = 150
         self._speed = 150
         self._current_playlist = None
         self._current_playlist = None
+        self._current_playlist_name = None  # New variable for playlist name
         
         
         # Regular state variables
         # Regular state variables
         self.stop_requested = False
         self.stop_requested = False
@@ -87,6 +88,16 @@ class AppState:
         if self.mqtt_handler:
         if self.mqtt_handler:
             self.mqtt_handler.update_state(playlist=value)
             self.mqtt_handler.update_state(playlist=value)
 
 
+    @property
+    def current_playlist_name(self):
+        return self._current_playlist_name
+
+    @current_playlist_name.setter
+    def current_playlist_name(self, value):
+        self._current_playlist_name = value
+        if self.mqtt_handler:
+            self.mqtt_handler.update_state(playlist_name=value)
+
     def to_dict(self):
     def to_dict(self):
         """Return a dictionary representation of the state."""
         """Return a dictionary representation of the state."""
         return {
         return {
@@ -105,6 +116,7 @@ class AppState:
             "gear_ratio": self.gear_ratio,
             "gear_ratio": self.gear_ratio,
             "homing": self.homing,
             "homing": self.homing,
             "current_playlist": self._current_playlist,
             "current_playlist": self._current_playlist,
+            "current_playlist_name": self._current_playlist_name,
             "current_playlist_index": self.current_playlist_index,
             "current_playlist_index": self.current_playlist_index,
             "playlist_mode": self.playlist_mode,
             "playlist_mode": self.playlist_mode,
             "port": self.port,
             "port": self.port,
@@ -128,6 +140,7 @@ class AppState:
         self.gear_ratio = data.get('gear_ratio', 10)
         self.gear_ratio = data.get('gear_ratio', 10)
         self.homing = data.get('homing', 0)
         self.homing = data.get('homing', 0)
         self._current_playlist = data.get("current_playlist")
         self._current_playlist = data.get("current_playlist")
+        self._current_playlist_name = data.get("current_playlist_name")
         self.current_playlist_index = data.get("current_playlist_index")
         self.current_playlist_index = data.get("current_playlist_index")
         self.playlist_mode = data.get("playlist_mode")
         self.playlist_mode = data.get("playlist_mode")
         self.port = data.get("port", None)
         self.port = data.get("port", None)

+ 10 - 2
modules/mqtt/factory.py

@@ -1,6 +1,8 @@
 """Factory for creating MQTT handlers."""
 """Factory for creating MQTT handlers."""
 import os
 import os
 from typing import Dict, Callable
 from typing import Dict, Callable
+from dotenv import load_dotenv
+from pathlib import Path
 from .base import BaseMQTTHandler
 from .base import BaseMQTTHandler
 from .handler import MQTTHandler
 from .handler import MQTTHandler
 from .mock import MockMQTTHandler
 from .mock import MockMQTTHandler
@@ -10,6 +12,11 @@ import logging
 
 
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
 
 
+# Load environment variables
+BASE_DIR = Path(__file__).resolve().parent.parent.parent  # Go up to project root
+env_path = os.path.join(BASE_DIR, '.env')
+load_dotenv(env_path)
+
 def create_mqtt_handler() -> BaseMQTTHandler:
 def create_mqtt_handler() -> BaseMQTTHandler:
     """Create and return an appropriate MQTT handler based on configuration.
     """Create and return an appropriate MQTT handler based on configuration.
     
     
@@ -17,8 +24,9 @@ def create_mqtt_handler() -> BaseMQTTHandler:
         BaseMQTTHandler: Either a real MQTTHandler if MQTT_BROKER is configured,
         BaseMQTTHandler: Either a real MQTTHandler if MQTT_BROKER is configured,
                         or a MockMQTTHandler if not.
                         or a MockMQTTHandler if not.
     """
     """
-    if os.getenv('MQTT_BROKER'):
-        logger.info("Got MQTT configuration, instantiating MQTTHandler")
+    mqtt_broker = os.getenv('MQTT_BROKER')
+    if mqtt_broker:
+        logger.info(f"Got MQTT configuration for broker: {mqtt_broker}, instantiating MQTTHandler")
         return MQTTHandler(create_mqtt_callbacks())
         return MQTTHandler(create_mqtt_callbacks())
     
     
     logger.info("MQTT Not going to be used, instantiating MockMQTTHandler")
     logger.info("MQTT Not going to be used, instantiating MockMQTTHandler")

+ 23 - 10
modules/mqtt/handler.py

@@ -6,6 +6,8 @@ import json
 from typing import Dict, Callable, List, Optional, Any
 from typing import Dict, Callable, List, Optional, Any
 import paho.mqtt.client as mqtt
 import paho.mqtt.client as mqtt
 import logging
 import logging
+import asyncio
+from functools import partial
 
 
 from .base import BaseMQTTHandler
 from .base import BaseMQTTHandler
 from modules.core.state import state
 from modules.core.state import state
@@ -66,6 +68,9 @@ class MQTTHandler(BaseMQTTHandler):
         self.state = state
         self.state = state
         self.state.mqtt_handler = self  # Set reference to self in state, needed so that state setters can update the state
         self.state.mqtt_handler = self  # Set reference to self in state, needed so that state setters can update the state
 
 
+        # Store the main event loop during initialization
+        self.main_loop = asyncio.get_event_loop()
+
     def setup_ha_discovery(self):
     def setup_ha_discovery(self):
         """Publish Home Assistant MQTT discovery configurations."""
         """Publish Home Assistant MQTT discovery configurations."""
         if not self.is_enabled:
         if not self.is_enabled:
@@ -227,13 +232,13 @@ class MQTTHandler(BaseMQTTHandler):
         else:
         else:
             self.client.publish(f"{self.pattern_select_topic}/state", "None", retain=True)
             self.client.publish(f"{self.pattern_select_topic}/state", "None", retain=True)
             
             
-    def _publish_playlist_state(self, playlist=None):
+    def _publish_playlist_state(self, playlist_name=None):
         """Helper to publish playlist state."""
         """Helper to publish playlist state."""
-        if playlist is None:
-            playlist = self.state.current_playlist
+        if playlist_name is None:
+            playlist_name = self.state.current_playlist_name
             
             
-        if playlist:
-            self.client.publish(f"{self.playlist_select_topic}/state", playlist, retain=True)
+        if playlist_name:
+            self.client.publish(f"{self.playlist_select_topic}/state", playlist_name, retain=True)
         else:
         else:
             self.client.publish(f"{self.playlist_select_topic}/state", "None", retain=True)
             self.client.publish(f"{self.playlist_select_topic}/state", "None", retain=True)
             
             
@@ -244,7 +249,7 @@ class MQTTHandler(BaseMQTTHandler):
         serial_status = f"connected to {serial_port}" if serial_connected else "disconnected"
         serial_status = f"connected to {serial_port}" if serial_connected else "disconnected"
         self.client.publish(self.serial_state_topic, serial_status, retain=True)
         self.client.publish(self.serial_state_topic, serial_status, retain=True)
 
 
-    def update_state(self, current_file=None, is_running=None, playlist=None):
+    def update_state(self, current_file=None, is_running=None, playlist=None, playlist_name=None):
         """Update state in Home Assistant. Only publishes the attributes that are explicitly passed."""
         """Update state in Home Assistant. Only publishes the attributes that are explicitly passed."""
         if not self.is_enabled:
         if not self.is_enabled:
             return
             return
@@ -259,8 +264,8 @@ class MQTTHandler(BaseMQTTHandler):
             self._publish_running_state(running_state)
             self._publish_running_state(running_state)
         
         
         # Update playlist state if playlist info is provided
         # Update playlist state if playlist info is provided
-        if playlist is not None:
-            self._publish_playlist_state(playlist)
+        if playlist_name is not None:
+            self._publish_playlist_state(playlist_name)
 
 
     def on_connect(self, client, userdata, flags, rc):
     def on_connect(self, client, userdata, flags, rc):
         """Callback when connected to MQTT broker."""
         """Callback when connected to MQTT broker."""
@@ -299,13 +304,21 @@ class MQTTHandler(BaseMQTTHandler):
                 # Handle pattern selection
                 # Handle pattern selection
                 pattern_name = msg.payload.decode()
                 pattern_name = msg.payload.decode()
                 if pattern_name in self.patterns:
                 if pattern_name in self.patterns:
-                    self.callback_registry['run_pattern'](file_path=f"{THETA_RHO_DIR}/{pattern_name}")
+                    # Schedule the coroutine to run in the main event loop
+                    asyncio.run_coroutine_threadsafe(
+                        self.callback_registry['run_pattern'](file_path=f"{THETA_RHO_DIR}/{pattern_name}"),
+                        self.main_loop
+                    )
                     self.client.publish(f"{self.pattern_select_topic}/state", pattern_name, retain=True)
                     self.client.publish(f"{self.pattern_select_topic}/state", pattern_name, retain=True)
             elif msg.topic == self.playlist_select_topic:
             elif msg.topic == self.playlist_select_topic:
                 # Handle playlist selection
                 # Handle playlist selection
                 playlist_name = msg.payload.decode()
                 playlist_name = msg.payload.decode()
                 if playlist_name in self.playlists:
                 if playlist_name in self.playlists:
-                    self.callback_registry['run_playlist'](playlist_name=playlist_name)
+                    # Schedule the coroutine to run in the main event loop
+                    asyncio.run_coroutine_threadsafe(
+                        self.callback_registry['run_playlist'](playlist_name=playlist_name),
+                        self.main_loop
+                    )
                     self.client.publish(f"{self.playlist_select_topic}/state", playlist_name, retain=True)
                     self.client.publish(f"{self.playlist_select_topic}/state", playlist_name, retain=True)
             elif msg.topic == self.speed_topic:
             elif msg.topic == self.speed_topic:
                 speed = int(msg.payload.decode())
                 speed = int(msg.payload.decode())