Sfoglia il codice sorgente

all endpoints should run async

tuanchris 4 mesi fa
parent
commit
41086c305e
1 ha cambiato i file con 48 aggiunte e 25 eliminazioni
  1. 48 25
      main.py

+ 48 - 25
main.py

@@ -333,7 +333,8 @@ async def restart(request: ConnectRequest):
 @app.get("/list_theta_rho_files")
 async def list_theta_rho_files():
     logger.debug("Listing theta-rho files")
-    files = pattern_manager.list_theta_rho_files()
+    # Run the blocking file system operation in a thread pool
+    files = await asyncio.to_thread(pattern_manager.list_theta_rho_files)
     return sorted(files)
 
 @app.get("/list_theta_rho_files_with_metadata")
@@ -346,9 +347,10 @@ async def list_theta_rho_files_with_metadata():
     import asyncio
     from concurrent.futures import ThreadPoolExecutor
     
-    files = pattern_manager.list_theta_rho_files()
+    # Run the blocking file listing in a thread
+    files = await asyncio.to_thread(pattern_manager.list_theta_rho_files)
     files_with_metadata = []
-    
+
     # Use ThreadPoolExecutor for I/O-bound operations
     executor = ThreadPoolExecutor(max_workers=4)
     
@@ -473,11 +475,13 @@ async def get_theta_rho_coordinates(request: GetCoordinatesRequest):
         file_name = normalize_file_path(request.file_name)
         file_path = os.path.join(THETA_RHO_DIR, file_name)
         
-        if not os.path.exists(file_path):
+        # Check file existence asynchronously
+        exists = await asyncio.to_thread(os.path.exists, file_path)
+        if not exists:
             raise HTTPException(status_code=404, detail=f"File {file_name} not found")
-        
-        # Parse the theta-rho file
-        coordinates = parse_theta_rho_file(file_path)
+
+        # Parse the theta-rho file asynchronously
+        coordinates = await asyncio.to_thread(parse_theta_rho_file, file_path)
         
         if not coordinates:
             raise HTTPException(status_code=400, detail="No valid coordinates found in file")
@@ -595,18 +599,21 @@ async def delete_theta_rho_file(request: DeleteFileRequest):
     # Normalize file path for cross-platform compatibility
     normalized_file_name = normalize_file_path(request.file_name)
     file_path = os.path.join(pattern_manager.THETA_RHO_DIR, normalized_file_name)
-    if not os.path.exists(file_path):
+
+    # Check file existence asynchronously
+    exists = await asyncio.to_thread(os.path.exists, file_path)
+    if not exists:
         logger.error(f"Attempted to delete non-existent file: {file_path}")
         raise HTTPException(status_code=404, detail="File not found")
 
     try:
-        # Delete the pattern file
-        os.remove(file_path)
+        # Delete the pattern file asynchronously
+        await asyncio.to_thread(os.remove, file_path)
         logger.info(f"Successfully deleted theta-rho file: {request.file_name}")
         
-        # Clean up cached preview image and metadata
+        # Clean up cached preview image and metadata asynchronously
         from modules.core.cache_manager import delete_pattern_cache
-        cache_cleanup_success = delete_pattern_cache(normalized_file_name)
+        cache_cleanup_success = await asyncio.to_thread(delete_pattern_cache, normalized_file_name)
         if cache_cleanup_success:
             logger.info(f"Successfully cleaned up cache for {request.file_name}")
         else:
@@ -655,18 +662,24 @@ async def preview_thr(request: DeleteFileRequest):
     normalized_file_name = normalize_file_path(request.file_name)
     # Construct the full path to the pattern file to check existence
     pattern_file_path = os.path.join(pattern_manager.THETA_RHO_DIR, normalized_file_name)
-    if not os.path.exists(pattern_file_path):
+
+    # Check file existence asynchronously
+    exists = await asyncio.to_thread(os.path.exists, pattern_file_path)
+    if not exists:
         logger.error(f"Attempted to preview non-existent pattern file: {pattern_file_path}")
         raise HTTPException(status_code=404, detail="Pattern file not found")
 
     try:
         cache_path = get_cache_path(normalized_file_name)
-        
-        if not os.path.exists(cache_path):
+
+        # Check cache existence asynchronously
+        cache_exists = await asyncio.to_thread(os.path.exists, cache_path)
+        if not cache_exists:
             logger.info(f"Cache miss for {request.file_name}. Generating preview...")
             # Attempt to generate the preview if it's missing
             success = await generate_image_preview(normalized_file_name)
-            if not success or not os.path.exists(cache_path):
+            cache_exists_after = await asyncio.to_thread(os.path.exists, cache_path)
+            if not success or not cache_exists_after:
                 logger.error(f"Failed to generate or find preview for {request.file_name} after attempting generation.")
                 raise HTTPException(status_code=500, detail="Failed to generate preview image.")
 
@@ -1044,14 +1057,17 @@ async def preview_thr_batch(request: dict):
         "Content-Type": "application/json"
     }
 
-    results = {}
-    for file_name in file_names:
+    async def process_single_file(file_name):
+        """Process a single file and return its preview data."""
         t1 = time.time()
         try:
             # Normalize file path for cross-platform compatibility
             normalized_file_name = normalize_file_path(file_name)
             pattern_file_path = os.path.join(pattern_manager.THETA_RHO_DIR, normalized_file_name)
-            if not os.path.exists(pattern_file_path):
+
+            # Check file existence asynchronously
+            exists = await asyncio.to_thread(os.path.exists, pattern_file_path)
+            if not exists:
                 logger.warning(f"Pattern file not found: {pattern_file_path}")
                 results[file_name] = {"error": "Pattern file not found"}
                 continue
@@ -1078,19 +1094,26 @@ async def preview_thr_batch(request: dict):
                 first_coord_obj = {"x": first_coord[0], "y": first_coord[1]} if first_coord else None
                 last_coord_obj = {"x": last_coord[0], "y": last_coord[1]} if last_coord else None
 
-            with open(cache_path, 'rb') as f:
-                image_data = f.read()
+            # Read image file asynchronously
+            image_data = await asyncio.to_thread(lambda: open(cache_path, 'rb').read())
             image_b64 = base64.b64encode(image_data).decode('utf-8')
-            results[file_name] = {
+            result = {
                 "image_data": f"data:image/webp;base64,{image_b64}",
                 "first_coordinate": first_coord_obj,
                 "last_coordinate": last_coord_obj
             }
+            logger.debug(f"Processed {file_name} in {time.time() - t1:.2f}s")
+            return file_name, result
         except Exception as e:
             logger.error(f"Error processing {file_name}: {str(e)}")
-            results[file_name] = {"error": str(e)}
-        finally:
-            logger.debug(f"Processed {file_name} in {time.time() - t1:.2f}s")
+            return file_name, {"error": str(e)}
+
+    # Process all files concurrently
+    tasks = [process_single_file(file_name) for file_name in file_names]
+    file_results = await asyncio.gather(*tasks)
+
+    # Convert results to dictionary
+    results = dict(file_results)
 
     logger.info(f"Total batch processing time: {time.time() - start:.2f}s for {len(file_names)} files")
     return JSONResponse(content=results, headers=headers)