现有服务备份
This commit is contained in:
199
utils/hardhat_detector.py
Normal file
199
utils/hardhat_detector.py
Normal file
@ -0,0 +1,199 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
from sahi import AutoDetectionModel
|
||||
from sahi.predict import get_sliced_prediction
|
||||
import logging
|
||||
import os
|
||||
import torch
|
||||
from typing import Set, List # 导入Set和List以支持类型提示
|
||||
|
||||
# 使用在main.py中配置好的根logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# --- 1. 配置模型和类别信息 (这是需要您根据实际模型调整的部分) ---
|
||||
|
||||
# 新的“状态检测”模型的类别名称列表。顺序必须与模型训练时一致。
|
||||
CLASS_NAMES_PPE_STATE_MODEL = ["wearingall", "noequipment", "nohelmet", "novest"]
|
||||
|
||||
# 后端期望接收的违规类型字符串
|
||||
BACKEND_VIOLATION_CODE_NO_HELMET = "nohelmet"
|
||||
BACKEND_VIOLATION_CODE_NO_VEST = "novest"
|
||||
|
||||
# 新模型的权重文件路径
|
||||
MODEL_WEIGHTS_PATH_HELMET = "models/ppe_state_model/best.pt"
|
||||
|
||||
|
||||
# --- 2. 保留并适配现有的调用接口 ---
|
||||
|
||||
# SAHI切片参数
|
||||
SAHI_PARAMS = {"slice_height": 1440, "slice_width": 1440, "overlap_height_ratio": 0.3, "overlap_width_ratio": 0.3}
|
||||
|
||||
# 全局模型实例
|
||||
detection_model_helmet = None # 变量名保持'detection_model_helmet'与旧版一致
|
||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
def initialize_helmet_model():
|
||||
"""
|
||||
符合 main.py 调用约定的初始化函数。
|
||||
它加载新的 PPE 状态模型。
|
||||
"""
|
||||
global detection_model_helmet
|
||||
if detection_model_helmet is None:
|
||||
if not os.path.exists(MODEL_WEIGHTS_PATH_HELMET):
|
||||
error_msg = f"Model weights file not found at: {MODEL_WEIGHTS_PATH_HELMET}"
|
||||
logger.error(error_msg)
|
||||
raise FileNotFoundError(error_msg)
|
||||
|
||||
logger.info(f"Loading hardhat state detection model from {MODEL_WEIGHTS_PATH_HELMET}...")
|
||||
try:
|
||||
detection_model_helmet = AutoDetectionModel.from_pretrained(
|
||||
model_type="yolov8",
|
||||
model_path=MODEL_WEIGHTS_PATH_HELMET,
|
||||
confidence_threshold=0.83,
|
||||
device=DEVICE,
|
||||
)
|
||||
# 可选的验证步骤
|
||||
loaded_class_names = list(detection_model_helmet.model.names.values())
|
||||
if sorted(loaded_class_names) != sorted(CLASS_NAMES_PPE_STATE_MODEL):
|
||||
logger.warning(f"Model class names mismatch! Code expects {CLASS_NAMES_PPE_STATE_MODEL} "
|
||||
f"but model has {loaded_class_names}. Please check consistency.")
|
||||
|
||||
logger.info("hardhat state detection model loaded successfully.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading hardhat state model: {e}", exc_info=True)
|
||||
raise RuntimeError(f"Could not load hardhat state model: {e}")
|
||||
return detection_model_helmet
|
||||
|
||||
def _run_sahi_and_process_results(image_cv: np.ndarray, model, sahi_params: dict, model_class_names: list) -> List:
|
||||
"""
|
||||
辅助函数,执行SAHI推理并返回所有有效预测的列表。
|
||||
"""
|
||||
try:
|
||||
sahi_result = get_sliced_prediction(
|
||||
image=image_cv,
|
||||
detection_model=model,
|
||||
slice_height=sahi_params["slice_height"],
|
||||
slice_width=sahi_params["slice_width"],
|
||||
overlap_height_ratio=sahi_params["overlap_height_ratio"],
|
||||
overlap_width_ratio=sahi_params["overlap_width_ratio"],
|
||||
verbose=0,
|
||||
postprocess_type="GREEDYNMM",
|
||||
postprocess_match_metric="IOS",
|
||||
postprocess_match_threshold=0.5
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"SAHI prediction failed with params {sahi_params}: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
object_predictions = sahi_result.object_prediction_list if sahi_result else []
|
||||
|
||||
valid_predictions = []
|
||||
for pred in object_predictions:
|
||||
cat_id = int(pred.category.id)
|
||||
if 0 <= cat_id < len(model_class_names):
|
||||
valid_predictions.append(pred)
|
||||
else:
|
||||
logger.warning(f"SAHI Raw: Detected class ID {cat_id} is out of range. Ignoring prediction.")
|
||||
return valid_predictions
|
||||
|
||||
def map_model_output_to_backend_violations(class_name: str) -> list[str]:
|
||||
"""
|
||||
将模型输出的状态类别名“翻译”成后端需要的违规代码列表。
|
||||
"""
|
||||
if class_name == "nohelmet":
|
||||
return [BACKEND_VIOLATION_CODE_NO_HELMET]
|
||||
elif class_name == "novest":
|
||||
return [BACKEND_VIOLATION_CODE_NO_VEST]
|
||||
elif class_name == "noequipment":
|
||||
return [BACKEND_VIOLATION_CODE_NO_HELMET, BACKEND_VIOLATION_CODE_NO_VEST]
|
||||
return []
|
||||
|
||||
def detect_hardhat_with_sahi(image_cv: np.ndarray, requested_violations: Set[str], extract: bool = False):
|
||||
"""
|
||||
主检测函数现在接收一个`requested_violations`集合参数,用于过滤最终的输出。
|
||||
"""
|
||||
global detection_model_helmet
|
||||
if detection_model_helmet is None:
|
||||
logger.warning("hardhat state model was not loaded. Initializing on first call.")
|
||||
try:
|
||||
initialize_helmet_model()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize model on demand: {e}")
|
||||
raise RuntimeError(f"hardhat model could not be initialized: {e}")
|
||||
|
||||
logger.info(f"Executing PPE detection, requested violations: {requested_violations}...")
|
||||
|
||||
img_height, img_width, _ = image_cv.shape
|
||||
|
||||
raw_sahi_predictions = _run_sahi_and_process_results(
|
||||
image_cv, detection_model_helmet, SAHI_PARAMS, CLASS_NAMES_PPE_STATE_MODEL
|
||||
)
|
||||
if not raw_sahi_predictions:
|
||||
logger.info("No raw objects detected by SAHI.")
|
||||
return []
|
||||
|
||||
logger.info(f"SAHI raw detection count: {len(raw_sahi_predictions)}. Applying post-processing filters...")
|
||||
|
||||
# --- 后处理过滤逻辑 ---
|
||||
filtered_predictions = []
|
||||
MIN_PERSON_HEIGHT_PX = 40
|
||||
MAX_PERSON_WIDTH_RATIO = 0.21
|
||||
MAX_PERSON_HEIGHT_RATIO = 0.33
|
||||
MIN_ASPECT_RATIO = 1.2
|
||||
MAX_ASPECT_RATIO = 5.0
|
||||
FEET_POSITION_THRESHOLD_RATIO = 0.05
|
||||
|
||||
for pred in raw_sahi_predictions:
|
||||
minx, miny, maxx, maxy = map(int, pred.bbox.to_xyxy())
|
||||
score = pred.score.value
|
||||
bbox_width = maxx - minx
|
||||
bbox_height = maxy - miny
|
||||
if bbox_height < MIN_PERSON_HEIGHT_PX: continue
|
||||
relative_width = bbox_width / img_width
|
||||
relative_height = bbox_height / img_height
|
||||
if relative_width > MAX_PERSON_WIDTH_RATIO or relative_height > MAX_PERSON_HEIGHT_RATIO: continue
|
||||
if bbox_width == 0: continue
|
||||
aspect_ratio = bbox_height / bbox_width
|
||||
if not (MIN_ASPECT_RATIO <= aspect_ratio <= MAX_ASPECT_RATIO): continue
|
||||
if maxy < (img_height * FEET_POSITION_THRESHOLD_RATIO): continue
|
||||
bbox_area = bbox_width * bbox_height
|
||||
image_area = img_width * img_height
|
||||
if bbox_area > (image_area * 0.1) and score < 0.5: continue
|
||||
filtered_predictions.append(pred)
|
||||
|
||||
logger.info(f"After filtering, {len(filtered_predictions)} valid person states remain. Mapping to violations...")
|
||||
|
||||
# --- 映射和最终过滤逻辑 ---
|
||||
targets_output_list = []
|
||||
for pred in filtered_predictions:
|
||||
class_id = int(pred.category.id)
|
||||
class_name = CLASS_NAMES_PPE_STATE_MODEL[class_id]
|
||||
score = round(float(pred.score.value), 4)
|
||||
minx, miny, maxx, maxy = map(int, pred.bbox.to_xyxy())
|
||||
|
||||
logger.info(f"Processing prediction -> Model says: '{class_name}', Score: {score}, BBox: [{minx},{miny},{maxx},{maxy}]")
|
||||
|
||||
all_possible_violations = map_model_output_to_backend_violations(class_name)
|
||||
|
||||
final_violations_to_report = [
|
||||
v for v in all_possible_violations if v in requested_violations
|
||||
]
|
||||
|
||||
if not final_violations_to_report:
|
||||
continue
|
||||
|
||||
bbox_width = maxx - minx
|
||||
bbox_height = maxy - miny
|
||||
|
||||
for code in final_violations_to_report:
|
||||
logger.info(f" └─ Violation Found & Requested: Type: '{code}', BBox: [{minx},{miny},{maxx},{maxy}]")
|
||||
targets_output_list.append({
|
||||
"type": code,
|
||||
"size": [int(bbox_width), int(bbox_height)],
|
||||
"leftTopPoint": [int(minx), int(miny)],
|
||||
"score": score
|
||||
})
|
||||
|
||||
logger.info(f"Prepared {len(targets_output_list)} filtered violation targets for response.")
|
||||
return targets_output_list
|
Reference in New Issue
Block a user