Przeglądaj źródła

better handle shutdown

Tuan Nguyen 11 miesięcy temu
rodzic
commit
01202908c5

+ 40 - 14
app.py

@@ -51,7 +51,7 @@ async def lifespan(app: FastAPI):
     yield  # This separates startup from shutdown code
 
     # Shutdown
-    on_exit()
+    await on_exit()
 
 app = FastAPI(lifespan=lifespan)
 templates = Jinja2Templates(directory="templates")
@@ -237,11 +237,16 @@ async def run_theta_rho(request: ThetaRhoRequest, background_tasks: BackgroundTa
         files_to_run = [file_path]
         logger.info(f'Running theta-rho file: {request.file_name} with pre_execution={request.pre_execution}')
         
-        # Pass arguments in the correct order
+        # Only include clear_pattern if it's not "none"
+        kwargs = {}
+        if request.pre_execution != "none":
+            kwargs['clear_pattern'] = request.pre_execution
+        
+        # Pass arguments properly
         background_tasks.add_task(
             pattern_manager.run_theta_rho_files,
             files_to_run,  # First positional argument
-            clear_pattern=request.pre_execution if request.pre_execution != "none" else None  # Named argument
+            **kwargs  # Spread keyword arguments
         )
         return {"success": True}
     except Exception as e:
@@ -524,27 +529,48 @@ async def get_wled_ip():
         raise HTTPException(status_code=404, detail="No WLED IP set")
     return {"success": True, "wled_ip": state.wled_ip}
 
-def on_exit():
+async def on_exit():
     """Function to execute on application shutdown."""
     logger.info("Shutting down gracefully, please wait for execution to complete")
+    
+    # Stop any running patterns and save state
     pattern_manager.stop_actions()
     state.save()
+    
+    # Clean up pattern manager resources
+    await pattern_manager.cleanup_pattern_manager()
+    
+    # Clean up MQTT resources
     mqtt.cleanup_mqtt()
+    
+    # Clean up state resources
+    state.cleanup()
+    
+    logger.info("Cleanup completed")
 
 def signal_handler(signum, frame):
     """Handle shutdown signals gracefully but forcefully."""
     logger.info("Received shutdown signal, cleaning up...")
     try:
-        # Set a short timeout for cleanup operations
-        import threading
-        cleanup_thread = threading.Thread(target=lambda: [
-            pattern_manager.stop_actions(),
-            state.save(),
-            mqtt.cleanup_mqtt()
-        ])
-        cleanup_thread.daemon = True  # Make thread daemonic so it won't block exit
-        cleanup_thread.start()
-        cleanup_thread.join(timeout=10.0)  # Wait up to 2 seconds for cleanup
+        # Run cleanup operations synchronously to ensure completion
+        pattern_manager.stop_actions()
+        state.save()
+        
+        # Create an event loop for async cleanup
+        loop = asyncio.new_event_loop()
+        asyncio.set_event_loop(loop)
+        
+        # Run async cleanup operations
+        loop.run_until_complete(pattern_manager.cleanup_pattern_manager())
+        
+        # Clean up MQTT and state
+        mqtt.cleanup_mqtt()
+        state.cleanup()
+        
+        # Close the event loop
+        loop.close()
+        
+        logger.info("Cleanup completed")
     except Exception as e:
         logger.error(f"Error during cleanup: {str(e)}")
     finally:

+ 6 - 0
modules/connection/connection_manager.py

@@ -76,6 +76,9 @@ class SerialConnection(BaseConnection):
         with self.lock:
             if self.ser.is_open:
                 self.ser.close()
+        # Release the lock resources
+        self.lock._release_save()
+        self.lock = None
 
 ###############################################################################
 # WebSocket Connection Implementation
@@ -122,6 +125,9 @@ class WebSocketConnection(BaseConnection):
         with self.lock:
             if self.ws:
                 self.ws.close()
+        # Release the lock resources
+        self.lock._release_save()
+        self.lock = None
                 
 def list_serial_ports():
     """Return a list of available serial ports."""

+ 149 - 33
modules/core/pattern_manager.py

@@ -33,6 +33,78 @@ pattern_lock = asyncio.Lock()
 # Progress update task
 progress_update_task = None
 
+async def cleanup_pattern_manager():
+    """Clean up pattern manager resources."""
+    global progress_update_task, pattern_lock, pause_event
+    
+    try:
+        # Cancel progress update task if running
+        if progress_update_task and not progress_update_task.done():
+            progress_update_task.cancel()
+            try:
+                # Use shield to prevent cancellation of the cleanup itself
+                await asyncio.shield(progress_update_task)
+            except asyncio.CancelledError:
+                pass
+            except Exception as e:
+                logger.error(f"Error cancelling progress update task: {e}")
+            finally:
+                progress_update_task = None
+
+        # Clean up pattern lock
+        if pattern_lock:
+            try:
+                if pattern_lock.locked():
+                    # Release the lock directly instead of manipulating internal state
+                    pattern_lock._locked = False
+                    for waiter in pattern_lock._waiters:
+                        if not waiter.done():
+                            waiter.set_result(True)
+            except Exception as e:
+                logger.error(f"Error cleaning up pattern lock: {e}")
+            pattern_lock = None
+
+        # Clean up pause event
+        if pause_event:
+            try:
+                # Set the event and wake up any waiters
+                pause_event.set()
+                for waiter in pause_event._waiters:
+                    if not waiter.done():
+                        waiter.set()
+            except Exception as e:
+                logger.error(f"Error cleaning up pause event: {e}")
+            pause_event = None
+
+        # Clean up pause condition from state
+        if state.pause_condition:
+            try:
+                with state.pause_condition:
+                    # Wake up all waiting threads
+                    state.pause_condition.notify_all()
+                # Create a new condition to ensure clean state
+                state.pause_condition = threading.Condition()
+            except Exception as e:
+                logger.error(f"Error cleaning up pause condition: {e}")
+
+        # Clear all state variables
+        state.current_playing_file = None
+        state.execution_progress = None
+        state.current_playlist = None
+        state.current_playlist_index = None
+        state.playlist_mode = None
+        state.pause_requested = False
+        state.stop_requested = True
+        state.is_clearing = False
+        
+        # Reset machine state
+        connection_manager.update_machine_position()
+        
+        logger.info("Pattern manager resources cleaned up")
+    except Exception as e:
+        logger.error(f"Error during pattern manager cleanup: {e}")
+        raise
+
 def list_theta_rho_files():
     files = []
     for root, _, filenames in os.walk(THETA_RHO_DIR):
@@ -100,24 +172,26 @@ def get_clear_pattern_file(clear_pattern_mode, path=None):
             return random.choice(list(CLEAR_PATTERNS.values()))
         return CLEAR_PATTERNS[clear_pattern_mode]
 
-async def run_theta_rho_file(file_path):
+async def run_theta_rho_file(file_path, is_playlist=False):
     """Run a theta-rho file by sending data in optimized batches with tqdm ETA tracking."""
     if pattern_lock.locked():
         logger.warning("Another pattern is already running. Cannot start a new one.")
         return
 
     async with pattern_lock:  # This ensures only one pattern can run at a time
-        # Start progress update task
+        # Start progress update task only if not part of a playlist
         global progress_update_task
-        progress_update_task = asyncio.create_task(broadcast_progress())
+        if not is_playlist and not progress_update_task:
+            progress_update_task = asyncio.create_task(broadcast_progress())
         
         coordinates = parse_theta_rho_file(file_path)
         total_coordinates = len(coordinates)
 
         if total_coordinates < 2:
             logger.warning("Not enough coordinates for interpolation")
-            state.current_playing_file = None
-            state.execution_progress = None
+            if not is_playlist:
+                state.current_playing_file = None
+                state.execution_progress = None
             return
 
         state.execution_progress = (0, total_coordinates, None, 0)
@@ -171,17 +245,23 @@ async def run_theta_rho_file(file_path):
         await asyncio.sleep(0.1)
         
         connection_manager.check_idle()
-        state.current_playing_file = None
-        state.execution_progress = None
-        logger.info("Pattern execution completed")
         
-        # Cancel progress update task
-        if progress_update_task:
+        # Only clear state if not part of a playlist
+        if not is_playlist:
+            state.current_playing_file = None
+            state.execution_progress = None
+            logger.info("Pattern execution completed and state cleared")
+        else:
+            logger.info("Pattern execution completed, maintaining state for playlist")
+        
+        # Only cancel progress update task if not part of a playlist
+        if not is_playlist and progress_update_task:
             progress_update_task.cancel()
             try:
                 await progress_update_task
             except asyncio.CancelledError:
                 pass
+            progress_update_task = None
 
 async def run_theta_rho_files(file_paths, pause_time=0, clear_pattern=None, run_mode="single", shuffle=False):
     """Run multiple .thr files in sequence with options."""
@@ -190,6 +270,12 @@ async def run_theta_rho_files(file_paths, pause_time=0, clear_pattern=None, run_
     # Set initial playlist state
     state.playlist_mode = run_mode
     state.current_playlist_index = 0
+    state.current_playlist = file_paths
+    
+    # Start progress update task for the playlist
+    global progress_update_task
+    if not progress_update_task:
+        progress_update_task = asyncio.create_task(broadcast_progress())
     
     if shuffle:
         random.shuffle(file_paths)
@@ -205,18 +291,21 @@ async def run_theta_rho_files(file_paths, pause_time=0, clear_pattern=None, run_
                     logger.info("Execution stopped before starting next pattern")
                     return
 
-                if clear_pattern:
+                if clear_pattern and clear_pattern != 'none':
                     if state.stop_requested:
                         logger.info("Execution stopped before running the next clear pattern")
                         return
 
                     clear_file_path = get_clear_pattern_file(clear_pattern, path)
-                    logger.info(f"Running clear pattern: {clear_file_path}")
-                    await run_theta_rho_file(clear_file_path)
+                    if clear_file_path:  # Only run clear pattern if we got a valid file path
+                        logger.info(f"Running clear pattern: {clear_file_path}")
+                        await run_theta_rho_file(clear_file_path, is_playlist=True)
+                    else:
+                        logger.info("Skipping clear pattern - no valid clear pattern file")
 
                 if not state.stop_requested:
                     logger.info(f"Running pattern {idx + 1} of {len(file_paths)}: {path}")
-                    await run_theta_rho_file(path)
+                    await run_theta_rho_file(path, is_playlist=True)
 
                 if idx < len(file_paths) - 1:
                     if state.stop_requested:
@@ -239,27 +328,49 @@ async def run_theta_rho_files(file_paths, pause_time=0, clear_pattern=None, run_
                 logger.info("Playlist completed")
                 break
     finally:
+        # Clean up progress update task at the end of the playlist
+        if progress_update_task:
+            progress_update_task.cancel()
+            try:
+                await progress_update_task
+            except asyncio.CancelledError:
+                pass
+            progress_update_task = None
+            
+        # Clear all state variables
+        state.current_playing_file = None
+        state.execution_progress = None
         state.current_playlist = None
         state.current_playlist_index = None
         state.playlist_mode = None
-        logger.info("All requested patterns completed (or stopped)")
+        logger.info("All requested patterns completed (or stopped) and state cleared")
 
 def stop_actions(clear_playlist = True):
     """Stop all current actions."""
-    with state.pause_condition:
-        state.pause_requested = False
-        state.stop_requested = True
-        state.current_playing_file = None
-        state.execution_progress = None
-        state.is_clearing = False
-        
-        if clear_playlist:
-            # Clear playlist state
-            state.current_playlist = None
-            state.current_playlist_index = None
-            state.playlist_mode = None
+    try:
+        with state.pause_condition:
+            state.pause_requested = False
+            state.stop_requested = True
+            state.current_playing_file = None
+            state.execution_progress = None
+            state.is_clearing = False
             
-        state.pause_condition.notify_all()
+            if clear_playlist:
+                # Clear playlist state
+                state.current_playlist = None
+                state.current_playlist_index = None
+                state.playlist_mode = None
+                
+                # Cancel progress update task if we're clearing the playlist
+                global progress_update_task
+                if progress_update_task and not progress_update_task.done():
+                    progress_update_task.cancel()
+                
+            state.pause_condition.notify_all()
+            connection_manager.update_machine_position()
+    except Exception as e:
+        logger.error(f"Error during stop_actions: {e}")
+        # Ensure we still update machine position even if there's an error
         connection_manager.update_machine_position()
 
 def move_polar(theta, rho):
@@ -379,10 +490,7 @@ async def broadcast_progress():
     """Background task to broadcast progress updates."""
     from app import active_status_connections
     while True:
-        if not pattern_lock.locked():
-            # No pattern running, stop the task
-            break
-            
+        # Send status updates regardless of pattern_lock state
         status = get_status()
         disconnected = set()
         
@@ -392,12 +500,20 @@ async def broadcast_progress():
         for websocket in active_connections:
             try:
                 await websocket.send_json(status)
-            except Exception:
+            except Exception as e:
+                logger.warning(f"Failed to send status update: {e}")
                 disconnected.add(websocket)
         
         # Clean up disconnected clients
         if disconnected:
             active_status_connections.difference_update(disconnected)
+            
+        # Check if we should stop broadcasting
+        if not state.current_playlist:
+            # If no playlist, only stop if no pattern is being executed
+            if not pattern_lock.locked():
+                logger.info("No playlist or pattern running, stopping broadcast")
+                break
         
         # Wait before next update
         await asyncio.sleep(1)

+ 34 - 0
modules/core/state.py

@@ -2,6 +2,9 @@
 import threading
 import json
 import os
+import logging
+
+logger = logging.getLogger(__name__)
 
 class AppState:
     def __init__(self):
@@ -211,5 +214,36 @@ class AppState:
         self.__init__()  # Reinitialize the state
         self.save()
 
+    def cleanup(self):
+        """Clean up AppState resources."""
+        try:
+            # Notify all waiting threads and clean up the condition
+            if self.pause_condition:
+                try:
+                    with self.pause_condition:
+                        self.pause_condition.notify_all()
+                    # Release the underlying lock resources
+                    self.pause_condition._lock._release_save()
+                    self.pause_condition._lock = None
+                except Exception as e:
+                    logger.error(f"Error cleaning up pause condition: {e}")
+                finally:
+                    self.pause_condition = None
+            
+            # Clean up other resources
+            if self.conn:
+                try:
+                    self.conn.close()
+                except Exception as e:
+                    logger.error(f"Error closing connection: {e}")
+                finally:
+                    self.conn = None
+                    
+            self.mqtt_handler = None
+            logger.info("AppState resources cleaned up")
+        except Exception as e:
+            logger.error(f"Error during AppState cleanup: {e}")
+            raise
+
 # Create a singleton instance that you can import elsewhere:
 state = AppState()

+ 25 - 4
modules/mqtt/handler.py

@@ -487,11 +487,32 @@ class MQTTHandler(BaseMQTTHandler):
         if not self.is_enabled:
             return
 
+        # First stop the running flag to prevent new iterations
         self.running = False
-        if self.status_thread:
-            self.status_thread.join(timeout=1)
-        self.client.loop_stop()
-        self.client.disconnect()
+        
+        # Clean up status thread
+        local_status_thread = self.status_thread  # Keep a local reference
+        if local_status_thread and local_status_thread.is_alive():
+            try:
+                local_status_thread.join(timeout=5)
+                if local_status_thread.is_alive():
+                    logger.warning("MQTT status thread did not terminate cleanly")
+            except Exception as e:
+                logger.error(f"Error joining status thread: {e}")
+        self.status_thread = None
+            
+        # Clean up MQTT client
+        try:
+            if hasattr(self, 'client'):
+                self.client.loop_stop()
+                self.client.disconnect()
+        except Exception as e:
+            logger.error(f"Error disconnecting MQTT client: {e}")
+        
+        # Clean up main loop reference
+        self.main_loop = None
+        
+        logger.info("MQTT handler stopped")
 
     @property
     def is_enabled(self) -> bool: