Files
AI-TianDong/main.py

182 lines
7.8 KiB
Python
Raw Normal View History

2025-07-24 12:45:27 +08:00
# main.py (完整替换)
from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel, HttpUrl
import logging
import logging.handlers
import os
from contextlib import asynccontextmanager
import requests
import cv2
import numpy as np
import io
from PIL import Image
from typing import List, Dict, Any, Set
# --- Detector Module Imports ---
# 我们现在只需要导入 detect_hardhat_with_sahi 来处理所有PPE相关的检测
from utils.hardhat_detector import initialize_helmet_model, detect_hardhat_with_sahi
from utils.smoking_detector import initialize_smoking_model, detect_smoking_behavior
from utils.fire_smoke_detector import initialize_fire_smoke_model, detect_fire_smoke_with_sahi
from utils.pho_detector import initialize_all_pho_models, detect_objects_by_class, PHO_DETECTOR_SUPPORTED_CLASSES
# --- 1. 中央日志系统配置 ---
def setup_logging():
LOG_DIR = "logs"
if not os.path.exists(LOG_DIR):
os.makedirs(LOG_DIR)
log_file_path = os.path.join(LOG_DIR, "app.log")
log_formatter = logging.Formatter(
"%(asctime)s - %(levelname)s - [%(name)s] - [%(module)s.%(funcName)s:%(lineno)d] - %(message)s"
)
file_handler = logging.handlers.RotatingFileHandler(
log_file_path, maxBytes=5*1024*1024, backupCount=5, encoding="utf-8"
)
file_handler.setFormatter(log_formatter)
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(log_formatter)
root_logger = logging.getLogger()
root_logger.setLevel(logging.INFO)
if root_logger.hasHandlers():
root_logger.handlers.clear()
root_logger.addHandler(file_handler)
root_logger.addHandler(stream_handler)
return logging.getLogger(__name__)
logger = setup_logging()
# --- 2. 应用启动/关闭事件 ---
@asynccontextmanager
async def lifespan(app: FastAPI):
"""在应用启动时初始化所有模型。"""
logger.info("应用启动程序开始...")
# (此处的模型初始化调用保持不变)
try: initialize_helmet_model()
except Exception: logger.error("初始化安全帽PPE模型失败", exc_info=True)
try: initialize_smoking_model()
except Exception: logger.error("初始化吸烟smoking模型失败", exc_info=True)
try: initialize_fire_smoke_model()
except Exception: logger.error("初始化火焰和烟雾模型失败", exc_info=True)
try: initialize_all_pho_models()
except Exception: logger.error("初始化光伏模型失败", exc_info=True)
yield
logger.info("应用已关闭。")
app = FastAPI(lifespan=lifespan)
# --- 3. 请求模型和辅助函数 ---
class DetectionRequest(BaseModel):
type: str
url: HttpUrl
extract: bool = False
def url_to_cv_image(url: str) -> np.ndarray:
try:
response = requests.get(str(url), timeout=15)
response.raise_for_status()
image_pil = Image.open(io.BytesIO(response.content)).convert("RGB")
return cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
except requests.exceptions.RequestException as e:
raise HTTPException(status_code=400, detail=f"无法下载或处理图片。错误: {e}")
except Exception as e:
raise HTTPException(status_code=500, detail=f"处理图片时发生未知错误: {e}")
# --- 4. 路由和核心逻辑 ---
# 新增定义与PPE相关的任务关键字和它们的映射关系
PPE_TASK_KEYWORDS = {"nohelmet", "novest"}
VIOLATION_MAP = {
"nohelmet": "nohelmet",
"novest": "novest",
}
@app.post("/detect_image")
async def run_detection(request: DetectionRequest):
"""主检测端点将任务分发给相应的检测器现在会聚合PPE任务。"""
logger.info(f"收到检测请求types='{request.type}', url='{request.url}'")
try:
image_cv = url_to_cv_image(request.url)
original_img_size = [image_cv.shape[1], image_cv.shape[0]]
except HTTPException as e:
logger.error(f"图像处理失败URL: {request.url}", exc_info=True)
return JSONResponse(status_code=e.status_code, content={"error": e.detail})
requested_types: Set[str] = {t.strip().lower() for t in request.type.split(',') if t.strip()}
if not requested_types:
return JSONResponse(status_code=400, content={"error": "未指定任何检测类型。"})
logger.info(f"已解析的检测类型: {requested_types}")
aggregated_targets: List[Dict] = []
error_messages: List[str] = []
# --- 新的智能分发逻辑 ---
# 1. 识别并聚合所有PPE相关的任务
ppe_tasks_requested = requested_types.intersection(PPE_TASK_KEYWORDS)
if ppe_tasks_requested:
# 将用户请求的任务关键字 (hardhat, vest) 转换为后端需要的违规代码 (nohelmet, novest)
violations_to_report = {VIOLATION_MAP[task] for task in ppe_tasks_requested}
logger.info(f"分发PPE任务到检测器: {ppe_tasks_requested}, 需要报告的违规: {violations_to_report}")
try:
# 将需要报告的违规类型集合传递给检测函数
results = detect_hardhat_with_sahi(image_cv, requested_violations=violations_to_report, extract=request.extract)
aggregated_targets.extend(results)
except Exception as e:
logger.error(f"处理PPE任务 {ppe_tasks_requested} 时发生严重错误", exc_info=True)
error_messages.append(f"处理PPE任务时出错: {e}")
# 2. 识别并分发光伏板相关任务
pho_tasks_to_run = requested_types.intersection(PHO_DETECTOR_SUPPORTED_CLASSES)
if pho_tasks_to_run:
logger.info(f"分发光伏板相关任务到检测器: {pho_tasks_to_run}")
try:
# 确保传递一个列表给detect_objects_by_class
results = detect_objects_by_class(image_cv, list(pho_tasks_to_run))
aggregated_targets.extend(results)
except Exception as e:
logger.error(f"处理光伏板任务 {pho_tasks_to_run} 时发生严重错误", exc_info=True)
error_messages.append(f"处理光伏板任务时出错: {e}")
# 3. 处理其他完全独立的任务
remaining_tasks = requested_types - ppe_tasks_requested - pho_tasks_to_run
if remaining_tasks:
logger.info(f"正在处理其他独立任务: {remaining_tasks}")
for task in remaining_tasks:
try:
detector_results = None
if task == "smoking":
detector_results = detect_smoking_behavior(image_cv, extract=request.extract)
elif task in ["fire", "smoke"]:
detector_results = detect_fire_smoke_with_sahi(image_cv, extract=request.extract)
else:
logger.warning(f"检测到未知任务类型,将被忽略: {task}")
error_messages.append(f"未知的检测类型: '{task}'")
continue
if isinstance(detector_results, list):
aggregated_targets.extend(detector_results)
except Exception as e:
logger.error(f"调用检测器处理任务 '{task}' 时失败", exc_info=True)
error_messages.append(f"处理 '{task}' 时出错: {e}")
# --- 最终响应构建 ---
response_content = {
"hasTarget": 1 if aggregated_targets else 0,
"originalImgSize": original_img_size,
"targets": aggregated_targets,
}
if error_messages:
response_content["processing_errors"] = error_messages
logger.warning(f"请求处理完成,但包含错误信息: {error_messages}")
return JSONResponse(status_code=200, content=response_content)
@app.get("/")
def read_root():
return {"message": "高级检测API正在运行。"}