182 lines
7.8 KiB
Python
182 lines
7.8 KiB
Python
# main.py (完整替换)
|
||
|
||
from fastapi import FastAPI, HTTPException
|
||
from fastapi.responses import JSONResponse
|
||
from pydantic import BaseModel, HttpUrl
|
||
import logging
|
||
import logging.handlers
|
||
import os
|
||
from contextlib import asynccontextmanager
|
||
import requests
|
||
import cv2
|
||
import numpy as np
|
||
import io
|
||
from PIL import Image
|
||
from typing import List, Dict, Any, Set
|
||
|
||
# --- Detector Module Imports ---
|
||
# 我们现在只需要导入 detect_hardhat_with_sahi 来处理所有PPE相关的检测
|
||
from utils.hardhat_detector import initialize_helmet_model, detect_hardhat_with_sahi
|
||
from utils.smoking_detector import initialize_smoking_model, detect_smoking_behavior
|
||
from utils.fire_smoke_detector import initialize_fire_smoke_model, detect_fire_smoke_with_sahi
|
||
from utils.pho_detector import initialize_all_pho_models, detect_objects_by_class, PHO_DETECTOR_SUPPORTED_CLASSES
|
||
|
||
|
||
# --- 1. 中央日志系统配置 ---
|
||
def setup_logging():
|
||
LOG_DIR = "logs"
|
||
if not os.path.exists(LOG_DIR):
|
||
os.makedirs(LOG_DIR)
|
||
log_file_path = os.path.join(LOG_DIR, "app.log")
|
||
log_formatter = logging.Formatter(
|
||
"%(asctime)s - %(levelname)s - [%(name)s] - [%(module)s.%(funcName)s:%(lineno)d] - %(message)s"
|
||
)
|
||
file_handler = logging.handlers.RotatingFileHandler(
|
||
log_file_path, maxBytes=5*1024*1024, backupCount=5, encoding="utf-8"
|
||
)
|
||
file_handler.setFormatter(log_formatter)
|
||
stream_handler = logging.StreamHandler()
|
||
stream_handler.setFormatter(log_formatter)
|
||
root_logger = logging.getLogger()
|
||
root_logger.setLevel(logging.INFO)
|
||
if root_logger.hasHandlers():
|
||
root_logger.handlers.clear()
|
||
root_logger.addHandler(file_handler)
|
||
root_logger.addHandler(stream_handler)
|
||
return logging.getLogger(__name__)
|
||
|
||
logger = setup_logging()
|
||
|
||
# --- 2. 应用启动/关闭事件 ---
|
||
@asynccontextmanager
|
||
async def lifespan(app: FastAPI):
|
||
"""在应用启动时初始化所有模型。"""
|
||
logger.info("应用启动程序开始...")
|
||
# (此处的模型初始化调用保持不变)
|
||
try: initialize_helmet_model()
|
||
except Exception: logger.error("初始化安全帽(PPE)模型失败", exc_info=True)
|
||
try: initialize_smoking_model()
|
||
except Exception: logger.error("初始化吸烟(smoking)模型失败", exc_info=True)
|
||
try: initialize_fire_smoke_model()
|
||
except Exception: logger.error("初始化火焰和烟雾模型失败", exc_info=True)
|
||
try: initialize_all_pho_models()
|
||
except Exception: logger.error("初始化光伏模型失败", exc_info=True)
|
||
yield
|
||
logger.info("应用已关闭。")
|
||
|
||
app = FastAPI(lifespan=lifespan)
|
||
|
||
|
||
# --- 3. 请求模型和辅助函数 ---
|
||
class DetectionRequest(BaseModel):
|
||
type: str
|
||
url: HttpUrl
|
||
extract: bool = False
|
||
|
||
def url_to_cv_image(url: str) -> np.ndarray:
|
||
try:
|
||
response = requests.get(str(url), timeout=15)
|
||
response.raise_for_status()
|
||
image_pil = Image.open(io.BytesIO(response.content)).convert("RGB")
|
||
return cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
|
||
except requests.exceptions.RequestException as e:
|
||
raise HTTPException(status_code=400, detail=f"无法下载或处理图片。错误: {e}")
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"处理图片时发生未知错误: {e}")
|
||
|
||
|
||
# --- 4. 路由和核心逻辑 ---
|
||
|
||
# 新增:定义与PPE相关的任务关键字和它们的映射关系
|
||
PPE_TASK_KEYWORDS = {"nohelmet", "novest"}
|
||
VIOLATION_MAP = {
|
||
"nohelmet": "nohelmet",
|
||
"novest": "novest",
|
||
}
|
||
|
||
@app.post("/detect_image")
|
||
async def run_detection(request: DetectionRequest):
|
||
"""主检测端点,将任务分发给相应的检测器,现在会聚合PPE任务。"""
|
||
logger.info(f"收到检测请求:types='{request.type}', url='{request.url}'")
|
||
|
||
try:
|
||
image_cv = url_to_cv_image(request.url)
|
||
original_img_size = [image_cv.shape[1], image_cv.shape[0]]
|
||
except HTTPException as e:
|
||
logger.error(f"图像处理失败,URL: {request.url}", exc_info=True)
|
||
return JSONResponse(status_code=e.status_code, content={"error": e.detail})
|
||
|
||
requested_types: Set[str] = {t.strip().lower() for t in request.type.split(',') if t.strip()}
|
||
if not requested_types:
|
||
return JSONResponse(status_code=400, content={"error": "未指定任何检测类型。"})
|
||
logger.info(f"已解析的检测类型: {requested_types}")
|
||
|
||
aggregated_targets: List[Dict] = []
|
||
error_messages: List[str] = []
|
||
|
||
# --- 新的智能分发逻辑 ---
|
||
# 1. 识别并聚合所有PPE相关的任务
|
||
ppe_tasks_requested = requested_types.intersection(PPE_TASK_KEYWORDS)
|
||
if ppe_tasks_requested:
|
||
# 将用户请求的任务关键字 (hardhat, vest) 转换为后端需要的违规代码 (nohelmet, novest)
|
||
violations_to_report = {VIOLATION_MAP[task] for task in ppe_tasks_requested}
|
||
|
||
logger.info(f"分发PPE任务到检测器: {ppe_tasks_requested}, 需要报告的违规: {violations_to_report}")
|
||
try:
|
||
# 将需要报告的违规类型集合传递给检测函数
|
||
results = detect_hardhat_with_sahi(image_cv, requested_violations=violations_to_report, extract=request.extract)
|
||
aggregated_targets.extend(results)
|
||
except Exception as e:
|
||
logger.error(f"处理PPE任务 {ppe_tasks_requested} 时发生严重错误", exc_info=True)
|
||
error_messages.append(f"处理PPE任务时出错: {e}")
|
||
|
||
# 2. 识别并分发光伏板相关任务
|
||
pho_tasks_to_run = requested_types.intersection(PHO_DETECTOR_SUPPORTED_CLASSES)
|
||
if pho_tasks_to_run:
|
||
logger.info(f"分发光伏板相关任务到检测器: {pho_tasks_to_run}")
|
||
try:
|
||
# 确保传递一个列表给detect_objects_by_class
|
||
results = detect_objects_by_class(image_cv, list(pho_tasks_to_run))
|
||
aggregated_targets.extend(results)
|
||
except Exception as e:
|
||
logger.error(f"处理光伏板任务 {pho_tasks_to_run} 时发生严重错误", exc_info=True)
|
||
error_messages.append(f"处理光伏板任务时出错: {e}")
|
||
|
||
# 3. 处理其他完全独立的任务
|
||
remaining_tasks = requested_types - ppe_tasks_requested - pho_tasks_to_run
|
||
if remaining_tasks:
|
||
logger.info(f"正在处理其他独立任务: {remaining_tasks}")
|
||
for task in remaining_tasks:
|
||
try:
|
||
detector_results = None
|
||
if task == "smoking":
|
||
detector_results = detect_smoking_behavior(image_cv, extract=request.extract)
|
||
elif task in ["fire", "smoke"]:
|
||
detector_results = detect_fire_smoke_with_sahi(image_cv, extract=request.extract)
|
||
else:
|
||
logger.warning(f"检测到未知任务类型,将被忽略: {task}")
|
||
error_messages.append(f"未知的检测类型: '{task}'")
|
||
continue
|
||
|
||
if isinstance(detector_results, list):
|
||
aggregated_targets.extend(detector_results)
|
||
except Exception as e:
|
||
logger.error(f"调用检测器处理任务 '{task}' 时失败", exc_info=True)
|
||
error_messages.append(f"处理 '{task}' 时出错: {e}")
|
||
|
||
# --- 最终响应构建 ---
|
||
response_content = {
|
||
"hasTarget": 1 if aggregated_targets else 0,
|
||
"originalImgSize": original_img_size,
|
||
"targets": aggregated_targets,
|
||
}
|
||
if error_messages:
|
||
response_content["processing_errors"] = error_messages
|
||
logger.warning(f"请求处理完成,但包含错误信息: {error_messages}")
|
||
|
||
return JSONResponse(status_code=200, content=response_content)
|
||
|
||
|
||
@app.get("/")
|
||
def read_root():
|
||
return {"message": "高级检测API正在运行。"} |