# main.py (完整替换) from fastapi import FastAPI, HTTPException from fastapi.responses import JSONResponse from pydantic import BaseModel, HttpUrl import logging import logging.handlers import os from contextlib import asynccontextmanager import requests import cv2 import numpy as np import io from PIL import Image from typing import List, Dict, Any, Set # --- Detector Module Imports --- # 我们现在只需要导入 detect_hardhat_with_sahi 来处理所有PPE相关的检测 from utils.hardhat_detector import initialize_helmet_model, detect_hardhat_with_sahi from utils.smoking_detector import initialize_smoking_model, detect_smoking_behavior from utils.fire_smoke_detector import initialize_fire_smoke_model, detect_fire_smoke_with_sahi from utils.pho_detector import initialize_all_pho_models, detect_objects_by_class, PHO_DETECTOR_SUPPORTED_CLASSES # --- 1. 中央日志系统配置 --- def setup_logging(): LOG_DIR = "logs" if not os.path.exists(LOG_DIR): os.makedirs(LOG_DIR) log_file_path = os.path.join(LOG_DIR, "app.log") log_formatter = logging.Formatter( "%(asctime)s - %(levelname)s - [%(name)s] - [%(module)s.%(funcName)s:%(lineno)d] - %(message)s" ) file_handler = logging.handlers.RotatingFileHandler( log_file_path, maxBytes=5*1024*1024, backupCount=5, encoding="utf-8" ) file_handler.setFormatter(log_formatter) stream_handler = logging.StreamHandler() stream_handler.setFormatter(log_formatter) root_logger = logging.getLogger() root_logger.setLevel(logging.INFO) if root_logger.hasHandlers(): root_logger.handlers.clear() root_logger.addHandler(file_handler) root_logger.addHandler(stream_handler) return logging.getLogger(__name__) logger = setup_logging() # --- 2. 应用启动/关闭事件 --- @asynccontextmanager async def lifespan(app: FastAPI): """在应用启动时初始化所有模型。""" logger.info("应用启动程序开始...") # (此处的模型初始化调用保持不变) try: initialize_helmet_model() except Exception: logger.error("初始化安全帽(PPE)模型失败", exc_info=True) try: initialize_smoking_model() except Exception: logger.error("初始化吸烟(smoking)模型失败", exc_info=True) try: initialize_fire_smoke_model() except Exception: logger.error("初始化火焰和烟雾模型失败", exc_info=True) try: initialize_all_pho_models() except Exception: logger.error("初始化光伏模型失败", exc_info=True) yield logger.info("应用已关闭。") app = FastAPI(lifespan=lifespan) # --- 3. 请求模型和辅助函数 --- class DetectionRequest(BaseModel): type: str url: HttpUrl extract: bool = False def url_to_cv_image(url: str) -> np.ndarray: try: response = requests.get(str(url), timeout=15) response.raise_for_status() image_pil = Image.open(io.BytesIO(response.content)).convert("RGB") return cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR) except requests.exceptions.RequestException as e: raise HTTPException(status_code=400, detail=f"无法下载或处理图片。错误: {e}") except Exception as e: raise HTTPException(status_code=500, detail=f"处理图片时发生未知错误: {e}") # --- 4. 路由和核心逻辑 --- # 新增:定义与PPE相关的任务关键字和它们的映射关系 PPE_TASK_KEYWORDS = {"nohelmet", "novest"} VIOLATION_MAP = { "nohelmet": "nohelmet", "novest": "novest", } @app.post("/detect_image") async def run_detection(request: DetectionRequest): """主检测端点,将任务分发给相应的检测器,现在会聚合PPE任务。""" logger.info(f"收到检测请求:types='{request.type}', url='{request.url}'") try: image_cv = url_to_cv_image(request.url) original_img_size = [image_cv.shape[1], image_cv.shape[0]] except HTTPException as e: logger.error(f"图像处理失败,URL: {request.url}", exc_info=True) return JSONResponse(status_code=e.status_code, content={"error": e.detail}) requested_types: Set[str] = {t.strip().lower() for t in request.type.split(',') if t.strip()} if not requested_types: return JSONResponse(status_code=400, content={"error": "未指定任何检测类型。"}) logger.info(f"已解析的检测类型: {requested_types}") aggregated_targets: List[Dict] = [] error_messages: List[str] = [] # --- 新的智能分发逻辑 --- # 1. 识别并聚合所有PPE相关的任务 ppe_tasks_requested = requested_types.intersection(PPE_TASK_KEYWORDS) if ppe_tasks_requested: # 将用户请求的任务关键字 (hardhat, vest) 转换为后端需要的违规代码 (nohelmet, novest) violations_to_report = {VIOLATION_MAP[task] for task in ppe_tasks_requested} logger.info(f"分发PPE任务到检测器: {ppe_tasks_requested}, 需要报告的违规: {violations_to_report}") try: # 将需要报告的违规类型集合传递给检测函数 results = detect_hardhat_with_sahi(image_cv, requested_violations=violations_to_report, extract=request.extract) aggregated_targets.extend(results) except Exception as e: logger.error(f"处理PPE任务 {ppe_tasks_requested} 时发生严重错误", exc_info=True) error_messages.append(f"处理PPE任务时出错: {e}") # 2. 识别并分发光伏板相关任务 pho_tasks_to_run = requested_types.intersection(PHO_DETECTOR_SUPPORTED_CLASSES) if pho_tasks_to_run: logger.info(f"分发光伏板相关任务到检测器: {pho_tasks_to_run}") try: # 确保传递一个列表给detect_objects_by_class results = detect_objects_by_class(image_cv, list(pho_tasks_to_run)) aggregated_targets.extend(results) except Exception as e: logger.error(f"处理光伏板任务 {pho_tasks_to_run} 时发生严重错误", exc_info=True) error_messages.append(f"处理光伏板任务时出错: {e}") # 3. 处理其他完全独立的任务 remaining_tasks = requested_types - ppe_tasks_requested - pho_tasks_to_run if remaining_tasks: logger.info(f"正在处理其他独立任务: {remaining_tasks}") for task in remaining_tasks: try: detector_results = None if task == "smoking": detector_results = detect_smoking_behavior(image_cv, extract=request.extract) elif task in ["fire", "smoke"]: detector_results = detect_fire_smoke_with_sahi(image_cv, extract=request.extract) else: logger.warning(f"检测到未知任务类型,将被忽略: {task}") error_messages.append(f"未知的检测类型: '{task}'") continue if isinstance(detector_results, list): aggregated_targets.extend(detector_results) except Exception as e: logger.error(f"调用检测器处理任务 '{task}' 时失败", exc_info=True) error_messages.append(f"处理 '{task}' 时出错: {e}") # --- 最终响应构建 --- response_content = { "hasTarget": 1 if aggregated_targets else 0, "originalImgSize": original_img_size, "targets": aggregated_targets, } if error_messages: response_content["processing_errors"] = error_messages logger.warning(f"请求处理完成,但包含错误信息: {error_messages}") return JSONResponse(status_code=200, content=response_content) @app.get("/") def read_root(): return {"message": "高级检测API正在运行。"}