199 lines
8.1 KiB
Python
199 lines
8.1 KiB
Python
import cv2
|
||
import numpy as np
|
||
from sahi import AutoDetectionModel
|
||
from sahi.predict import get_sliced_prediction
|
||
import logging
|
||
import os
|
||
import torch
|
||
from typing import Set, List # 导入Set和List以支持类型提示
|
||
|
||
# 使用在main.py中配置好的根logger
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# --- 1. 配置模型和类别信息 (这是需要您根据实际模型调整的部分) ---
|
||
|
||
# 新的“状态检测”模型的类别名称列表。顺序必须与模型训练时一致。
|
||
CLASS_NAMES_PPE_STATE_MODEL = ["wearingall", "noequipment", "nohelmet", "novest"]
|
||
|
||
# 后端期望接收的违规类型字符串
|
||
BACKEND_VIOLATION_CODE_NO_HELMET = "nohelmet"
|
||
BACKEND_VIOLATION_CODE_NO_VEST = "novest"
|
||
|
||
# 新模型的权重文件路径
|
||
MODEL_WEIGHTS_PATH_HELMET = "models/ppe_state_model/best.pt"
|
||
|
||
|
||
# --- 2. 保留并适配现有的调用接口 ---
|
||
|
||
# SAHI切片参数
|
||
SAHI_PARAMS = {"slice_height": 1440, "slice_width": 1440, "overlap_height_ratio": 0.3, "overlap_width_ratio": 0.3}
|
||
|
||
# 全局模型实例
|
||
detection_model_helmet = None # 变量名保持'detection_model_helmet'与旧版一致
|
||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||
|
||
|
||
def initialize_helmet_model():
|
||
"""
|
||
符合 main.py 调用约定的初始化函数。
|
||
它加载新的 PPE 状态模型。
|
||
"""
|
||
global detection_model_helmet
|
||
if detection_model_helmet is None:
|
||
if not os.path.exists(MODEL_WEIGHTS_PATH_HELMET):
|
||
error_msg = f"Model weights file not found at: {MODEL_WEIGHTS_PATH_HELMET}"
|
||
logger.error(error_msg)
|
||
raise FileNotFoundError(error_msg)
|
||
|
||
logger.info(f"Loading hardhat state detection model from {MODEL_WEIGHTS_PATH_HELMET}...")
|
||
try:
|
||
detection_model_helmet = AutoDetectionModel.from_pretrained(
|
||
model_type="yolov8",
|
||
model_path=MODEL_WEIGHTS_PATH_HELMET,
|
||
confidence_threshold=0.83,
|
||
device=DEVICE,
|
||
)
|
||
# 可选的验证步骤
|
||
loaded_class_names = list(detection_model_helmet.model.names.values())
|
||
if sorted(loaded_class_names) != sorted(CLASS_NAMES_PPE_STATE_MODEL):
|
||
logger.warning(f"Model class names mismatch! Code expects {CLASS_NAMES_PPE_STATE_MODEL} "
|
||
f"but model has {loaded_class_names}. Please check consistency.")
|
||
|
||
logger.info("hardhat state detection model loaded successfully.")
|
||
except Exception as e:
|
||
logger.error(f"Error loading hardhat state model: {e}", exc_info=True)
|
||
raise RuntimeError(f"Could not load hardhat state model: {e}")
|
||
return detection_model_helmet
|
||
|
||
def _run_sahi_and_process_results(image_cv: np.ndarray, model, sahi_params: dict, model_class_names: list) -> List:
|
||
"""
|
||
辅助函数,执行SAHI推理并返回所有有效预测的列表。
|
||
"""
|
||
try:
|
||
sahi_result = get_sliced_prediction(
|
||
image=image_cv,
|
||
detection_model=model,
|
||
slice_height=sahi_params["slice_height"],
|
||
slice_width=sahi_params["slice_width"],
|
||
overlap_height_ratio=sahi_params["overlap_height_ratio"],
|
||
overlap_width_ratio=sahi_params["overlap_width_ratio"],
|
||
verbose=0,
|
||
postprocess_type="GREEDYNMM",
|
||
postprocess_match_metric="IOS",
|
||
postprocess_match_threshold=0.5
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"SAHI prediction failed with params {sahi_params}: {e}", exc_info=True)
|
||
return []
|
||
|
||
object_predictions = sahi_result.object_prediction_list if sahi_result else []
|
||
|
||
valid_predictions = []
|
||
for pred in object_predictions:
|
||
cat_id = int(pred.category.id)
|
||
if 0 <= cat_id < len(model_class_names):
|
||
valid_predictions.append(pred)
|
||
else:
|
||
logger.warning(f"SAHI Raw: Detected class ID {cat_id} is out of range. Ignoring prediction.")
|
||
return valid_predictions
|
||
|
||
def map_model_output_to_backend_violations(class_name: str) -> list[str]:
|
||
"""
|
||
将模型输出的状态类别名“翻译”成后端需要的违规代码列表。
|
||
"""
|
||
if class_name == "nohelmet":
|
||
return [BACKEND_VIOLATION_CODE_NO_HELMET]
|
||
elif class_name == "novest":
|
||
return [BACKEND_VIOLATION_CODE_NO_VEST]
|
||
elif class_name == "noequipment":
|
||
return [BACKEND_VIOLATION_CODE_NO_HELMET, BACKEND_VIOLATION_CODE_NO_VEST]
|
||
return []
|
||
|
||
def detect_hardhat_with_sahi(image_cv: np.ndarray, requested_violations: Set[str], extract: bool = False):
|
||
"""
|
||
主检测函数现在接收一个`requested_violations`集合参数,用于过滤最终的输出。
|
||
"""
|
||
global detection_model_helmet
|
||
if detection_model_helmet is None:
|
||
logger.warning("hardhat state model was not loaded. Initializing on first call.")
|
||
try:
|
||
initialize_helmet_model()
|
||
except Exception as e:
|
||
logger.error(f"Failed to initialize model on demand: {e}")
|
||
raise RuntimeError(f"hardhat model could not be initialized: {e}")
|
||
|
||
logger.info(f"Executing PPE detection, requested violations: {requested_violations}...")
|
||
|
||
img_height, img_width, _ = image_cv.shape
|
||
|
||
raw_sahi_predictions = _run_sahi_and_process_results(
|
||
image_cv, detection_model_helmet, SAHI_PARAMS, CLASS_NAMES_PPE_STATE_MODEL
|
||
)
|
||
if not raw_sahi_predictions:
|
||
logger.info("No raw objects detected by SAHI.")
|
||
return []
|
||
|
||
logger.info(f"SAHI raw detection count: {len(raw_sahi_predictions)}. Applying post-processing filters...")
|
||
|
||
# --- 后处理过滤逻辑 ---
|
||
filtered_predictions = []
|
||
MIN_PERSON_HEIGHT_PX = 40
|
||
MAX_PERSON_WIDTH_RATIO = 0.21
|
||
MAX_PERSON_HEIGHT_RATIO = 0.33
|
||
MIN_ASPECT_RATIO = 1.2
|
||
MAX_ASPECT_RATIO = 5.0
|
||
FEET_POSITION_THRESHOLD_RATIO = 0.05
|
||
|
||
for pred in raw_sahi_predictions:
|
||
minx, miny, maxx, maxy = map(int, pred.bbox.to_xyxy())
|
||
score = pred.score.value
|
||
bbox_width = maxx - minx
|
||
bbox_height = maxy - miny
|
||
if bbox_height < MIN_PERSON_HEIGHT_PX: continue
|
||
relative_width = bbox_width / img_width
|
||
relative_height = bbox_height / img_height
|
||
if relative_width > MAX_PERSON_WIDTH_RATIO or relative_height > MAX_PERSON_HEIGHT_RATIO: continue
|
||
if bbox_width == 0: continue
|
||
aspect_ratio = bbox_height / bbox_width
|
||
if not (MIN_ASPECT_RATIO <= aspect_ratio <= MAX_ASPECT_RATIO): continue
|
||
if maxy < (img_height * FEET_POSITION_THRESHOLD_RATIO): continue
|
||
bbox_area = bbox_width * bbox_height
|
||
image_area = img_width * img_height
|
||
if bbox_area > (image_area * 0.1) and score < 0.5: continue
|
||
filtered_predictions.append(pred)
|
||
|
||
logger.info(f"After filtering, {len(filtered_predictions)} valid person states remain. Mapping to violations...")
|
||
|
||
# --- 映射和最终过滤逻辑 ---
|
||
targets_output_list = []
|
||
for pred in filtered_predictions:
|
||
class_id = int(pred.category.id)
|
||
class_name = CLASS_NAMES_PPE_STATE_MODEL[class_id]
|
||
score = round(float(pred.score.value), 4)
|
||
minx, miny, maxx, maxy = map(int, pred.bbox.to_xyxy())
|
||
|
||
logger.info(f"Processing prediction -> Model says: '{class_name}', Score: {score}, BBox: [{minx},{miny},{maxx},{maxy}]")
|
||
|
||
all_possible_violations = map_model_output_to_backend_violations(class_name)
|
||
|
||
final_violations_to_report = [
|
||
v for v in all_possible_violations if v in requested_violations
|
||
]
|
||
|
||
if not final_violations_to_report:
|
||
continue
|
||
|
||
bbox_width = maxx - minx
|
||
bbox_height = maxy - miny
|
||
|
||
for code in final_violations_to_report:
|
||
logger.info(f" └─ Violation Found & Requested: Type: '{code}', BBox: [{minx},{miny},{maxx},{maxy}]")
|
||
targets_output_list.append({
|
||
"type": code,
|
||
"size": [int(bbox_width), int(bbox_height)],
|
||
"leftTopPoint": [int(minx), int(miny)],
|
||
"score": score
|
||
})
|
||
|
||
logger.info(f"Prepared {len(targets_output_list)} filtered violation targets for response.")
|
||
return targets_output_list |