Files
AI-TianDong/utils/hardhat_detector.py
2025-07-24 12:45:27 +08:00

199 lines
8.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import cv2
import numpy as np
from sahi import AutoDetectionModel
from sahi.predict import get_sliced_prediction
import logging
import os
import torch
from typing import Set, List # 导入Set和List以支持类型提示
# 使用在main.py中配置好的根logger
logger = logging.getLogger(__name__)
# --- 1. 配置模型和类别信息 (这是需要您根据实际模型调整的部分) ---
# 新的“状态检测”模型的类别名称列表。顺序必须与模型训练时一致。
CLASS_NAMES_PPE_STATE_MODEL = ["wearingall", "noequipment", "nohelmet", "novest"]
# 后端期望接收的违规类型字符串
BACKEND_VIOLATION_CODE_NO_HELMET = "nohelmet"
BACKEND_VIOLATION_CODE_NO_VEST = "novest"
# 新模型的权重文件路径
MODEL_WEIGHTS_PATH_HELMET = "models/ppe_state_model/best.pt"
# --- 2. 保留并适配现有的调用接口 ---
# SAHI切片参数
SAHI_PARAMS = {"slice_height": 1440, "slice_width": 1440, "overlap_height_ratio": 0.3, "overlap_width_ratio": 0.3}
# 全局模型实例
detection_model_helmet = None # 变量名保持'detection_model_helmet'与旧版一致
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def initialize_helmet_model():
"""
符合 main.py 调用约定的初始化函数。
它加载新的 PPE 状态模型。
"""
global detection_model_helmet
if detection_model_helmet is None:
if not os.path.exists(MODEL_WEIGHTS_PATH_HELMET):
error_msg = f"Model weights file not found at: {MODEL_WEIGHTS_PATH_HELMET}"
logger.error(error_msg)
raise FileNotFoundError(error_msg)
logger.info(f"Loading hardhat state detection model from {MODEL_WEIGHTS_PATH_HELMET}...")
try:
detection_model_helmet = AutoDetectionModel.from_pretrained(
model_type="yolov8",
model_path=MODEL_WEIGHTS_PATH_HELMET,
confidence_threshold=0.83,
device=DEVICE,
)
# 可选的验证步骤
loaded_class_names = list(detection_model_helmet.model.names.values())
if sorted(loaded_class_names) != sorted(CLASS_NAMES_PPE_STATE_MODEL):
logger.warning(f"Model class names mismatch! Code expects {CLASS_NAMES_PPE_STATE_MODEL} "
f"but model has {loaded_class_names}. Please check consistency.")
logger.info("hardhat state detection model loaded successfully.")
except Exception as e:
logger.error(f"Error loading hardhat state model: {e}", exc_info=True)
raise RuntimeError(f"Could not load hardhat state model: {e}")
return detection_model_helmet
def _run_sahi_and_process_results(image_cv: np.ndarray, model, sahi_params: dict, model_class_names: list) -> List:
"""
辅助函数执行SAHI推理并返回所有有效预测的列表。
"""
try:
sahi_result = get_sliced_prediction(
image=image_cv,
detection_model=model,
slice_height=sahi_params["slice_height"],
slice_width=sahi_params["slice_width"],
overlap_height_ratio=sahi_params["overlap_height_ratio"],
overlap_width_ratio=sahi_params["overlap_width_ratio"],
verbose=0,
postprocess_type="GREEDYNMM",
postprocess_match_metric="IOS",
postprocess_match_threshold=0.5
)
except Exception as e:
logger.error(f"SAHI prediction failed with params {sahi_params}: {e}", exc_info=True)
return []
object_predictions = sahi_result.object_prediction_list if sahi_result else []
valid_predictions = []
for pred in object_predictions:
cat_id = int(pred.category.id)
if 0 <= cat_id < len(model_class_names):
valid_predictions.append(pred)
else:
logger.warning(f"SAHI Raw: Detected class ID {cat_id} is out of range. Ignoring prediction.")
return valid_predictions
def map_model_output_to_backend_violations(class_name: str) -> list[str]:
"""
将模型输出的状态类别名“翻译”成后端需要的违规代码列表。
"""
if class_name == "nohelmet":
return [BACKEND_VIOLATION_CODE_NO_HELMET]
elif class_name == "novest":
return [BACKEND_VIOLATION_CODE_NO_VEST]
elif class_name == "noequipment":
return [BACKEND_VIOLATION_CODE_NO_HELMET, BACKEND_VIOLATION_CODE_NO_VEST]
return []
def detect_hardhat_with_sahi(image_cv: np.ndarray, requested_violations: Set[str], extract: bool = False):
"""
主检测函数现在接收一个`requested_violations`集合参数,用于过滤最终的输出。
"""
global detection_model_helmet
if detection_model_helmet is None:
logger.warning("hardhat state model was not loaded. Initializing on first call.")
try:
initialize_helmet_model()
except Exception as e:
logger.error(f"Failed to initialize model on demand: {e}")
raise RuntimeError(f"hardhat model could not be initialized: {e}")
logger.info(f"Executing PPE detection, requested violations: {requested_violations}...")
img_height, img_width, _ = image_cv.shape
raw_sahi_predictions = _run_sahi_and_process_results(
image_cv, detection_model_helmet, SAHI_PARAMS, CLASS_NAMES_PPE_STATE_MODEL
)
if not raw_sahi_predictions:
logger.info("No raw objects detected by SAHI.")
return []
logger.info(f"SAHI raw detection count: {len(raw_sahi_predictions)}. Applying post-processing filters...")
# --- 后处理过滤逻辑 ---
filtered_predictions = []
MIN_PERSON_HEIGHT_PX = 40
MAX_PERSON_WIDTH_RATIO = 0.21
MAX_PERSON_HEIGHT_RATIO = 0.33
MIN_ASPECT_RATIO = 1.2
MAX_ASPECT_RATIO = 5.0
FEET_POSITION_THRESHOLD_RATIO = 0.05
for pred in raw_sahi_predictions:
minx, miny, maxx, maxy = map(int, pred.bbox.to_xyxy())
score = pred.score.value
bbox_width = maxx - minx
bbox_height = maxy - miny
if bbox_height < MIN_PERSON_HEIGHT_PX: continue
relative_width = bbox_width / img_width
relative_height = bbox_height / img_height
if relative_width > MAX_PERSON_WIDTH_RATIO or relative_height > MAX_PERSON_HEIGHT_RATIO: continue
if bbox_width == 0: continue
aspect_ratio = bbox_height / bbox_width
if not (MIN_ASPECT_RATIO <= aspect_ratio <= MAX_ASPECT_RATIO): continue
if maxy < (img_height * FEET_POSITION_THRESHOLD_RATIO): continue
bbox_area = bbox_width * bbox_height
image_area = img_width * img_height
if bbox_area > (image_area * 0.1) and score < 0.5: continue
filtered_predictions.append(pred)
logger.info(f"After filtering, {len(filtered_predictions)} valid person states remain. Mapping to violations...")
# --- 映射和最终过滤逻辑 ---
targets_output_list = []
for pred in filtered_predictions:
class_id = int(pred.category.id)
class_name = CLASS_NAMES_PPE_STATE_MODEL[class_id]
score = round(float(pred.score.value), 4)
minx, miny, maxx, maxy = map(int, pred.bbox.to_xyxy())
logger.info(f"Processing prediction -> Model says: '{class_name}', Score: {score}, BBox: [{minx},{miny},{maxx},{maxy}]")
all_possible_violations = map_model_output_to_backend_violations(class_name)
final_violations_to_report = [
v for v in all_possible_violations if v in requested_violations
]
if not final_violations_to_report:
continue
bbox_width = maxx - minx
bbox_height = maxy - miny
for code in final_violations_to_report:
logger.info(f" └─ Violation Found & Requested: Type: '{code}', BBox: [{minx},{miny},{maxx},{maxy}]")
targets_output_list.append({
"type": code,
"size": [int(bbox_width), int(bbox_height)],
"leftTopPoint": [int(minx), int(miny)],
"score": score
})
logger.info(f"Prepared {len(targets_output_list)} filtered violation targets for response.")
return targets_output_list