现有服务备份

This commit is contained in:
2025-07-24 12:45:27 +08:00
commit 7ae047c7c2
23 changed files with 161390 additions and 0 deletions

42
Dockerfile Normal file
View File

@ -0,0 +1,42 @@
# 选择一个包含 Python 的基础镜像,考虑一个带编译工具的以防某些库需要编译
FROM nvidia/cuda:12.6.0-devel-ubuntu22.04
# 设置时区(可选,但对于日志时间戳有好处)
ENV TZ=Asia/Shanghai
RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
# 设置工作目录
WORKDIR /app
# 安装系统依赖(例如 OpenCV 可能需要)
# libgl1-mesa-glx 是常见的桌面OpenGL库对于 headless server 来说libgl1可能就够了
# 有些版本的OpenCV需要特定的共享库
RUN apt-get update && apt-get install -y --no-install-recommends \
libgl1 \
libglib2.0-0 libsm6 libxext6 libxrender-dev \
&& apt-get clean && rm -rf /var/lib/apt/lists/*
# 安装 Python 和 pip
RUN apt-get update && apt-get install -y python3 python3-pip
ENV OMP_NUM_THREADS=1
# 复制依赖文件并安装
COPY requirements.txt .
RUN pip install torch==2.6.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126
RUN pip install --no-cache-dir -i https://pypi.tuna.tsinghua.edu.cn/simple -r requirements.txt
# 复制整个应用代码和模型到容器中
COPY . .
# 确保模型目录存在且模型已复制 (Dockerfile中的COPY指令会处理此问题)
# 运行前检查模型路径是否正确:
# RUN ls -lR app/models/ # 这可以帮助你在构建时调试路径问题
# 暴露API服务运行的端口
EXPOSE 8000
# 容器启动时运行的命令
# 使用 --host 0.0.0.0 使服务可以从容器外部访问
# --reload 用于开发,生产中通常不使用
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]

20897
logs/app.log Normal file

File diff suppressed because it is too large Load Diff

28105
logs/app.log.1 Normal file

File diff suppressed because it is too large Load Diff

27844
logs/app.log.2 Normal file

File diff suppressed because it is too large Load Diff

27804
logs/app.log.3 Normal file

File diff suppressed because it is too large Load Diff

27846
logs/app.log.4 Normal file

File diff suppressed because it is too large Load Diff

27835
logs/app.log.5 Normal file

File diff suppressed because it is too large Load Diff

182
main.py Normal file
View File

@ -0,0 +1,182 @@
# main.py (完整替换)
from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel, HttpUrl
import logging
import logging.handlers
import os
from contextlib import asynccontextmanager
import requests
import cv2
import numpy as np
import io
from PIL import Image
from typing import List, Dict, Any, Set
# --- Detector Module Imports ---
# 我们现在只需要导入 detect_hardhat_with_sahi 来处理所有PPE相关的检测
from utils.hardhat_detector import initialize_helmet_model, detect_hardhat_with_sahi
from utils.smoking_detector import initialize_smoking_model, detect_smoking_behavior
from utils.fire_smoke_detector import initialize_fire_smoke_model, detect_fire_smoke_with_sahi
from utils.pho_detector import initialize_all_pho_models, detect_objects_by_class, PHO_DETECTOR_SUPPORTED_CLASSES
# --- 1. 中央日志系统配置 ---
def setup_logging():
LOG_DIR = "logs"
if not os.path.exists(LOG_DIR):
os.makedirs(LOG_DIR)
log_file_path = os.path.join(LOG_DIR, "app.log")
log_formatter = logging.Formatter(
"%(asctime)s - %(levelname)s - [%(name)s] - [%(module)s.%(funcName)s:%(lineno)d] - %(message)s"
)
file_handler = logging.handlers.RotatingFileHandler(
log_file_path, maxBytes=5*1024*1024, backupCount=5, encoding="utf-8"
)
file_handler.setFormatter(log_formatter)
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(log_formatter)
root_logger = logging.getLogger()
root_logger.setLevel(logging.INFO)
if root_logger.hasHandlers():
root_logger.handlers.clear()
root_logger.addHandler(file_handler)
root_logger.addHandler(stream_handler)
return logging.getLogger(__name__)
logger = setup_logging()
# --- 2. 应用启动/关闭事件 ---
@asynccontextmanager
async def lifespan(app: FastAPI):
"""在应用启动时初始化所有模型。"""
logger.info("应用启动程序开始...")
# (此处的模型初始化调用保持不变)
try: initialize_helmet_model()
except Exception: logger.error("初始化安全帽PPE模型失败", exc_info=True)
try: initialize_smoking_model()
except Exception: logger.error("初始化吸烟smoking模型失败", exc_info=True)
try: initialize_fire_smoke_model()
except Exception: logger.error("初始化火焰和烟雾模型失败", exc_info=True)
try: initialize_all_pho_models()
except Exception: logger.error("初始化光伏模型失败", exc_info=True)
yield
logger.info("应用已关闭。")
app = FastAPI(lifespan=lifespan)
# --- 3. 请求模型和辅助函数 ---
class DetectionRequest(BaseModel):
type: str
url: HttpUrl
extract: bool = False
def url_to_cv_image(url: str) -> np.ndarray:
try:
response = requests.get(str(url), timeout=15)
response.raise_for_status()
image_pil = Image.open(io.BytesIO(response.content)).convert("RGB")
return cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
except requests.exceptions.RequestException as e:
raise HTTPException(status_code=400, detail=f"无法下载或处理图片。错误: {e}")
except Exception as e:
raise HTTPException(status_code=500, detail=f"处理图片时发生未知错误: {e}")
# --- 4. 路由和核心逻辑 ---
# 新增定义与PPE相关的任务关键字和它们的映射关系
PPE_TASK_KEYWORDS = {"nohelmet", "novest"}
VIOLATION_MAP = {
"nohelmet": "nohelmet",
"novest": "novest",
}
@app.post("/detect_image")
async def run_detection(request: DetectionRequest):
"""主检测端点将任务分发给相应的检测器现在会聚合PPE任务。"""
logger.info(f"收到检测请求types='{request.type}', url='{request.url}'")
try:
image_cv = url_to_cv_image(request.url)
original_img_size = [image_cv.shape[1], image_cv.shape[0]]
except HTTPException as e:
logger.error(f"图像处理失败URL: {request.url}", exc_info=True)
return JSONResponse(status_code=e.status_code, content={"error": e.detail})
requested_types: Set[str] = {t.strip().lower() for t in request.type.split(',') if t.strip()}
if not requested_types:
return JSONResponse(status_code=400, content={"error": "未指定任何检测类型。"})
logger.info(f"已解析的检测类型: {requested_types}")
aggregated_targets: List[Dict] = []
error_messages: List[str] = []
# --- 新的智能分发逻辑 ---
# 1. 识别并聚合所有PPE相关的任务
ppe_tasks_requested = requested_types.intersection(PPE_TASK_KEYWORDS)
if ppe_tasks_requested:
# 将用户请求的任务关键字 (hardhat, vest) 转换为后端需要的违规代码 (nohelmet, novest)
violations_to_report = {VIOLATION_MAP[task] for task in ppe_tasks_requested}
logger.info(f"分发PPE任务到检测器: {ppe_tasks_requested}, 需要报告的违规: {violations_to_report}")
try:
# 将需要报告的违规类型集合传递给检测函数
results = detect_hardhat_with_sahi(image_cv, requested_violations=violations_to_report, extract=request.extract)
aggregated_targets.extend(results)
except Exception as e:
logger.error(f"处理PPE任务 {ppe_tasks_requested} 时发生严重错误", exc_info=True)
error_messages.append(f"处理PPE任务时出错: {e}")
# 2. 识别并分发光伏板相关任务
pho_tasks_to_run = requested_types.intersection(PHO_DETECTOR_SUPPORTED_CLASSES)
if pho_tasks_to_run:
logger.info(f"分发光伏板相关任务到检测器: {pho_tasks_to_run}")
try:
# 确保传递一个列表给detect_objects_by_class
results = detect_objects_by_class(image_cv, list(pho_tasks_to_run))
aggregated_targets.extend(results)
except Exception as e:
logger.error(f"处理光伏板任务 {pho_tasks_to_run} 时发生严重错误", exc_info=True)
error_messages.append(f"处理光伏板任务时出错: {e}")
# 3. 处理其他完全独立的任务
remaining_tasks = requested_types - ppe_tasks_requested - pho_tasks_to_run
if remaining_tasks:
logger.info(f"正在处理其他独立任务: {remaining_tasks}")
for task in remaining_tasks:
try:
detector_results = None
if task == "smoking":
detector_results = detect_smoking_behavior(image_cv, extract=request.extract)
elif task in ["fire", "smoke"]:
detector_results = detect_fire_smoke_with_sahi(image_cv, extract=request.extract)
else:
logger.warning(f"检测到未知任务类型,将被忽略: {task}")
error_messages.append(f"未知的检测类型: '{task}'")
continue
if isinstance(detector_results, list):
aggregated_targets.extend(detector_results)
except Exception as e:
logger.error(f"调用检测器处理任务 '{task}' 时失败", exc_info=True)
error_messages.append(f"处理 '{task}' 时出错: {e}")
# --- 最终响应构建 ---
response_content = {
"hasTarget": 1 if aggregated_targets else 0,
"originalImgSize": original_img_size,
"targets": aggregated_targets,
}
if error_messages:
response_content["processing_errors"] = error_messages
logger.warning(f"请求处理完成,但包含错误信息: {error_messages}")
return JSONResponse(status_code=200, content=response_content)
@app.get("/")
def read_root():
return {"message": "高级检测API正在运行。"}

Binary file not shown.

Binary file not shown.

Binary file not shown.

BIN
models/helmet_model/best.pt Normal file

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

11
requirements.txt Normal file
View File

@ -0,0 +1,11 @@
fastapi
uvicorn[standard]
pydantic
python-multipart
Pillow
opencv-python-headless
numpy
supervision
sahi
ultralytics
requests

View File

@ -0,0 +1,93 @@
import cv2
import numpy as np
# import supervision as sv # No longer directly needed
from sahi import AutoDetectionModel
from sahi.predict import get_sliced_prediction
# import base64 # No longer needed for primary output
import logging
logger = logging.getLogger(__name__)
CLASS_NAMES_FIRE_SMOKE = ["fire", "smoggy"]
MODEL_WEIGHTS_PATH_FIRE_SMOKE = "models/fire_smoke_model/best.pt"
SAHI_PARAMS_LARGE = {"slice_height": 2000, "slice_width": 2000, "overlap_height_ratio": 0.2, "overlap_width_ratio": 0.2}
SAHI_PARAMS_SMALL = {"slice_height": 256, "slice_width": 256, "overlap_height_ratio": 0.2, "overlap_width_ratio": 0.2}
detection_model_fire_smoke = None
def initialize_fire_smoke_model():
global detection_model_fire_smoke
if detection_model_fire_smoke is None:
logger.info(f"Loading fire_smoke detection model from {MODEL_WEIGHTS_PATH_FIRE_SMOKE}...")
try:
detection_model_fire_smoke = AutoDetectionModel.from_pretrained(
model_type="ultralytics",
model_path=MODEL_WEIGHTS_PATH_FIRE_SMOKE,
confidence_threshold=0.8,
)
logger.info("Fire_smoke detection model loaded successfully.")
except Exception as e:
logger.error(f"Error loading fire_smoke model: {e}", exc_info=True)
raise RuntimeError(f"Could not load fire_smoke model: {e}")
return detection_model_fire_smoke
def _run_sahi_and_process_results(image_cv: np.ndarray, model, sahi_params: dict, model_class_names: list):
# This helper function is the same
try:
sahi_result = get_sliced_prediction(
image=image_cv, detection_model=model, slice_height=sahi_params["slice_height"],
slice_width=sahi_params["slice_width"], overlap_height_ratio=sahi_params["overlap_height_ratio"],
overlap_width_ratio=sahi_params["overlap_width_ratio"], verbose=0
)
except Exception as e:
logger.error(f"SAHI prediction failed with params {sahi_params}: {e}", exc_info=True)
return []
object_predictions_original = sahi_result.object_prediction_list
valid_original_predictions = []
for pred in object_predictions_original:
cat_id = int(pred.category.id)
if 0 <= cat_id < len(model_class_names):
valid_original_predictions.append(pred)
else:
logger.warning(f"SAHI Raw: Detected class ID {cat_id} for {model.__class__.__name__} is out of range for model_class_names: {model_class_names}. Ignoring.")
return valid_original_predictions
def detect_fire_smoke_with_sahi(image_cv: np.ndarray, extract: bool = False):
global detection_model_fire_smoke
if detection_model_fire_smoke is None:
logger.warning("Fire_smoke model was not loaded. Attempting to initialize now.")
initialize_fire_smoke_model()
if detection_model_fire_smoke is None:
logger.error("Fire_smoke model could not be initialized for detection.")
return {"error": "Fire_smoke model is not available."}
logger.info(f"Executing fire_smoke detection, extract flag is: {extract}")
_obj_preds_all_model_outputs = []
logger.info(f"Attempting fire_smoke detection with SAHI (Large Slices)...")
obj_preds_attempt1 = _run_sahi_and_process_results(
image_cv, detection_model_fire_smoke, SAHI_PARAMS_LARGE, CLASS_NAMES_FIRE_SMOKE
)
if obj_preds_attempt1:
_obj_preds_all_model_outputs = obj_preds_attempt1
else:
logger.info("No SAHI fire_smoke detections (large slices). Skipping small slices attempt as requested.")
_obj_preds_all_model_outputs = []
targets_output_list = []
for pred in _obj_preds_all_model_outputs:
class_id = int(pred.category.id)
class_name = CLASS_NAMES_FIRE_SMOKE[class_id]
minx, miny, maxx, maxy = pred.bbox.minx, pred.bbox.miny, pred.bbox.maxx, pred.bbox.maxy
bbox_width = maxx - minx
bbox_height = maxy - miny
targets_output_list.append({
"type": class_name,
"size": [int(bbox_width), int(bbox_height)],
"leftTopPoint": [int(minx), int(miny)],
"score": round(float(pred.score.value), 4)
})
logger.info(f"Prepared {len(targets_output_list)} fire/smoke targets for API response.")
return targets_output_list

View File

@ -0,0 +1,111 @@
import cv2
import numpy as np
import supervision as sv
from sahi import AutoDetectionModel
from sahi.predict import get_sliced_prediction
# import base64 # No longer needed for primary output if base64 image isn't returned
import logging
logger = logging.getLogger(__name__)
CLASS_NAMES_HELMET_MODEL_OUTPUT = ["helmet", "nohelmet", "vast", "novast"]
TARGET_CLASS_FOR_REPORTING = "nohelmet"
MODEL_WEIGHTS_PATH_HELMET = "models/helmet_model/best.pt"
SAHI_PARAMS_LARGE = {"slice_height": 1200, "slice_width": 1200, "overlap_height_ratio": 0.2, "overlap_width_ratio": 0.2}
SAHI_PARAMS_SMALL = {"slice_height": 700, "slice_width": 700, "overlap_height_ratio": 0.2, "overlap_width_ratio": 0.2}
detection_model_helmet = None
def initialize_helmet_model():
global detection_model_helmet
if detection_model_helmet is None:
logger.info(f"Loading helmet detection model from {MODEL_WEIGHTS_PATH_HELMET}...")
try:
detection_model_helmet = AutoDetectionModel.from_pretrained(
model_type="ultralytics",
model_path=MODEL_WEIGHTS_PATH_HELMET,
confidence_threshold=0.8,
)
logger.info("Helmet detection model loaded successfully.")
except Exception as e:
logger.error(f"Error loading helmet model: {e}", exc_info=True)
raise RuntimeError(f"Could not load helmet model: {e}")
return detection_model_helmet
# encode_image_to_base64 is no longer strictly needed if not returning annotated image
def _run_sahi_and_process_results(image_cv: np.ndarray, model, sahi_params: dict, model_class_names: list):
# This helper function is mostly the same, returns List[ObjectPrediction]
try:
sahi_result = get_sliced_prediction(
image=image_cv, detection_model=model, slice_height=sahi_params["slice_height"],
slice_width=sahi_params["slice_width"], overlap_height_ratio=sahi_params["overlap_height_ratio"],
overlap_width_ratio=sahi_params["overlap_width_ratio"], verbose=0
)
except Exception as e:
logger.error(f"SAHI prediction failed with params {sahi_params}: {e}", exc_info=True)
return [] # Return empty list of predictions on SAHI error
object_predictions_original = sahi_result.object_prediction_list
valid_original_predictions = []
for pred in object_predictions_original:
cat_id = int(pred.category.id)
if 0 <= cat_id < len(model_class_names):
valid_original_predictions.append(pred)
else:
logger.warning(f"SAHI Raw: Detected class ID {cat_id} for {model.__class__.__name__} is out of range for model_class_names: {model_class_names}. Ignoring.")
return valid_original_predictions
def detect_hardhat_with_sahi(image_cv: np.ndarray, extract: bool = False):
global detection_model_helmet
if detection_model_helmet is None:
logger.warning("Helmet model was not loaded. Attempting to initialize now.")
initialize_helmet_model()
if detection_model_helmet is None:
logger.error("Helmet model could not be initialized for detection.")
return {"error": "Helmet model is not available."} # Return error dict
logger.info(f"Executing helmet task (reporting only '{TARGET_CLASS_FOR_REPORTING}'), extract flag is: {extract}")
_obj_preds_all_model_outputs = [] # Store all valid SAHI ObjectPrediction
logger.info(f"Attempting detection with SAHI (Large Slices)...")
obj_preds_attempt1 = _run_sahi_and_process_results(
image_cv, detection_model_helmet, SAHI_PARAMS_LARGE, CLASS_NAMES_HELMET_MODEL_OUTPUT
)
if obj_preds_attempt1: # If list is not empty
logger.info(f"Initial SAHI detections (large slices): {len(obj_preds_attempt1)}")
_obj_preds_all_model_outputs = obj_preds_attempt1
else:
logger.info("No SAHI detections (large slices). Retrying with small slices...")
obj_preds_attempt2 = _run_sahi_and_process_results(
image_cv, detection_model_helmet, SAHI_PARAMS_SMALL, CLASS_NAMES_HELMET_MODEL_OUTPUT
)
if obj_preds_attempt2:
logger.info(f"Initial SAHI detections (small slices): {len(obj_preds_attempt2)}")
_obj_preds_all_model_outputs = obj_preds_attempt2
else:
logger.info("No initial SAHI detections found even with small slices.")
# --- Transform to new output format, filtering for TARGET_CLASS_FOR_REPORTING ---
targets_output_list = []
for pred in _obj_preds_all_model_outputs:
class_id = int(pred.category.id)
class_name = CLASS_NAMES_HELMET_MODEL_OUTPUT[class_id]
if class_name == TARGET_CLASS_FOR_REPORTING:
minx, miny, maxx, maxy = pred.bbox.minx, pred.bbox.miny, pred.bbox.maxx, pred.bbox.maxy
bbox_width = maxx - minx
bbox_height = maxy - miny
targets_output_list.append({
"type": class_name,
"size": [int(bbox_width), int(bbox_height)],
"leftTopPoint": [int(minx), int(miny)],
"score": round(float(pred.score.value), 4) # Rounded score
})
logger.info(f"Prepared {len(targets_output_list)} '{TARGET_CLASS_FOR_REPORTING}' targets for API response.")
return targets_output_list # Directly return list of target dicts

199
utils/hardhat_detector.py Normal file
View File

@ -0,0 +1,199 @@
import cv2
import numpy as np
from sahi import AutoDetectionModel
from sahi.predict import get_sliced_prediction
import logging
import os
import torch
from typing import Set, List # 导入Set和List以支持类型提示
# 使用在main.py中配置好的根logger
logger = logging.getLogger(__name__)
# --- 1. 配置模型和类别信息 (这是需要您根据实际模型调整的部分) ---
# 新的“状态检测”模型的类别名称列表。顺序必须与模型训练时一致。
CLASS_NAMES_PPE_STATE_MODEL = ["wearingall", "noequipment", "nohelmet", "novest"]
# 后端期望接收的违规类型字符串
BACKEND_VIOLATION_CODE_NO_HELMET = "nohelmet"
BACKEND_VIOLATION_CODE_NO_VEST = "novest"
# 新模型的权重文件路径
MODEL_WEIGHTS_PATH_HELMET = "models/ppe_state_model/best.pt"
# --- 2. 保留并适配现有的调用接口 ---
# SAHI切片参数
SAHI_PARAMS = {"slice_height": 1440, "slice_width": 1440, "overlap_height_ratio": 0.3, "overlap_width_ratio": 0.3}
# 全局模型实例
detection_model_helmet = None # 变量名保持'detection_model_helmet'与旧版一致
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def initialize_helmet_model():
"""
符合 main.py 调用约定的初始化函数。
它加载新的 PPE 状态模型。
"""
global detection_model_helmet
if detection_model_helmet is None:
if not os.path.exists(MODEL_WEIGHTS_PATH_HELMET):
error_msg = f"Model weights file not found at: {MODEL_WEIGHTS_PATH_HELMET}"
logger.error(error_msg)
raise FileNotFoundError(error_msg)
logger.info(f"Loading hardhat state detection model from {MODEL_WEIGHTS_PATH_HELMET}...")
try:
detection_model_helmet = AutoDetectionModel.from_pretrained(
model_type="yolov8",
model_path=MODEL_WEIGHTS_PATH_HELMET,
confidence_threshold=0.83,
device=DEVICE,
)
# 可选的验证步骤
loaded_class_names = list(detection_model_helmet.model.names.values())
if sorted(loaded_class_names) != sorted(CLASS_NAMES_PPE_STATE_MODEL):
logger.warning(f"Model class names mismatch! Code expects {CLASS_NAMES_PPE_STATE_MODEL} "
f"but model has {loaded_class_names}. Please check consistency.")
logger.info("hardhat state detection model loaded successfully.")
except Exception as e:
logger.error(f"Error loading hardhat state model: {e}", exc_info=True)
raise RuntimeError(f"Could not load hardhat state model: {e}")
return detection_model_helmet
def _run_sahi_and_process_results(image_cv: np.ndarray, model, sahi_params: dict, model_class_names: list) -> List:
"""
辅助函数执行SAHI推理并返回所有有效预测的列表。
"""
try:
sahi_result = get_sliced_prediction(
image=image_cv,
detection_model=model,
slice_height=sahi_params["slice_height"],
slice_width=sahi_params["slice_width"],
overlap_height_ratio=sahi_params["overlap_height_ratio"],
overlap_width_ratio=sahi_params["overlap_width_ratio"],
verbose=0,
postprocess_type="GREEDYNMM",
postprocess_match_metric="IOS",
postprocess_match_threshold=0.5
)
except Exception as e:
logger.error(f"SAHI prediction failed with params {sahi_params}: {e}", exc_info=True)
return []
object_predictions = sahi_result.object_prediction_list if sahi_result else []
valid_predictions = []
for pred in object_predictions:
cat_id = int(pred.category.id)
if 0 <= cat_id < len(model_class_names):
valid_predictions.append(pred)
else:
logger.warning(f"SAHI Raw: Detected class ID {cat_id} is out of range. Ignoring prediction.")
return valid_predictions
def map_model_output_to_backend_violations(class_name: str) -> list[str]:
"""
将模型输出的状态类别名“翻译”成后端需要的违规代码列表。
"""
if class_name == "nohelmet":
return [BACKEND_VIOLATION_CODE_NO_HELMET]
elif class_name == "novest":
return [BACKEND_VIOLATION_CODE_NO_VEST]
elif class_name == "noequipment":
return [BACKEND_VIOLATION_CODE_NO_HELMET, BACKEND_VIOLATION_CODE_NO_VEST]
return []
def detect_hardhat_with_sahi(image_cv: np.ndarray, requested_violations: Set[str], extract: bool = False):
"""
主检测函数现在接收一个`requested_violations`集合参数,用于过滤最终的输出。
"""
global detection_model_helmet
if detection_model_helmet is None:
logger.warning("hardhat state model was not loaded. Initializing on first call.")
try:
initialize_helmet_model()
except Exception as e:
logger.error(f"Failed to initialize model on demand: {e}")
raise RuntimeError(f"hardhat model could not be initialized: {e}")
logger.info(f"Executing PPE detection, requested violations: {requested_violations}...")
img_height, img_width, _ = image_cv.shape
raw_sahi_predictions = _run_sahi_and_process_results(
image_cv, detection_model_helmet, SAHI_PARAMS, CLASS_NAMES_PPE_STATE_MODEL
)
if not raw_sahi_predictions:
logger.info("No raw objects detected by SAHI.")
return []
logger.info(f"SAHI raw detection count: {len(raw_sahi_predictions)}. Applying post-processing filters...")
# --- 后处理过滤逻辑 ---
filtered_predictions = []
MIN_PERSON_HEIGHT_PX = 40
MAX_PERSON_WIDTH_RATIO = 0.21
MAX_PERSON_HEIGHT_RATIO = 0.33
MIN_ASPECT_RATIO = 1.2
MAX_ASPECT_RATIO = 5.0
FEET_POSITION_THRESHOLD_RATIO = 0.05
for pred in raw_sahi_predictions:
minx, miny, maxx, maxy = map(int, pred.bbox.to_xyxy())
score = pred.score.value
bbox_width = maxx - minx
bbox_height = maxy - miny
if bbox_height < MIN_PERSON_HEIGHT_PX: continue
relative_width = bbox_width / img_width
relative_height = bbox_height / img_height
if relative_width > MAX_PERSON_WIDTH_RATIO or relative_height > MAX_PERSON_HEIGHT_RATIO: continue
if bbox_width == 0: continue
aspect_ratio = bbox_height / bbox_width
if not (MIN_ASPECT_RATIO <= aspect_ratio <= MAX_ASPECT_RATIO): continue
if maxy < (img_height * FEET_POSITION_THRESHOLD_RATIO): continue
bbox_area = bbox_width * bbox_height
image_area = img_width * img_height
if bbox_area > (image_area * 0.1) and score < 0.5: continue
filtered_predictions.append(pred)
logger.info(f"After filtering, {len(filtered_predictions)} valid person states remain. Mapping to violations...")
# --- 映射和最终过滤逻辑 ---
targets_output_list = []
for pred in filtered_predictions:
class_id = int(pred.category.id)
class_name = CLASS_NAMES_PPE_STATE_MODEL[class_id]
score = round(float(pred.score.value), 4)
minx, miny, maxx, maxy = map(int, pred.bbox.to_xyxy())
logger.info(f"Processing prediction -> Model says: '{class_name}', Score: {score}, BBox: [{minx},{miny},{maxx},{maxy}]")
all_possible_violations = map_model_output_to_backend_violations(class_name)
final_violations_to_report = [
v for v in all_possible_violations if v in requested_violations
]
if not final_violations_to_report:
continue
bbox_width = maxx - minx
bbox_height = maxy - miny
for code in final_violations_to_report:
logger.info(f" └─ Violation Found & Requested: Type: '{code}', BBox: [{minx},{miny},{maxx},{maxy}]")
targets_output_list.append({
"type": code,
"size": [int(bbox_width), int(bbox_height)],
"leftTopPoint": [int(minx), int(miny)],
"score": score
})
logger.info(f"Prepared {len(targets_output_list)} filtered violation targets for response.")
return targets_output_list

163
utils/pho_detector.py Normal file
View File

@ -0,0 +1,163 @@
# utils/pho_detector.py
import cv2
import numpy as np
from sahi import AutoDetectionModel
from sahi.predict import get_sliced_prediction
import logging
from typing import List, Dict, Any
# --- Basic Configuration ---
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# --- Module-Specific Configuration ---
# All classes this module can detect.
ALL_SUPPORTED_CLASSES = ['pho', 'shelves', 'pile', 'hole']
# Configuration for the first model (detects 'pho' and 'shelves').
MODEL_1_CONFIG = {
"name": "pho_model",
# IMPORTANT: Update this path to your actual model location.
"weights_path": "models/photovoltaic_model/pho_model/best.pt",
# Class order must match the model's training IDs (0: 'pho', 1: 'shelves').
"class_names": ['pho', 'shelves'],
"confidence_threshold": 0.7
}
# Configuration for the second model (detects 'pile' and 'hole').
MODEL_2_CONFIG = {
"name": "pile_model",
# IMPORTANT: Update this path to your actual model location.
"weights_path": "models/photovoltaic_model/pile_model/best.pt",
# Class order must match the model's training IDs (0: 'pile', 1: 'hole').
"class_names": ['pile', 'hole'],
"confidence_threshold": 0.7
}
# Unified SAHI slicing parameters for all tasks in this module.
SAHI_PARAMS = {
"slice_height": 800,
"slice_width": 800,
"overlap_height_ratio": 0.2,
"overlap_width_ratio": 0.2
}
# Global dictionary to cache loaded model instances.
detection_models: Dict[str, AutoDetectionModel] = {}
# --- Model Management Functions ---
def initialize_model(model_config: Dict[str, Any]) -> AutoDetectionModel:
"""Initializes a single model based on its configuration if not already loaded."""
global detection_models
model_name = model_config["name"]
if model_name not in detection_models:
logger.info(f"Loading detection model '{model_name}' from {model_config['weights_path']}...")
try:
model = AutoDetectionModel.from_pretrained(
model_type="ultralytics",
model_path=model_config["weights_path"],
confidence_threshold=model_config["confidence_threshold"],
)
detection_models[model_name] = model
logger.info(f"Model '{model_name}' loaded successfully.")
except Exception as e:
logger.error(f"Error loading model '{model_name}': {e}", exc_info=True)
raise RuntimeError(f"Could not load model '{model_name}': {e}")
return detection_models[model_name]
def initialize_all_pho_models():
"""Initializes all models defined in this module. Called once at application startup."""
logger.info("Pre-initializing all photovoltaic-related models...")
initialize_model(MODEL_1_CONFIG)
initialize_model(MODEL_2_CONFIG)
logger.info("All photovoltaic models have been initialized.")
# --- Inference and Core Logic ---
def _run_inference_for_model(image_cv: np.ndarray, model_config: Dict[str, Any]) -> List[Dict]:
"""Internal function to run SAHI prediction for a single specified model."""
try:
model = initialize_model(model_config)
except RuntimeError as e:
logger.error(f"Cannot run inference; model '{model_config['name']}' failed to initialize: {e}")
return []
logger.info(f"Running SAHI prediction with model: {model_config['name']}...")
try:
sahi_result = get_sliced_prediction(
image=image_cv,
detection_model=model,
**SAHI_PARAMS,
verbose=0
)
except Exception as e:
logger.error(f"SAHI prediction failed for model '{model_config['name']}': {e}", exc_info=True)
return []
targets_output_list = []
# Use the specific class list for this model to map IDs to names.
model_class_names = model_config["class_names"]
for pred in sahi_result.object_prediction_list:
cat_id = int(pred.category.id)
# Core Logic: Map the model's local ID to its correct global class name.
if 0 <= cat_id < len(model_class_names):
class_name = model_class_names[cat_id]
minx, miny, maxx, maxy = pred.bbox.minx, pred.bbox.miny, pred.bbox.maxx, pred.bbox.maxy
targets_output_list.append({
"type": class_name,
"size": [int(maxx - minx), int(maxy - miny)],
"leftTopPoint": [int(minx), int(miny)],
"score": round(float(pred.score.value), 4)
})
else:
logger.warning(f"Model '{model_config['name']}' detected an out-of-range class ID: {cat_id}. Ignoring.")
logger.info(f"Model '{model_config['name']}' found {len(targets_output_list)} raw objects.")
return targets_output_list
def detect_objects_by_class(image_cv: np.ndarray, target_classes: List[str]) -> List[Dict]:
"""
Main public function. Detects objects by intelligently dispatching to the correct model(s).
Args:
image_cv: The input OpenCV image.
target_classes: A list of class names to detect (e.g., ['pho', 'hole']).
Returns:
A list of detection dictionaries for the successfully found targets.
"""
final_results = []
# Determine which targets belong to Model 1.
model_1_targets = [cls for cls in target_classes if cls in MODEL_1_CONFIG["class_names"]]
# Determine which targets belong to Model 2.
model_2_targets = [cls for cls in target_classes if cls in MODEL_2_CONFIG["class_names"]]
# Run Model 1 only if it's needed.
if model_1_targets:
logger.info(f"Model 1 ('{MODEL_1_CONFIG['name']}') triggered for targets: {model_1_targets}")
model_1_results = _run_inference_for_model(image_cv, MODEL_1_CONFIG)
# Filter the raw results to only include what the user asked for.
final_results.extend([res for res in model_1_results if res["type"] in model_1_targets])
# Run Model 2 only if it's needed.
if model_2_targets:
logger.info(f"Model 2 ('{MODEL_2_CONFIG['name']}') triggered for targets: {model_2_targets}")
model_2_results = _run_inference_for_model(image_cv, MODEL_2_CONFIG)
# Filter the raw results to only include what the user asked for.
final_results.extend([res for res in model_2_results if res["type"] in model_2_targets])
logger.info(f"Photovoltaic detection complete. Returning {len(final_results)} filtered targets.")
return final_results
# --- Export for main.py ---
# This list is imported by main.py to know which tasks to delegate to this module.
PHO_DETECTOR_SUPPORTED_CLASSES = ALL_SUPPORTED_CLASSES

258
utils/smoking_detector.py Normal file
View File

@ -0,0 +1,258 @@
import cv2
import numpy as np
from sahi import AutoDetectionModel
from sahi.predict import get_sliced_prediction
import logging
import torch # 用于设备检测和YOLO模型加载
from ultralytics import YOLO # 用于阶段二的ROI检测
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO) # 添加这行可以看到logger的输出
# --- 更新的配置 ---
CLASS_NAMES_BEHAVIOR = ["smoke", "face"]
CIGARETTE_CLASS_ID = 0
FACE_CLASS_ID = 1
MODEL_WEIGHTS_PATH_BEHAVIOR = "models/smoking_model/best.pt"
# SAHI 切片参数
SAHI_PARAMS_FACE_LARGE = {"slice_height": 1280, "slice_width": 1280, "overlap_height_ratio": 0.2, "overlap_width_ratio": 0.2}
SAHI_PARAMS_FACE_SMALL = {"slice_height": 640, "slice_width": 640, "overlap_height_ratio": 0.2, "overlap_width_ratio": 0.2}
# ROI 内香烟检测的YOLO置信度
CIGARETTE_IN_ROI_CONF_THRESHOLD = 0.6
FACE_ROI_PADDING = 30
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# 全局模型实例
detection_model_sahi_behavior = None
yolo_model_direct_behavior = None
def initialize_smoking_model():
global detection_model_sahi_behavior, yolo_model_direct_behavior
if detection_model_sahi_behavior is None or yolo_model_direct_behavior is None:
logger.info(f"Loading behavior detection model (face & cigarette) from {MODEL_WEIGHTS_PATH_BEHAVIOR}...")
try:
# 1. SAHI封装的模型用于全局切片检测
detection_model_sahi_behavior = AutoDetectionModel.from_pretrained(
model_type="ultralytics",
model_path=MODEL_WEIGHTS_PATH_BEHAVIOR,
confidence_threshold=0.7,
device=DEVICE,
)
# 2. Ultralytics直接加载的模型用于快速的ROI检测
yolo_model_direct_behavior = YOLO(MODEL_WEIGHTS_PATH_BEHAVIOR)
yolo_model_direct_behavior.to(DEVICE)
if hasattr(yolo_model_direct_behavior, 'names') and yolo_model_direct_behavior.names:
model_names = list(yolo_model_direct_behavior.names.values())
if model_names != CLASS_NAMES_BEHAVIOR and sorted(model_names) != sorted(CLASS_NAMES_BEHAVIOR) :
logger.warning(f"Mismatch in class names! Expected from code: {CLASS_NAMES_BEHAVIOR}, "
f"Got from model: {model_names}. Please verify consistency.")
else:
logger.info(f"Model class names confirmed: {model_names}")
else:
logger.warning("Could not retrieve class names from loaded yolo_model_direct_behavior for verification.")
logger.info("Behavior detection models (SAHI and Direct YOLO) loaded successfully.")
except Exception as e:
logger.error(f"Error loading behavior model: {e}", exc_info=True)
detection_model_sahi_behavior = None
yolo_model_direct_behavior = None
raise RuntimeError(f"Could not load behavior model: {e}")
return detection_model_sahi_behavior, yolo_model_direct_behavior
def _run_sahi_and_process_results(image_cv: np.ndarray, model, sahi_params: dict, target_class_id: int, target_class_name: str, model_actual_class_names: list):
try:
sahi_result = get_sliced_prediction(
image=image_cv, detection_model=model, slice_height=sahi_params["slice_height"],
slice_width=sahi_params["slice_width"], overlap_height_ratio=sahi_params["overlap_height_ratio"],
overlap_width_ratio=sahi_params["overlap_width_ratio"], verbose=0,
postprocess_type="GREEDYNMM",
postprocess_match_metric="IOS",
postprocess_match_threshold=0.7,
postprocess_class_agnostic=False
)
except Exception as e:
logger.error(f"SAHI prediction failed for class '{target_class_name}' with params {sahi_params}: {e}", exc_info=True)
return []
valid_predictions = []
if hasattr(sahi_result, 'object_prediction_list'):
for pred in sahi_result.object_prediction_list:
cat_id_from_sahi = int(pred.category.id)
if not (0 <= cat_id_from_sahi < len(model_actual_class_names)):
logger.warning(f"SAHI Raw: Detected class ID {cat_id_from_sahi} for '{target_class_name}' "
f"is out of range for actual model class names ({len(model_actual_class_names)} classes). Ignoring pred.")
continue
if cat_id_from_sahi == target_class_id:
valid_predictions.append(pred)
logger.info(f"SAHI found {len(valid_predictions)} instances of '{target_class_name}' with params {sahi_params}.")
return valid_predictions
def detect_smoking_behavior(image_cv: np.ndarray, extract: bool = False):
global detection_model_sahi_behavior, yolo_model_direct_behavior
if detection_model_sahi_behavior is None or yolo_model_direct_behavior is None:
logger.warning("Behavior model (SAHI or Direct YOLO) was not loaded. Attempting to initialize now.")
try:
initialize_smoking_model()
except RuntimeError:
logger.error("Behavior model could not be initialized for detection.")
return {"error": "Behavior model is not available."}
if detection_model_sahi_behavior is None or yolo_model_direct_behavior is None:
logger.error("Behavior model components are still None after initialization attempt.")
return {"error": "Behavior model failed to initialize."}
model_actual_class_names_dict = yolo_model_direct_behavior.names
if not model_actual_class_names_dict:
logger.error("Cannot determine actual class names from the loaded YOLO model.")
return {"error": "Model class names missing."}
max_id = max(model_actual_class_names_dict.keys())
model_actual_class_names_list = [""] * (max_id + 1)
for id_val, name_val in model_actual_class_names_dict.items():
model_actual_class_names_list[id_val] = name_val
logger.info(f"Executing two-stage smoking behavior detection...")
# --- 阶段一: 使用SAHI检测人脸 ---
detected_faces_sahi = []
logger.info(f"Attempting FACE detection with SAHI (Large Slices {SAHI_PARAMS_FACE_LARGE})...")
face_preds_large = _run_sahi_and_process_results(
image_cv, detection_model_sahi_behavior, SAHI_PARAMS_FACE_LARGE, FACE_CLASS_ID, CLASS_NAMES_BEHAVIOR[FACE_CLASS_ID], model_actual_class_names_list
)
if face_preds_large:
detected_faces_sahi = face_preds_large
else:
logger.info(f"No SAHI face detections (large slices). Retrying with small slices {SAHI_PARAMS_FACE_SMALL}...")
face_preds_small = _run_sahi_and_process_results(
image_cv, detection_model_sahi_behavior, SAHI_PARAMS_FACE_SMALL, FACE_CLASS_ID, CLASS_NAMES_BEHAVIOR[FACE_CLASS_ID], model_actual_class_names_list
)
if face_preds_small:
detected_faces_sahi = face_preds_small
else:
logger.info("No SAHI face detections found even with small slices.")
# <<<--- 逻辑增强点 1对检测到的人脸进行严格的资格审查 --- START --->>>
if detected_faces_sahi:
original_face_count = len(detected_faces_sahi)
qualified_faces = []
for face_pred in detected_faces_sahi:
# 1. 置信度审查 (您可以调整这个阈值)
if face_pred.score.value < 0.65:
continue
# 2. 尺寸审查 (您可以调整这些像素值)
x1, y1, x2, y2 = face_pred.bbox.to_xyxy()
width, height = x2 - x1, y2 - y1
if width < 32 or height < 32:
continue
# 3. 长宽比审查 (您可以调整这个范围)
aspect_ratio = width / height if height > 0 else 0
if not (0.6 < aspect_ratio < 1.6):
continue
qualified_faces.append(face_pred)
logger.info(f"Face Qualification: Initial candidates: {original_face_count}, Qualified after filtering: {len(qualified_faces)}")
detected_faces_sahi = qualified_faces # 用审查合格的列表覆盖原来的
# <<<--- 逻辑增强点 1 --- END --->>>
targets_output_list = []
if not detected_faces_sahi:
logger.info("No QUALIFIED faces found in Stage 1. No smoking behavior to report.")
return targets_output_list
logger.info(f"Found {len(detected_faces_sahi)} qualified faces in Stage 1. Proceeding to Stage 2...")
# --- 阶段二: 在每个人脸ROI内使用直接YOLO检测香烟 ---
for face_pred_sahi in detected_faces_sahi:
x1_face, y1_face, x2_face, y2_face = map(int, face_pred_sahi.bbox.to_xyxy())
# 定义并裁剪ROI
roi_x1 = max(0, x1_face - FACE_ROI_PADDING)
roi_y1 = max(0, y1_face - FACE_ROI_PADDING)
roi_x2 = min(image_cv.shape[1], x2_face + FACE_ROI_PADDING)
roi_y2 = min(image_cv.shape[0], y2_face + FACE_ROI_PADDING)
face_roi_crop = image_cv[roi_y1:roi_y2, roi_x1:roi_x2]
if face_roi_crop.size == 0:
logger.debug(f"Face ROI at [{x1_face},{y1_face},{x2_face},{y2_face}] resulted in empty crop. Skipping.")
continue
# <<<--- 逻辑增强点 2在ROI内进行交叉验证 --- START --->>>
roi_results = yolo_model_direct_behavior.predict(source=face_roi_crop, verbose=False, device=DEVICE, conf=CIGARETTE_IN_ROI_CONF_THRESHOLD)
is_face_confirmed_in_roi = False
found_cigarette_in_roi = False
cigarette_details_for_this_face = []
if roi_results and len(roi_results) > 0:
result_for_roi = roi_results[0]
for box in result_for_roi.boxes:
class_id = int(box.cls)
if class_id == FACE_CLASS_ID:
is_face_confirmed_in_roi = True
elif class_id == CIGARETTE_CLASS_ID:
xc1_roi, yc1_roi, xc2_roi, yc2_roi = map(int, box.xyxy[0].tolist())
cig_minx_global = xc1_roi + roi_x1
cig_miny_global = yc1_roi + roi_y1
cig_maxx_global = xc2_roi + roi_x1
cig_maxy_global = yc2_roi + roi_y1
cig_width_global = cig_maxx_global - cig_minx_global
cig_height_global = cig_maxy_global - cig_miny_global
# <<<--- 逻辑增强点 3对香烟进行严格的尺寸和比例审查 --- START --->>>
long_side = max(cig_width_global, cig_height_global)
short_side = min(cig_width_global, cig_height_global)
# 1. 绝对尺寸审查 (根据您的要求)
if long_side > 100 or short_side > 40:
logger.info(f" REJECTED Cigarette: Absolute size out of bounds. long={long_side}, short={short_side}")
continue
# 2. 长宽比审查 (针对电线杆等误检)
if short_side > 0:
aspect_ratio = long_side / short_side
# 可调:一个正常的香烟,其长度应远大于宽度。这里设定一个范围来排除比例怪异的物体。
if not (2.5 < aspect_ratio < 25.0):
logger.info(f" REJECTED Cigarette: Aspect ratio out of bounds. ratio={aspect_ratio:.2f}")
continue
elif long_side > 0: # 处理 short_side为0 但 long_side不为0 的细线情况
logger.info(f" REJECTED Cigarette: BBox is a zero-width/height line.")
continue
else: # 宽和高都为0跳过
continue
# <<<--- 逻辑增强点 3 --- END --->>>
# 只有通过了所有审查的香烟才被视为有效
found_cigarette_in_roi = True
cigarette_details_for_this_face.append({
"type": CLASS_NAMES_BEHAVIOR[CIGARETTE_CLASS_ID],
"size": [int(cig_width_global), int(cig_height_global)],
"leftTopPoint": [int(cig_minx_global), int(cig_miny_global)],
"score": round(float(box.conf), 4)
})
# 最终裁定必须在ROI内同时确认有人脸和合格的香烟
if is_face_confirmed_in_roi and found_cigarette_in_roi:
logger.info(f" SUCCESS: Cross-validation passed. Face confirmed and qualified cigarette found in ROI of face at [{x1_face},{y1_face},{x2_face},{y2_face}].")
targets_output_list.extend(cigarette_details_for_this_face)
else:
logger.info(f" INFO: Cross-validation failed for face ROI. Face confirmed in ROI: {is_face_confirmed_in_roi}, Cigarette found in ROI: {found_cigarette_in_roi}. Ignoring results for this face.")
# <<<--- 逻辑增强点 2 --- END --->>>
logger.info(f"Processed all faces. Final targets_output_list contains {len(targets_output_list)} cigarette detections after all logical checks.")
return targets_output_list