Kaynağa Gözat

WIP HA Integration

Fabio De Simone 11 ay önce
ebeveyn
işleme
466ca09859

+ 17 - 0
.env.example

@@ -0,0 +1,17 @@
+# MQTT Configuration
+# Required for real MQTT implementation (leave empty for mock implementation)
+MQTT_BROKER=
+MQTT_PORT=1883
+MQTT_USERNAME=
+MQTT_PASSWORD=
+MQTT_CLIENT_ID=dune_weaver
+
+# MQTT Topics
+MQTT_STATUS_TOPIC=dune_weaver/status
+MQTT_COMMAND_TOPIC=dune_weaver/command
+MQTT_STATUS_INTERVAL=30  # Status update interval in seconds
+
+# Home Assistant MQTT Discovery
+MQTT_DISCOVERY_PREFIX=homeassistant
+HA_DEVICE_NAME=Sand Table
+HA_DEVICE_ID=dune_weaver

+ 2 - 0
dune_weaver_flask/__init__.py

@@ -0,0 +1,2 @@
+from dotenv import load_dotenv
+load_dotenv()

+ 11 - 8
dune_weaver_flask/app.py

@@ -8,6 +8,7 @@ from dune_weaver_flask.modules.core import pattern_manager
 from dune_weaver_flask.modules.core import playlist_manager
 from .modules.firmware import firmware_manager
 from dune_weaver_flask.modules.core.state import state
+from dune_weaver_flask.modules import mqtt
 
 
 # Configure logging
@@ -249,9 +250,8 @@ def serial_status():
 
 @app.route('/pause_execution', methods=['POST'])
 def pause_execution():
-    logger.info("Pausing pattern execution")
-    pattern_manager.pause_requested = True
-    return jsonify({'success': True, 'message': 'Execution paused'})
+    if pattern_manager.pause_execution():
+        return jsonify({'success': True, 'message': 'Execution paused'})
 
 @app.route('/status', methods=['GET'])
 def get_status():
@@ -259,11 +259,8 @@ def get_status():
 
 @app.route('/resume_execution', methods=['POST'])
 def resume_execution():
-    logger.info("Resuming pattern execution")
-    with pattern_manager.pause_condition:
-        pattern_manager.pause_requested = False
-        pattern_manager.pause_condition.notify_all()
-    return jsonify({'success': True, 'message': 'Execution resumed'})
+    if pattern_manager.resume_execution():
+        return jsonify({'success': True, 'message': 'Execution resumed'})
 
 # Playlist endpoints
 @app.route("/list_all_playlists", methods=["GET"])
@@ -422,6 +419,7 @@ def on_exit():
     """Function to execute on application shutdown."""
     pattern_manager.stop_actions()
     state.save()
+    mqtt.cleanup_mqtt()
 
 def entrypoint():
     logger.info("Starting Dune Weaver application...")
@@ -433,6 +431,11 @@ def entrypoint():
         serial_manager.connect_to_serial()
     except Exception as e:
         logger.warning(f"Failed to auto-connect to serial port: {str(e)}")
+        
+    try:
+        mqtt_handler = mqtt.init_mqtt()
+    except Exception as e:
+        logger.warning(f"Failed to initialize MQTT: {str(e)}")
 
     try:
         logger.info("Starting Flask server on port 8080...")

+ 12 - 0
dune_weaver_flask/modules/core/pattern_manager.py

@@ -144,6 +144,18 @@ def interpolate_path(theta, rho):
     state.machine_x = new_x_abs
     state.machine_y = new_y_abs
     
+def pause_execution():
+    logger.info("Pausing pattern execution")
+    pattern_manager.pause_requested = True
+    return True
+
+def resume_execution():
+    logger.info("Resuming pattern execution")
+    with pattern_manager.pause_condition:
+        pattern_manager.pause_requested = False
+        pattern_manager.pause_condition.notify_all()
+    return True
+    
 def reset_theta():
     logger.info('Resetting Theta')
     state.current_theta = 0

+ 30 - 0
dune_weaver_flask/modules/mqtt/__init__.py

@@ -0,0 +1,30 @@
+"""MQTT module for Dune Weaver Flask application."""
+from .factory import create_mqtt_handler
+import logging
+
+logger = logging.getLogger(__name__)
+# Global MQTT handler instance
+mqtt_handler = None
+
+def init_mqtt():
+    """Initialize the MQTT handler."""
+    global mqtt_handler
+    logger.info("initializing mqtt module")
+    if mqtt_handler is None:
+        mqtt_handler = create_mqtt_handler()
+        mqtt_handler.start()
+    return mqtt_handler
+
+def get_mqtt_handler():
+    """Get the MQTT handler instance."""
+    global mqtt_handler
+    if mqtt_handler is None:
+        mqtt_handler = init_mqtt()
+    return mqtt_handler
+
+def cleanup_mqtt():
+    """Clean up MQTT handler resources."""
+    global mqtt_handler
+    if mqtt_handler is not None:
+        mqtt_handler.stop()
+        mqtt_handler = None 

+ 39 - 0
dune_weaver_flask/modules/mqtt/base.py

@@ -0,0 +1,39 @@
+"""Base MQTT handler interface."""
+from abc import ABC, abstractmethod
+from typing import Dict, Callable, List, Optional, Any
+
+class BaseMQTTHandler(ABC):
+    """Abstract base class for MQTT handlers."""
+    
+    @abstractmethod
+    def start(self) -> None:
+        """Start the MQTT handler."""
+        pass
+    
+    @abstractmethod
+    def stop(self) -> None:
+        """Stop the MQTT handler."""
+        pass
+    
+    @abstractmethod
+    def update_state(self, is_running: Optional[bool] = None, 
+                    current_file: Optional[str] = None,
+                    patterns: Optional[List[str]] = None, 
+                    serial: Optional[str] = None,
+                    playlist: Optional[Dict[str, Any]] = None) -> None:
+        """Update the state of the sand table and publish to MQTT.
+        
+        Args:
+            is_running: Whether the table is currently running a pattern
+            current_file: The currently playing file
+            patterns: List of available pattern files
+            serial: Serial connection status
+            playlist: Current playlist information if in playlist mode
+        """
+        pass
+    
+    @property
+    @abstractmethod
+    def is_enabled(self) -> bool:
+        """Return whether MQTT functionality is enabled."""
+        pass 

+ 25 - 0
dune_weaver_flask/modules/mqtt/factory.py

@@ -0,0 +1,25 @@
+"""Factory for creating MQTT handlers."""
+import os
+from typing import Dict, Callable
+from .base import BaseMQTTHandler
+from .handler import MQTTHandler
+from .mock import MockMQTTHandler
+from .utils import create_mqtt_callbacks
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+def create_mqtt_handler() -> BaseMQTTHandler:
+    """Create and return an appropriate MQTT handler based on configuration.
+    
+    Returns:
+        BaseMQTTHandler: Either a real MQTTHandler if MQTT_BROKER is configured,
+                        or a MockMQTTHandler if not.
+    """
+    if os.getenv('MQTT_BROKER'):
+        logger.info("Got MQTT configuration, instantiating MQTTHandler")
+        return MQTTHandler(create_mqtt_callbacks())
+    
+    logger.info("MQTT Not going to be used, instantiating MockMQTTHandler")
+    return MockMQTTHandler() 

+ 311 - 0
dune_weaver_flask/modules/mqtt/handler.py

@@ -0,0 +1,311 @@
+"""Real MQTT handler implementation."""
+import os
+import threading
+import time
+import json
+from typing import Dict, Callable, List, Optional, Any
+import paho.mqtt.client as mqtt
+import logging
+
+from .base import BaseMQTTHandler
+from dune_weaver_flask.modules.core.state import state
+from dune_weaver_flask.modules.core.pattern_manager import list_theta_rho_files
+from dune_weaver_flask.modules.core.playlist_manager import list_all_playlists
+from dune_weaver_flask.modules.serial.serial_manager import is_connected
+
+logger = logging.getLogger(__name__)
+
+class MQTTHandler(BaseMQTTHandler):
+    """Real implementation of MQTT handler."""
+    
+    def __init__(self, callback_registry: Dict[str, Callable]):
+        # MQTT Configuration from environment variables
+        self.broker = os.getenv('MQTT_BROKER')
+        self.port = int(os.getenv('MQTT_PORT', '1883'))
+        self.username = os.getenv('MQTT_USERNAME')
+        self.password = os.getenv('MQTT_PASSWORD')
+        self.client_id = os.getenv('MQTT_CLIENT_ID', 'dune_weaver')
+        self.status_topic = os.getenv('MQTT_STATUS_TOPIC', 'dune_weaver/status')
+        self.command_topic = os.getenv('MQTT_COMMAND_TOPIC', 'dune_weaver/command')
+        self.status_interval = int(os.getenv('MQTT_STATUS_INTERVAL', '30'))
+
+        # Store callback registry
+        self.callback_registry = callback_registry
+
+        # Threading control
+        self.running = False
+        self.status_thread = None
+
+        # Home Assistant MQTT Discovery settings
+        self.discovery_prefix = os.getenv('MQTT_DISCOVERY_PREFIX', 'homeassistant')
+        self.device_name = os.getenv('HA_DEVICE_NAME', 'Sand Table')
+        self.device_id = os.getenv('HA_DEVICE_ID', 'dune_weaver')
+        
+        # Additional topics for state
+        self.current_file_topic = f"{self.device_id}/state/current_file"
+        self.running_state_topic = f"{self.device_id}/state/running"
+        self.serial_state_topic = f"{self.device_id}/state/serial"
+        self.pattern_select_topic = f"{self.device_id}/pattern/set"
+        self.playlist_select_topic = f"{self.device_id}/playlist/set"
+        self.speed_topic = f"{self.device_id}/speed/set"
+
+        # Store current state
+        self.current_file = ""
+        self.is_running_state = False
+        self.serial_state = ""
+        self.patterns = []
+        self.playlists = []
+
+        # Initialize MQTT client if broker is configured
+        if self.broker:
+            self.client = mqtt.Client(client_id=self.client_id)
+            self.client.on_connect = self.on_connect
+            self.client.on_message = self.on_message
+
+            if self.username and self.password:
+                self.client.username_pw_set(self.username, self.password)
+
+    def setup_ha_discovery(self):
+        """Publish Home Assistant MQTT discovery configurations."""
+        if not self.is_enabled:
+            return
+
+        base_device = {
+            "identifiers": [self.device_id],
+            "name": self.device_name,
+            "model": "Dune Weaver",
+            "manufacturer": "DIY"
+        }
+        
+        # Serial State Sensor
+        serial_config = {
+            "name": f"{self.device_name} Serial State",
+            "unique_id": f"{self.device_id}_serial_state",
+            "state_topic": self.serial_state_topic,
+            "device": base_device,
+            "icon": "mdi:serial-port",
+            "entity_category": "diagnostic"
+        }
+        self._publish_discovery("sensor", "serial_state", serial_config)
+
+        # Running State Sensor
+        running_config = {
+            "name": f"{self.device_name} Running State",
+            "unique_id": f"{self.device_id}_running_state",
+            "state_topic": self.running_state_topic,
+            "device": base_device,
+            "icon": "mdi:machine",
+            "entity_category": "diagnostic"
+        }
+        self._publish_discovery("binary_sensor", "running_state", running_config)
+
+        # Speed Control
+        speed_config = {
+            "name": f"{self.device_name} Speed",
+            "unique_id": f"{self.device_id}_speed",
+            "command_topic": self.speed_topic,
+            "state_topic": f"{self.speed_topic}/state",
+            "device": base_device,
+            "icon": "mdi:speedometer",
+            "min": 50,
+            "max": 1000,
+            "step": 50
+        }
+        self._publish_discovery("number", "speed", speed_config)
+
+        # Pattern Select
+        pattern_config = {
+            "name": f"{self.device_name} Pattern",
+            "unique_id": f"{self.device_id}_pattern",
+            "command_topic": self.pattern_select_topic,
+            "state_topic": f"{self.pattern_select_topic}/state",
+            "options": self.patterns,
+            "device": base_device,
+            "icon": "mdi:draw"
+        }
+        self._publish_discovery("select", "pattern", pattern_config)
+
+        # Playlist Select
+        playlist_config = {
+            "name": f"{self.device_name} Playlist",
+            "unique_id": f"{self.device_id}_playlist",
+            "command_topic": self.playlist_select_topic,
+            "state_topic": f"{self.playlist_select_topic}/state",
+            "options": self.playlists,
+            "device": base_device,
+            "icon": "mdi:playlist-play"
+        }
+        self._publish_discovery("select", "playlist", playlist_config)
+
+        # Playlist Active Sensor
+        playlist_active_config = {
+            "name": f"{self.device_name} Playlist Active",
+            "unique_id": f"{self.device_id}_playlist_active",
+            "state_topic": f"{self.device_id}/state/playlist",
+            "value_template": "{{ value_json.active }}",
+            "device": base_device,
+            "icon": "mdi:playlist-play",
+            "entity_category": "diagnostic"
+        }
+        self._publish_discovery("binary_sensor", "playlist_active", playlist_active_config)
+
+    def _publish_discovery(self, component: str, config_type: str, config: dict):
+        """Helper method to publish HA discovery configs."""
+        if not self.is_enabled:
+            return
+            
+        discovery_topic = f"{self.discovery_prefix}/{component}/{self.device_id}/{config_type}/config"
+        self.client.publish(discovery_topic, json.dumps(config), retain=True)
+
+    def update_state(self, is_running: Optional[bool] = None, 
+                    current_file: Optional[str] = None,
+                    patterns: Optional[List[str]] = None, 
+                    serial: Optional[str] = None,
+                    playlist: Optional[Dict[str, Any]] = None) -> None:
+        """Update the state of the sand table and publish to MQTT."""
+        if not self.is_enabled:
+            return
+
+        if is_running is not None:
+            self.is_running_state = is_running
+            self.client.publish(self.running_state_topic, "ON" if is_running else "OFF", retain=True)
+        
+        if current_file is not None:
+            self.current_file = current_file
+            self.client.publish(self.current_file_topic, current_file, retain=True)
+            self.client.publish(f"{self.pattern_select_topic}/state", current_file, retain=True)
+
+        if patterns is not None and set(patterns) != set(self.patterns):
+            self.patterns = patterns
+            self.setup_ha_discovery()
+        
+        if serial is not None:
+            self.serial_state = serial
+            self.client.publish(self.serial_state_topic, serial, retain=True)
+
+        # Update speed state
+        self.client.publish(f"{self.speed_topic}/state", state.speed, retain=True)
+
+    def on_connect(self, client, userdata, flags, rc):
+        """Callback when connected to MQTT broker."""
+        logger.info(f"Connected to MQTT broker with result code {rc}")
+        # Subscribe to command topics
+        client.subscribe([
+            (self.command_topic, 0),
+            (self.pattern_select_topic, 0),
+            (self.playlist_select_topic, 0),
+            (self.speed_topic, 0)
+        ])
+        # Publish discovery configurations
+        self.setup_ha_discovery()
+
+    def on_message(self, client, userdata, msg):
+        """Callback when message is received."""
+        try:
+            if msg.topic == self.pattern_select_topic:
+                # Handle pattern selection
+                pattern_name = msg.payload.decode()
+                if pattern_name in self.patterns:
+                    self.callback_registry['run_pattern'](file_path=f"{pattern_name}")
+                    self.client.publish(f"{self.pattern_select_topic}/state", pattern_name, retain=True)
+            elif msg.topic == self.playlist_select_topic:
+                # Handle playlist selection
+                playlist_name = msg.payload.decode()
+                if playlist_name in self.playlists:
+                    self.callback_registry['run_playlist'](playlist_name=playlist_name)
+                    self.client.publish(f"{self.playlist_select_topic}/state", playlist_name, retain=True)
+            elif msg.topic == self.speed_topic:
+                speed = int(msg.payload.decode())
+                self.callback_registry['set_speed'](speed)
+            else:
+                # Handle other commands
+                payload = json.loads(msg.payload.decode())
+                command = payload.get('command')
+                params = payload.get('params', {})
+
+                if command in self.callback_registry:
+                    self.callback_registry[command](**params)
+                else:
+                    logger.error(f"Unknown command received: {command}")
+
+        except json.JSONDecodeError:
+            logger.error(f"Invalid JSON payload received: {msg.payload}")
+        except Exception as e:
+            logger.error(f"Error processing MQTT message: {e}")
+
+    def publish_status(self):
+        """Publish status updates periodically."""
+        while self.running:
+            try:
+                # Create status message
+                status = {
+                    "status": "running" if not state.stop_requested else "idle",
+                    "timestamp": time.time(),
+                    "client_id": self.client_id,
+                    "current_file": state.current_playing_file,
+                    "speed": state.speed,
+                    "position": {
+                        "theta": state.current_theta,
+                        "rho": state.current_rho,
+                        "x": state.machine_x,
+                        "y": state.machine_y
+                    }
+                }
+                logger.info(f"publishing status: {status}" )
+                self.client.publish(self.status_topic, json.dumps(status))
+                
+                # Wait for next interval
+                time.sleep(self.status_interval)
+            except Exception as e:
+                logger.error(f"Error publishing status: {e}")
+                time.sleep(5)  # Wait before retry
+
+    def start(self) -> None:
+        """Start the MQTT handler."""
+        if not self.is_enabled:
+            return
+        
+        try:
+            self.client.connect(self.broker, self.port)
+            self.client.loop_start()
+            # Start status publishing thread
+            self.running = True
+            self.status_thread = threading.Thread(target=self.publish_status, daemon=True)
+            self.status_thread.start()
+            
+            # Get initial states
+            patterns = list_theta_rho_files()
+            playlists = list_all_playlists()
+            serial_status = is_connected()
+            
+            logger.info(patterns, playlists, serial_status)
+            # Wait for connection
+            time.sleep(1)
+
+            # Publish initial states
+            self.update_state(
+                is_running=not state.stop_requested,
+                current_file=state.current_playing_file,
+                patterns=patterns,
+                serial=serial_status.get('status', '')
+            )
+            
+            logger.info("MQTT Handler started successfully")
+        except Exception as e:
+            logger.error(f"Failed to start MQTT Handler: {e}")
+
+    def stop(self) -> None:
+        """Stop the MQTT handler."""
+        if not self.is_enabled:
+            return
+
+        self.running = False
+        if self.status_thread:
+            self.status_thread.join(timeout=1)
+        self.client.loop_stop()
+        self.client.disconnect()
+
+    @property
+    def is_enabled(self) -> bool:
+        """Return whether MQTT functionality is enabled."""
+        return bool(self.broker) 

+ 24 - 0
dune_weaver_flask/modules/mqtt/mock.py

@@ -0,0 +1,24 @@
+"""Mock MQTT handler implementation."""
+from .base import BaseMQTTHandler
+
+
+
+class MockMQTTHandler(BaseMQTTHandler):
+    """Mock implementation of MQTT handler that does nothing."""
+    
+    def start(self) -> None:
+        """No-op start."""
+        pass
+    
+    def stop(self) -> None:
+        """No-op stop."""
+        pass
+    
+    def update_state(self, **kwargs) -> None:
+        """No-op state update."""
+        pass
+    
+    @property
+    def is_enabled(self) -> bool:
+        """Always returns False since this is a mock."""
+        return False 

+ 53 - 0
dune_weaver_flask/modules/mqtt/utils.py

@@ -0,0 +1,53 @@
+"""MQTT utilities and callback management."""
+import os
+from typing import Dict, Callable
+from dune_weaver_flask.modules.core.pattern_manager import (
+    run_theta_rho_file, stop_actions, pause_execution,
+    resume_execution, THETA_RHO_DIR,
+    run_theta_rho_files
+)
+from dune_weaver_flask.modules.core.playlist_manager import get_playlist
+from dune_weaver_flask.modules.serial.serial_manager import is_connected, home
+from dune_weaver_flask.modules.core.state import state
+
+def create_mqtt_callbacks() -> Dict[str, Callable]:
+    """Create and return the MQTT callback registry."""
+    def set_speed(speed):
+        state.speed = speed
+    return {
+        'run_pattern': run_theta_rho_file,  # Already handles file path
+        'run_playlist': lambda name: run_theta_rho_files(
+            [os.path.join(THETA_RHO_DIR, file) for file in get_playlist(name)['files']],
+            run_mode='loop',
+            pause_time=0,
+            clear_pattern=None
+        ),
+        'stop': stop_actions,  # Already handles state
+        'pause': pause_execution,  # Already handles state
+        'resume': resume_execution,  # Already handles state
+        'home': home,
+        'set_speed': set_speed
+    }
+
+def get_mqtt_state():
+    """Get the current state for MQTT updates."""
+    # Get list of pattern files
+    patterns = []
+    for root, _, filenames in os.walk(THETA_RHO_DIR):
+        for file in filenames:
+            if file.endswith('.thr'):
+                patterns.append(file)
+    
+    # Get current execution status
+    is_running = not state.stop_requested and not state.pause_requested
+    current_file = state.current_playing_file or ''
+    
+    # Get serial status
+    serial_status = is_connected()
+    
+    return {
+        'is_running': is_running,
+        'current_file': current_file,
+        'patterns': sorted(patterns),
+        'serial': serial_status.get('status', ''),
+    } 

+ 2 - 0
requirements.txt

@@ -2,3 +2,5 @@ flask
 pyserial
 esptool
 tqdm
+paho-mqtt
+python-dotenv