import cv2 import numpy as np from sahi import AutoDetectionModel from sahi.predict import get_sliced_prediction import logging import os import torch from typing import Set, List # 导入Set和List以支持类型提示 # 使用在main.py中配置好的根logger logger = logging.getLogger(__name__) # --- 1. 配置模型和类别信息 (这是需要您根据实际模型调整的部分) --- # 新的“状态检测”模型的类别名称列表。顺序必须与模型训练时一致。 CLASS_NAMES_PPE_STATE_MODEL = ["wearingall", "noequipment", "nohelmet", "novest"] # 后端期望接收的违规类型字符串 BACKEND_VIOLATION_CODE_NO_HELMET = "nohelmet" BACKEND_VIOLATION_CODE_NO_VEST = "novest" # 新模型的权重文件路径 MODEL_WEIGHTS_PATH_HELMET = "models/ppe_state_model/best.pt" # --- 2. 保留并适配现有的调用接口 --- # SAHI切片参数 SAHI_PARAMS = {"slice_height": 1440, "slice_width": 1440, "overlap_height_ratio": 0.3, "overlap_width_ratio": 0.3} # 全局模型实例 detection_model_helmet = None # 变量名保持'detection_model_helmet'与旧版一致 DEVICE = "cuda" if torch.cuda.is_available() else "cpu" def initialize_helmet_model(): """ 符合 main.py 调用约定的初始化函数。 它加载新的 PPE 状态模型。 """ global detection_model_helmet if detection_model_helmet is None: if not os.path.exists(MODEL_WEIGHTS_PATH_HELMET): error_msg = f"Model weights file not found at: {MODEL_WEIGHTS_PATH_HELMET}" logger.error(error_msg) raise FileNotFoundError(error_msg) logger.info(f"Loading hardhat state detection model from {MODEL_WEIGHTS_PATH_HELMET}...") try: detection_model_helmet = AutoDetectionModel.from_pretrained( model_type="yolov8", model_path=MODEL_WEIGHTS_PATH_HELMET, confidence_threshold=0.83, device=DEVICE, ) # 可选的验证步骤 loaded_class_names = list(detection_model_helmet.model.names.values()) if sorted(loaded_class_names) != sorted(CLASS_NAMES_PPE_STATE_MODEL): logger.warning(f"Model class names mismatch! Code expects {CLASS_NAMES_PPE_STATE_MODEL} " f"but model has {loaded_class_names}. Please check consistency.") logger.info("hardhat state detection model loaded successfully.") except Exception as e: logger.error(f"Error loading hardhat state model: {e}", exc_info=True) raise RuntimeError(f"Could not load hardhat state model: {e}") return detection_model_helmet def _run_sahi_and_process_results(image_cv: np.ndarray, model, sahi_params: dict, model_class_names: list) -> List: """ 辅助函数,执行SAHI推理并返回所有有效预测的列表。 """ try: sahi_result = get_sliced_prediction( image=image_cv, detection_model=model, slice_height=sahi_params["slice_height"], slice_width=sahi_params["slice_width"], overlap_height_ratio=sahi_params["overlap_height_ratio"], overlap_width_ratio=sahi_params["overlap_width_ratio"], verbose=0, postprocess_type="GREEDYNMM", postprocess_match_metric="IOS", postprocess_match_threshold=0.5 ) except Exception as e: logger.error(f"SAHI prediction failed with params {sahi_params}: {e}", exc_info=True) return [] object_predictions = sahi_result.object_prediction_list if sahi_result else [] valid_predictions = [] for pred in object_predictions: cat_id = int(pred.category.id) if 0 <= cat_id < len(model_class_names): valid_predictions.append(pred) else: logger.warning(f"SAHI Raw: Detected class ID {cat_id} is out of range. Ignoring prediction.") return valid_predictions def map_model_output_to_backend_violations(class_name: str) -> list[str]: """ 将模型输出的状态类别名“翻译”成后端需要的违规代码列表。 """ if class_name == "nohelmet": return [BACKEND_VIOLATION_CODE_NO_HELMET] elif class_name == "novest": return [BACKEND_VIOLATION_CODE_NO_VEST] elif class_name == "noequipment": return [BACKEND_VIOLATION_CODE_NO_HELMET, BACKEND_VIOLATION_CODE_NO_VEST] return [] def detect_hardhat_with_sahi(image_cv: np.ndarray, requested_violations: Set[str], extract: bool = False): """ 主检测函数现在接收一个`requested_violations`集合参数,用于过滤最终的输出。 """ global detection_model_helmet if detection_model_helmet is None: logger.warning("hardhat state model was not loaded. Initializing on first call.") try: initialize_helmet_model() except Exception as e: logger.error(f"Failed to initialize model on demand: {e}") raise RuntimeError(f"hardhat model could not be initialized: {e}") logger.info(f"Executing PPE detection, requested violations: {requested_violations}...") img_height, img_width, _ = image_cv.shape raw_sahi_predictions = _run_sahi_and_process_results( image_cv, detection_model_helmet, SAHI_PARAMS, CLASS_NAMES_PPE_STATE_MODEL ) if not raw_sahi_predictions: logger.info("No raw objects detected by SAHI.") return [] logger.info(f"SAHI raw detection count: {len(raw_sahi_predictions)}. Applying post-processing filters...") # --- 后处理过滤逻辑 --- filtered_predictions = [] MIN_PERSON_HEIGHT_PX = 40 MAX_PERSON_WIDTH_RATIO = 0.21 MAX_PERSON_HEIGHT_RATIO = 0.33 MIN_ASPECT_RATIO = 1.2 MAX_ASPECT_RATIO = 5.0 FEET_POSITION_THRESHOLD_RATIO = 0.05 for pred in raw_sahi_predictions: minx, miny, maxx, maxy = map(int, pred.bbox.to_xyxy()) score = pred.score.value bbox_width = maxx - minx bbox_height = maxy - miny if bbox_height < MIN_PERSON_HEIGHT_PX: continue relative_width = bbox_width / img_width relative_height = bbox_height / img_height if relative_width > MAX_PERSON_WIDTH_RATIO or relative_height > MAX_PERSON_HEIGHT_RATIO: continue if bbox_width == 0: continue aspect_ratio = bbox_height / bbox_width if not (MIN_ASPECT_RATIO <= aspect_ratio <= MAX_ASPECT_RATIO): continue if maxy < (img_height * FEET_POSITION_THRESHOLD_RATIO): continue bbox_area = bbox_width * bbox_height image_area = img_width * img_height if bbox_area > (image_area * 0.1) and score < 0.5: continue filtered_predictions.append(pred) logger.info(f"After filtering, {len(filtered_predictions)} valid person states remain. Mapping to violations...") # --- 映射和最终过滤逻辑 --- targets_output_list = [] for pred in filtered_predictions: class_id = int(pred.category.id) class_name = CLASS_NAMES_PPE_STATE_MODEL[class_id] score = round(float(pred.score.value), 4) minx, miny, maxx, maxy = map(int, pred.bbox.to_xyxy()) logger.info(f"Processing prediction -> Model says: '{class_name}', Score: {score}, BBox: [{minx},{miny},{maxx},{maxy}]") all_possible_violations = map_model_output_to_backend_violations(class_name) final_violations_to_report = [ v for v in all_possible_violations if v in requested_violations ] if not final_violations_to_report: continue bbox_width = maxx - minx bbox_height = maxy - miny for code in final_violations_to_report: logger.info(f" └─ Violation Found & Requested: Type: '{code}', BBox: [{minx},{miny},{maxx},{maxy}]") targets_output_list.append({ "type": code, "size": [int(bbox_width), int(bbox_height)], "leftTopPoint": [int(minx), int(miny)], "score": score }) logger.info(f"Prepared {len(targets_output_list)} filtered violation targets for response.") return targets_output_list