Files
AI-TianDong/utils/smoking_detector.py
2025-07-24 12:45:27 +08:00

258 lines
13 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import cv2
import numpy as np
from sahi import AutoDetectionModel
from sahi.predict import get_sliced_prediction
import logging
import torch # 用于设备检测和YOLO模型加载
from ultralytics import YOLO # 用于阶段二的ROI检测
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO) # 添加这行可以看到logger的输出
# --- 更新的配置 ---
CLASS_NAMES_BEHAVIOR = ["smoke", "face"]
CIGARETTE_CLASS_ID = 0
FACE_CLASS_ID = 1
MODEL_WEIGHTS_PATH_BEHAVIOR = "models/smoking_model/best.pt"
# SAHI 切片参数
SAHI_PARAMS_FACE_LARGE = {"slice_height": 1280, "slice_width": 1280, "overlap_height_ratio": 0.2, "overlap_width_ratio": 0.2}
SAHI_PARAMS_FACE_SMALL = {"slice_height": 640, "slice_width": 640, "overlap_height_ratio": 0.2, "overlap_width_ratio": 0.2}
# ROI 内香烟检测的YOLO置信度
CIGARETTE_IN_ROI_CONF_THRESHOLD = 0.6
FACE_ROI_PADDING = 30
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# 全局模型实例
detection_model_sahi_behavior = None
yolo_model_direct_behavior = None
def initialize_smoking_model():
global detection_model_sahi_behavior, yolo_model_direct_behavior
if detection_model_sahi_behavior is None or yolo_model_direct_behavior is None:
logger.info(f"Loading behavior detection model (face & cigarette) from {MODEL_WEIGHTS_PATH_BEHAVIOR}...")
try:
# 1. SAHI封装的模型用于全局切片检测
detection_model_sahi_behavior = AutoDetectionModel.from_pretrained(
model_type="ultralytics",
model_path=MODEL_WEIGHTS_PATH_BEHAVIOR,
confidence_threshold=0.7,
device=DEVICE,
)
# 2. Ultralytics直接加载的模型用于快速的ROI检测
yolo_model_direct_behavior = YOLO(MODEL_WEIGHTS_PATH_BEHAVIOR)
yolo_model_direct_behavior.to(DEVICE)
if hasattr(yolo_model_direct_behavior, 'names') and yolo_model_direct_behavior.names:
model_names = list(yolo_model_direct_behavior.names.values())
if model_names != CLASS_NAMES_BEHAVIOR and sorted(model_names) != sorted(CLASS_NAMES_BEHAVIOR) :
logger.warning(f"Mismatch in class names! Expected from code: {CLASS_NAMES_BEHAVIOR}, "
f"Got from model: {model_names}. Please verify consistency.")
else:
logger.info(f"Model class names confirmed: {model_names}")
else:
logger.warning("Could not retrieve class names from loaded yolo_model_direct_behavior for verification.")
logger.info("Behavior detection models (SAHI and Direct YOLO) loaded successfully.")
except Exception as e:
logger.error(f"Error loading behavior model: {e}", exc_info=True)
detection_model_sahi_behavior = None
yolo_model_direct_behavior = None
raise RuntimeError(f"Could not load behavior model: {e}")
return detection_model_sahi_behavior, yolo_model_direct_behavior
def _run_sahi_and_process_results(image_cv: np.ndarray, model, sahi_params: dict, target_class_id: int, target_class_name: str, model_actual_class_names: list):
try:
sahi_result = get_sliced_prediction(
image=image_cv, detection_model=model, slice_height=sahi_params["slice_height"],
slice_width=sahi_params["slice_width"], overlap_height_ratio=sahi_params["overlap_height_ratio"],
overlap_width_ratio=sahi_params["overlap_width_ratio"], verbose=0,
postprocess_type="GREEDYNMM",
postprocess_match_metric="IOS",
postprocess_match_threshold=0.7,
postprocess_class_agnostic=False
)
except Exception as e:
logger.error(f"SAHI prediction failed for class '{target_class_name}' with params {sahi_params}: {e}", exc_info=True)
return []
valid_predictions = []
if hasattr(sahi_result, 'object_prediction_list'):
for pred in sahi_result.object_prediction_list:
cat_id_from_sahi = int(pred.category.id)
if not (0 <= cat_id_from_sahi < len(model_actual_class_names)):
logger.warning(f"SAHI Raw: Detected class ID {cat_id_from_sahi} for '{target_class_name}' "
f"is out of range for actual model class names ({len(model_actual_class_names)} classes). Ignoring pred.")
continue
if cat_id_from_sahi == target_class_id:
valid_predictions.append(pred)
logger.info(f"SAHI found {len(valid_predictions)} instances of '{target_class_name}' with params {sahi_params}.")
return valid_predictions
def detect_smoking_behavior(image_cv: np.ndarray, extract: bool = False):
global detection_model_sahi_behavior, yolo_model_direct_behavior
if detection_model_sahi_behavior is None or yolo_model_direct_behavior is None:
logger.warning("Behavior model (SAHI or Direct YOLO) was not loaded. Attempting to initialize now.")
try:
initialize_smoking_model()
except RuntimeError:
logger.error("Behavior model could not be initialized for detection.")
return {"error": "Behavior model is not available."}
if detection_model_sahi_behavior is None or yolo_model_direct_behavior is None:
logger.error("Behavior model components are still None after initialization attempt.")
return {"error": "Behavior model failed to initialize."}
model_actual_class_names_dict = yolo_model_direct_behavior.names
if not model_actual_class_names_dict:
logger.error("Cannot determine actual class names from the loaded YOLO model.")
return {"error": "Model class names missing."}
max_id = max(model_actual_class_names_dict.keys())
model_actual_class_names_list = [""] * (max_id + 1)
for id_val, name_val in model_actual_class_names_dict.items():
model_actual_class_names_list[id_val] = name_val
logger.info(f"Executing two-stage smoking behavior detection...")
# --- 阶段一: 使用SAHI检测人脸 ---
detected_faces_sahi = []
logger.info(f"Attempting FACE detection with SAHI (Large Slices {SAHI_PARAMS_FACE_LARGE})...")
face_preds_large = _run_sahi_and_process_results(
image_cv, detection_model_sahi_behavior, SAHI_PARAMS_FACE_LARGE, FACE_CLASS_ID, CLASS_NAMES_BEHAVIOR[FACE_CLASS_ID], model_actual_class_names_list
)
if face_preds_large:
detected_faces_sahi = face_preds_large
else:
logger.info(f"No SAHI face detections (large slices). Retrying with small slices {SAHI_PARAMS_FACE_SMALL}...")
face_preds_small = _run_sahi_and_process_results(
image_cv, detection_model_sahi_behavior, SAHI_PARAMS_FACE_SMALL, FACE_CLASS_ID, CLASS_NAMES_BEHAVIOR[FACE_CLASS_ID], model_actual_class_names_list
)
if face_preds_small:
detected_faces_sahi = face_preds_small
else:
logger.info("No SAHI face detections found even with small slices.")
# <<<--- 逻辑增强点 1对检测到的人脸进行严格的资格审查 --- START --->>>
if detected_faces_sahi:
original_face_count = len(detected_faces_sahi)
qualified_faces = []
for face_pred in detected_faces_sahi:
# 1. 置信度审查 (您可以调整这个阈值)
if face_pred.score.value < 0.65:
continue
# 2. 尺寸审查 (您可以调整这些像素值)
x1, y1, x2, y2 = face_pred.bbox.to_xyxy()
width, height = x2 - x1, y2 - y1
if width < 32 or height < 32:
continue
# 3. 长宽比审查 (您可以调整这个范围)
aspect_ratio = width / height if height > 0 else 0
if not (0.6 < aspect_ratio < 1.6):
continue
qualified_faces.append(face_pred)
logger.info(f"Face Qualification: Initial candidates: {original_face_count}, Qualified after filtering: {len(qualified_faces)}")
detected_faces_sahi = qualified_faces # 用审查合格的列表覆盖原来的
# <<<--- 逻辑增强点 1 --- END --->>>
targets_output_list = []
if not detected_faces_sahi:
logger.info("No QUALIFIED faces found in Stage 1. No smoking behavior to report.")
return targets_output_list
logger.info(f"Found {len(detected_faces_sahi)} qualified faces in Stage 1. Proceeding to Stage 2...")
# --- 阶段二: 在每个人脸ROI内使用直接YOLO检测香烟 ---
for face_pred_sahi in detected_faces_sahi:
x1_face, y1_face, x2_face, y2_face = map(int, face_pred_sahi.bbox.to_xyxy())
# 定义并裁剪ROI
roi_x1 = max(0, x1_face - FACE_ROI_PADDING)
roi_y1 = max(0, y1_face - FACE_ROI_PADDING)
roi_x2 = min(image_cv.shape[1], x2_face + FACE_ROI_PADDING)
roi_y2 = min(image_cv.shape[0], y2_face + FACE_ROI_PADDING)
face_roi_crop = image_cv[roi_y1:roi_y2, roi_x1:roi_x2]
if face_roi_crop.size == 0:
logger.debug(f"Face ROI at [{x1_face},{y1_face},{x2_face},{y2_face}] resulted in empty crop. Skipping.")
continue
# <<<--- 逻辑增强点 2在ROI内进行交叉验证 --- START --->>>
roi_results = yolo_model_direct_behavior.predict(source=face_roi_crop, verbose=False, device=DEVICE, conf=CIGARETTE_IN_ROI_CONF_THRESHOLD)
is_face_confirmed_in_roi = False
found_cigarette_in_roi = False
cigarette_details_for_this_face = []
if roi_results and len(roi_results) > 0:
result_for_roi = roi_results[0]
for box in result_for_roi.boxes:
class_id = int(box.cls)
if class_id == FACE_CLASS_ID:
is_face_confirmed_in_roi = True
elif class_id == CIGARETTE_CLASS_ID:
xc1_roi, yc1_roi, xc2_roi, yc2_roi = map(int, box.xyxy[0].tolist())
cig_minx_global = xc1_roi + roi_x1
cig_miny_global = yc1_roi + roi_y1
cig_maxx_global = xc2_roi + roi_x1
cig_maxy_global = yc2_roi + roi_y1
cig_width_global = cig_maxx_global - cig_minx_global
cig_height_global = cig_maxy_global - cig_miny_global
# <<<--- 逻辑增强点 3对香烟进行严格的尺寸和比例审查 --- START --->>>
long_side = max(cig_width_global, cig_height_global)
short_side = min(cig_width_global, cig_height_global)
# 1. 绝对尺寸审查 (根据您的要求)
if long_side > 100 or short_side > 40:
logger.info(f" REJECTED Cigarette: Absolute size out of bounds. long={long_side}, short={short_side}")
continue
# 2. 长宽比审查 (针对电线杆等误检)
if short_side > 0:
aspect_ratio = long_side / short_side
# 可调:一个正常的香烟,其长度应远大于宽度。这里设定一个范围来排除比例怪异的物体。
if not (2.5 < aspect_ratio < 25.0):
logger.info(f" REJECTED Cigarette: Aspect ratio out of bounds. ratio={aspect_ratio:.2f}")
continue
elif long_side > 0: # 处理 short_side为0 但 long_side不为0 的细线情况
logger.info(f" REJECTED Cigarette: BBox is a zero-width/height line.")
continue
else: # 宽和高都为0跳过
continue
# <<<--- 逻辑增强点 3 --- END --->>>
# 只有通过了所有审查的香烟才被视为有效
found_cigarette_in_roi = True
cigarette_details_for_this_face.append({
"type": CLASS_NAMES_BEHAVIOR[CIGARETTE_CLASS_ID],
"size": [int(cig_width_global), int(cig_height_global)],
"leftTopPoint": [int(cig_minx_global), int(cig_miny_global)],
"score": round(float(box.conf), 4)
})
# 最终裁定必须在ROI内同时确认有人脸和合格的香烟
if is_face_confirmed_in_roi and found_cigarette_in_roi:
logger.info(f" SUCCESS: Cross-validation passed. Face confirmed and qualified cigarette found in ROI of face at [{x1_face},{y1_face},{x2_face},{y2_face}].")
targets_output_list.extend(cigarette_details_for_this_face)
else:
logger.info(f" INFO: Cross-validation failed for face ROI. Face confirmed in ROI: {is_face_confirmed_in_roi}, Cigarette found in ROI: {found_cigarette_in_roi}. Ignoring results for this face.")
# <<<--- 逻辑增强点 2 --- END --->>>
logger.info(f"Processed all faces. Final targets_output_list contains {len(targets_output_list)} cigarette detections after all logical checks.")
return targets_output_list