import cv2 import numpy as np from sahi import AutoDetectionModel from sahi.predict import get_sliced_prediction import logging import torch # 用于设备检测和YOLO模型加载 from ultralytics import YOLO # 用于阶段二的ROI检测 logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) # 添加这行可以看到logger的输出 # --- 更新的配置 --- CLASS_NAMES_BEHAVIOR = ["smoke", "face"] CIGARETTE_CLASS_ID = 0 FACE_CLASS_ID = 1 MODEL_WEIGHTS_PATH_BEHAVIOR = "models/smoking_model/best.pt" # SAHI 切片参数 SAHI_PARAMS_FACE_LARGE = {"slice_height": 1280, "slice_width": 1280, "overlap_height_ratio": 0.2, "overlap_width_ratio": 0.2} SAHI_PARAMS_FACE_SMALL = {"slice_height": 640, "slice_width": 640, "overlap_height_ratio": 0.2, "overlap_width_ratio": 0.2} # ROI 内香烟检测的YOLO置信度 CIGARETTE_IN_ROI_CONF_THRESHOLD = 0.6 FACE_ROI_PADDING = 30 DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # 全局模型实例 detection_model_sahi_behavior = None yolo_model_direct_behavior = None def initialize_smoking_model(): global detection_model_sahi_behavior, yolo_model_direct_behavior if detection_model_sahi_behavior is None or yolo_model_direct_behavior is None: logger.info(f"Loading behavior detection model (face & cigarette) from {MODEL_WEIGHTS_PATH_BEHAVIOR}...") try: # 1. SAHI封装的模型,用于全局切片检测 detection_model_sahi_behavior = AutoDetectionModel.from_pretrained( model_type="ultralytics", model_path=MODEL_WEIGHTS_PATH_BEHAVIOR, confidence_threshold=0.7, device=DEVICE, ) # 2. Ultralytics直接加载的模型,用于快速的ROI检测 yolo_model_direct_behavior = YOLO(MODEL_WEIGHTS_PATH_BEHAVIOR) yolo_model_direct_behavior.to(DEVICE) if hasattr(yolo_model_direct_behavior, 'names') and yolo_model_direct_behavior.names: model_names = list(yolo_model_direct_behavior.names.values()) if model_names != CLASS_NAMES_BEHAVIOR and sorted(model_names) != sorted(CLASS_NAMES_BEHAVIOR) : logger.warning(f"Mismatch in class names! Expected from code: {CLASS_NAMES_BEHAVIOR}, " f"Got from model: {model_names}. Please verify consistency.") else: logger.info(f"Model class names confirmed: {model_names}") else: logger.warning("Could not retrieve class names from loaded yolo_model_direct_behavior for verification.") logger.info("Behavior detection models (SAHI and Direct YOLO) loaded successfully.") except Exception as e: logger.error(f"Error loading behavior model: {e}", exc_info=True) detection_model_sahi_behavior = None yolo_model_direct_behavior = None raise RuntimeError(f"Could not load behavior model: {e}") return detection_model_sahi_behavior, yolo_model_direct_behavior def _run_sahi_and_process_results(image_cv: np.ndarray, model, sahi_params: dict, target_class_id: int, target_class_name: str, model_actual_class_names: list): try: sahi_result = get_sliced_prediction( image=image_cv, detection_model=model, slice_height=sahi_params["slice_height"], slice_width=sahi_params["slice_width"], overlap_height_ratio=sahi_params["overlap_height_ratio"], overlap_width_ratio=sahi_params["overlap_width_ratio"], verbose=0, postprocess_type="GREEDYNMM", postprocess_match_metric="IOS", postprocess_match_threshold=0.7, postprocess_class_agnostic=False ) except Exception as e: logger.error(f"SAHI prediction failed for class '{target_class_name}' with params {sahi_params}: {e}", exc_info=True) return [] valid_predictions = [] if hasattr(sahi_result, 'object_prediction_list'): for pred in sahi_result.object_prediction_list: cat_id_from_sahi = int(pred.category.id) if not (0 <= cat_id_from_sahi < len(model_actual_class_names)): logger.warning(f"SAHI Raw: Detected class ID {cat_id_from_sahi} for '{target_class_name}' " f"is out of range for actual model class names ({len(model_actual_class_names)} classes). Ignoring pred.") continue if cat_id_from_sahi == target_class_id: valid_predictions.append(pred) logger.info(f"SAHI found {len(valid_predictions)} instances of '{target_class_name}' with params {sahi_params}.") return valid_predictions def detect_smoking_behavior(image_cv: np.ndarray, extract: bool = False): global detection_model_sahi_behavior, yolo_model_direct_behavior if detection_model_sahi_behavior is None or yolo_model_direct_behavior is None: logger.warning("Behavior model (SAHI or Direct YOLO) was not loaded. Attempting to initialize now.") try: initialize_smoking_model() except RuntimeError: logger.error("Behavior model could not be initialized for detection.") return {"error": "Behavior model is not available."} if detection_model_sahi_behavior is None or yolo_model_direct_behavior is None: logger.error("Behavior model components are still None after initialization attempt.") return {"error": "Behavior model failed to initialize."} model_actual_class_names_dict = yolo_model_direct_behavior.names if not model_actual_class_names_dict: logger.error("Cannot determine actual class names from the loaded YOLO model.") return {"error": "Model class names missing."} max_id = max(model_actual_class_names_dict.keys()) model_actual_class_names_list = [""] * (max_id + 1) for id_val, name_val in model_actual_class_names_dict.items(): model_actual_class_names_list[id_val] = name_val logger.info(f"Executing two-stage smoking behavior detection...") # --- 阶段一: 使用SAHI检测人脸 --- detected_faces_sahi = [] logger.info(f"Attempting FACE detection with SAHI (Large Slices {SAHI_PARAMS_FACE_LARGE})...") face_preds_large = _run_sahi_and_process_results( image_cv, detection_model_sahi_behavior, SAHI_PARAMS_FACE_LARGE, FACE_CLASS_ID, CLASS_NAMES_BEHAVIOR[FACE_CLASS_ID], model_actual_class_names_list ) if face_preds_large: detected_faces_sahi = face_preds_large else: logger.info(f"No SAHI face detections (large slices). Retrying with small slices {SAHI_PARAMS_FACE_SMALL}...") face_preds_small = _run_sahi_and_process_results( image_cv, detection_model_sahi_behavior, SAHI_PARAMS_FACE_SMALL, FACE_CLASS_ID, CLASS_NAMES_BEHAVIOR[FACE_CLASS_ID], model_actual_class_names_list ) if face_preds_small: detected_faces_sahi = face_preds_small else: logger.info("No SAHI face detections found even with small slices.") # <<<--- 逻辑增强点 1:对检测到的人脸进行严格的资格审查 --- START --->>> if detected_faces_sahi: original_face_count = len(detected_faces_sahi) qualified_faces = [] for face_pred in detected_faces_sahi: # 1. 置信度审查 (您可以调整这个阈值) if face_pred.score.value < 0.65: continue # 2. 尺寸审查 (您可以调整这些像素值) x1, y1, x2, y2 = face_pred.bbox.to_xyxy() width, height = x2 - x1, y2 - y1 if width < 32 or height < 32: continue # 3. 长宽比审查 (您可以调整这个范围) aspect_ratio = width / height if height > 0 else 0 if not (0.6 < aspect_ratio < 1.6): continue qualified_faces.append(face_pred) logger.info(f"Face Qualification: Initial candidates: {original_face_count}, Qualified after filtering: {len(qualified_faces)}") detected_faces_sahi = qualified_faces # 用审查合格的列表覆盖原来的 # <<<--- 逻辑增强点 1 --- END --->>> targets_output_list = [] if not detected_faces_sahi: logger.info("No QUALIFIED faces found in Stage 1. No smoking behavior to report.") return targets_output_list logger.info(f"Found {len(detected_faces_sahi)} qualified faces in Stage 1. Proceeding to Stage 2...") # --- 阶段二: 在每个人脸ROI内使用直接YOLO检测香烟 --- for face_pred_sahi in detected_faces_sahi: x1_face, y1_face, x2_face, y2_face = map(int, face_pred_sahi.bbox.to_xyxy()) # 定义并裁剪ROI roi_x1 = max(0, x1_face - FACE_ROI_PADDING) roi_y1 = max(0, y1_face - FACE_ROI_PADDING) roi_x2 = min(image_cv.shape[1], x2_face + FACE_ROI_PADDING) roi_y2 = min(image_cv.shape[0], y2_face + FACE_ROI_PADDING) face_roi_crop = image_cv[roi_y1:roi_y2, roi_x1:roi_x2] if face_roi_crop.size == 0: logger.debug(f"Face ROI at [{x1_face},{y1_face},{x2_face},{y2_face}] resulted in empty crop. Skipping.") continue # <<<--- 逻辑增强点 2:在ROI内进行交叉验证 --- START --->>> roi_results = yolo_model_direct_behavior.predict(source=face_roi_crop, verbose=False, device=DEVICE, conf=CIGARETTE_IN_ROI_CONF_THRESHOLD) is_face_confirmed_in_roi = False found_cigarette_in_roi = False cigarette_details_for_this_face = [] if roi_results and len(roi_results) > 0: result_for_roi = roi_results[0] for box in result_for_roi.boxes: class_id = int(box.cls) if class_id == FACE_CLASS_ID: is_face_confirmed_in_roi = True elif class_id == CIGARETTE_CLASS_ID: xc1_roi, yc1_roi, xc2_roi, yc2_roi = map(int, box.xyxy[0].tolist()) cig_minx_global = xc1_roi + roi_x1 cig_miny_global = yc1_roi + roi_y1 cig_maxx_global = xc2_roi + roi_x1 cig_maxy_global = yc2_roi + roi_y1 cig_width_global = cig_maxx_global - cig_minx_global cig_height_global = cig_maxy_global - cig_miny_global # <<<--- 逻辑增强点 3:对香烟进行严格的尺寸和比例审查 --- START --->>> long_side = max(cig_width_global, cig_height_global) short_side = min(cig_width_global, cig_height_global) # 1. 绝对尺寸审查 (根据您的要求) if long_side > 100 or short_side > 40: logger.info(f" REJECTED Cigarette: Absolute size out of bounds. long={long_side}, short={short_side}") continue # 2. 长宽比审查 (针对电线杆等误检) if short_side > 0: aspect_ratio = long_side / short_side # 可调:一个正常的香烟,其长度应远大于宽度。这里设定一个范围来排除比例怪异的物体。 if not (2.5 < aspect_ratio < 25.0): logger.info(f" REJECTED Cigarette: Aspect ratio out of bounds. ratio={aspect_ratio:.2f}") continue elif long_side > 0: # 处理 short_side为0 但 long_side不为0 的细线情况 logger.info(f" REJECTED Cigarette: BBox is a zero-width/height line.") continue else: # 宽和高都为0,跳过 continue # <<<--- 逻辑增强点 3 --- END --->>> # 只有通过了所有审查的香烟才被视为有效 found_cigarette_in_roi = True cigarette_details_for_this_face.append({ "type": CLASS_NAMES_BEHAVIOR[CIGARETTE_CLASS_ID], "size": [int(cig_width_global), int(cig_height_global)], "leftTopPoint": [int(cig_minx_global), int(cig_miny_global)], "score": round(float(box.conf), 4) }) # 最终裁定:必须在ROI内同时确认有人脸和(合格的)香烟 if is_face_confirmed_in_roi and found_cigarette_in_roi: logger.info(f" SUCCESS: Cross-validation passed. Face confirmed and qualified cigarette found in ROI of face at [{x1_face},{y1_face},{x2_face},{y2_face}].") targets_output_list.extend(cigarette_details_for_this_face) else: logger.info(f" INFO: Cross-validation failed for face ROI. Face confirmed in ROI: {is_face_confirmed_in_roi}, Cigarette found in ROI: {found_cigarette_in_roi}. Ignoring results for this face.") # <<<--- 逻辑增强点 2 --- END --->>> logger.info(f"Processed all faces. Final targets_output_list contains {len(targets_output_list)} cigarette detections after all logical checks.") return targets_output_list