现有服务备份

This commit is contained in:
2025-07-24 12:45:27 +08:00
commit 7ae047c7c2
23 changed files with 161390 additions and 0 deletions

199
utils/hardhat_detector.py Normal file
View File

@ -0,0 +1,199 @@
import cv2
import numpy as np
from sahi import AutoDetectionModel
from sahi.predict import get_sliced_prediction
import logging
import os
import torch
from typing import Set, List # 导入Set和List以支持类型提示
# 使用在main.py中配置好的根logger
logger = logging.getLogger(__name__)
# --- 1. 配置模型和类别信息 (这是需要您根据实际模型调整的部分) ---
# 新的“状态检测”模型的类别名称列表。顺序必须与模型训练时一致。
CLASS_NAMES_PPE_STATE_MODEL = ["wearingall", "noequipment", "nohelmet", "novest"]
# 后端期望接收的违规类型字符串
BACKEND_VIOLATION_CODE_NO_HELMET = "nohelmet"
BACKEND_VIOLATION_CODE_NO_VEST = "novest"
# 新模型的权重文件路径
MODEL_WEIGHTS_PATH_HELMET = "models/ppe_state_model/best.pt"
# --- 2. 保留并适配现有的调用接口 ---
# SAHI切片参数
SAHI_PARAMS = {"slice_height": 1440, "slice_width": 1440, "overlap_height_ratio": 0.3, "overlap_width_ratio": 0.3}
# 全局模型实例
detection_model_helmet = None # 变量名保持'detection_model_helmet'与旧版一致
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def initialize_helmet_model():
"""
符合 main.py 调用约定的初始化函数。
它加载新的 PPE 状态模型。
"""
global detection_model_helmet
if detection_model_helmet is None:
if not os.path.exists(MODEL_WEIGHTS_PATH_HELMET):
error_msg = f"Model weights file not found at: {MODEL_WEIGHTS_PATH_HELMET}"
logger.error(error_msg)
raise FileNotFoundError(error_msg)
logger.info(f"Loading hardhat state detection model from {MODEL_WEIGHTS_PATH_HELMET}...")
try:
detection_model_helmet = AutoDetectionModel.from_pretrained(
model_type="yolov8",
model_path=MODEL_WEIGHTS_PATH_HELMET,
confidence_threshold=0.83,
device=DEVICE,
)
# 可选的验证步骤
loaded_class_names = list(detection_model_helmet.model.names.values())
if sorted(loaded_class_names) != sorted(CLASS_NAMES_PPE_STATE_MODEL):
logger.warning(f"Model class names mismatch! Code expects {CLASS_NAMES_PPE_STATE_MODEL} "
f"but model has {loaded_class_names}. Please check consistency.")
logger.info("hardhat state detection model loaded successfully.")
except Exception as e:
logger.error(f"Error loading hardhat state model: {e}", exc_info=True)
raise RuntimeError(f"Could not load hardhat state model: {e}")
return detection_model_helmet
def _run_sahi_and_process_results(image_cv: np.ndarray, model, sahi_params: dict, model_class_names: list) -> List:
"""
辅助函数执行SAHI推理并返回所有有效预测的列表。
"""
try:
sahi_result = get_sliced_prediction(
image=image_cv,
detection_model=model,
slice_height=sahi_params["slice_height"],
slice_width=sahi_params["slice_width"],
overlap_height_ratio=sahi_params["overlap_height_ratio"],
overlap_width_ratio=sahi_params["overlap_width_ratio"],
verbose=0,
postprocess_type="GREEDYNMM",
postprocess_match_metric="IOS",
postprocess_match_threshold=0.5
)
except Exception as e:
logger.error(f"SAHI prediction failed with params {sahi_params}: {e}", exc_info=True)
return []
object_predictions = sahi_result.object_prediction_list if sahi_result else []
valid_predictions = []
for pred in object_predictions:
cat_id = int(pred.category.id)
if 0 <= cat_id < len(model_class_names):
valid_predictions.append(pred)
else:
logger.warning(f"SAHI Raw: Detected class ID {cat_id} is out of range. Ignoring prediction.")
return valid_predictions
def map_model_output_to_backend_violations(class_name: str) -> list[str]:
"""
将模型输出的状态类别名“翻译”成后端需要的违规代码列表。
"""
if class_name == "nohelmet":
return [BACKEND_VIOLATION_CODE_NO_HELMET]
elif class_name == "novest":
return [BACKEND_VIOLATION_CODE_NO_VEST]
elif class_name == "noequipment":
return [BACKEND_VIOLATION_CODE_NO_HELMET, BACKEND_VIOLATION_CODE_NO_VEST]
return []
def detect_hardhat_with_sahi(image_cv: np.ndarray, requested_violations: Set[str], extract: bool = False):
"""
主检测函数现在接收一个`requested_violations`集合参数,用于过滤最终的输出。
"""
global detection_model_helmet
if detection_model_helmet is None:
logger.warning("hardhat state model was not loaded. Initializing on first call.")
try:
initialize_helmet_model()
except Exception as e:
logger.error(f"Failed to initialize model on demand: {e}")
raise RuntimeError(f"hardhat model could not be initialized: {e}")
logger.info(f"Executing PPE detection, requested violations: {requested_violations}...")
img_height, img_width, _ = image_cv.shape
raw_sahi_predictions = _run_sahi_and_process_results(
image_cv, detection_model_helmet, SAHI_PARAMS, CLASS_NAMES_PPE_STATE_MODEL
)
if not raw_sahi_predictions:
logger.info("No raw objects detected by SAHI.")
return []
logger.info(f"SAHI raw detection count: {len(raw_sahi_predictions)}. Applying post-processing filters...")
# --- 后处理过滤逻辑 ---
filtered_predictions = []
MIN_PERSON_HEIGHT_PX = 40
MAX_PERSON_WIDTH_RATIO = 0.21
MAX_PERSON_HEIGHT_RATIO = 0.33
MIN_ASPECT_RATIO = 1.2
MAX_ASPECT_RATIO = 5.0
FEET_POSITION_THRESHOLD_RATIO = 0.05
for pred in raw_sahi_predictions:
minx, miny, maxx, maxy = map(int, pred.bbox.to_xyxy())
score = pred.score.value
bbox_width = maxx - minx
bbox_height = maxy - miny
if bbox_height < MIN_PERSON_HEIGHT_PX: continue
relative_width = bbox_width / img_width
relative_height = bbox_height / img_height
if relative_width > MAX_PERSON_WIDTH_RATIO or relative_height > MAX_PERSON_HEIGHT_RATIO: continue
if bbox_width == 0: continue
aspect_ratio = bbox_height / bbox_width
if not (MIN_ASPECT_RATIO <= aspect_ratio <= MAX_ASPECT_RATIO): continue
if maxy < (img_height * FEET_POSITION_THRESHOLD_RATIO): continue
bbox_area = bbox_width * bbox_height
image_area = img_width * img_height
if bbox_area > (image_area * 0.1) and score < 0.5: continue
filtered_predictions.append(pred)
logger.info(f"After filtering, {len(filtered_predictions)} valid person states remain. Mapping to violations...")
# --- 映射和最终过滤逻辑 ---
targets_output_list = []
for pred in filtered_predictions:
class_id = int(pred.category.id)
class_name = CLASS_NAMES_PPE_STATE_MODEL[class_id]
score = round(float(pred.score.value), 4)
minx, miny, maxx, maxy = map(int, pred.bbox.to_xyxy())
logger.info(f"Processing prediction -> Model says: '{class_name}', Score: {score}, BBox: [{minx},{miny},{maxx},{maxy}]")
all_possible_violations = map_model_output_to_backend_violations(class_name)
final_violations_to_report = [
v for v in all_possible_violations if v in requested_violations
]
if not final_violations_to_report:
continue
bbox_width = maxx - minx
bbox_height = maxy - miny
for code in final_violations_to_report:
logger.info(f" └─ Violation Found & Requested: Type: '{code}', BBox: [{minx},{miny},{maxx},{maxy}]")
targets_output_list.append({
"type": code,
"size": [int(bbox_width), int(bbox_height)],
"leftTopPoint": [int(minx), int(miny)],
"score": score
})
logger.info(f"Prepared {len(targets_output_list)} filtered violation targets for response.")
return targets_output_list