258 lines
13 KiB
Python
258 lines
13 KiB
Python
import cv2
|
||
import numpy as np
|
||
from sahi import AutoDetectionModel
|
||
from sahi.predict import get_sliced_prediction
|
||
import logging
|
||
import torch # 用于设备检测和YOLO模型加载
|
||
from ultralytics import YOLO # 用于阶段二的ROI检测
|
||
|
||
logger = logging.getLogger(__name__)
|
||
logging.basicConfig(level=logging.INFO) # 添加这行可以看到logger的输出
|
||
|
||
# --- 更新的配置 ---
|
||
CLASS_NAMES_BEHAVIOR = ["smoke", "face"]
|
||
CIGARETTE_CLASS_ID = 0
|
||
FACE_CLASS_ID = 1
|
||
|
||
MODEL_WEIGHTS_PATH_BEHAVIOR = "models/smoking_model/best.pt"
|
||
|
||
# SAHI 切片参数
|
||
SAHI_PARAMS_FACE_LARGE = {"slice_height": 1280, "slice_width": 1280, "overlap_height_ratio": 0.2, "overlap_width_ratio": 0.2}
|
||
SAHI_PARAMS_FACE_SMALL = {"slice_height": 640, "slice_width": 640, "overlap_height_ratio": 0.2, "overlap_width_ratio": 0.2}
|
||
|
||
# ROI 内香烟检测的YOLO置信度
|
||
CIGARETTE_IN_ROI_CONF_THRESHOLD = 0.6
|
||
FACE_ROI_PADDING = 30
|
||
|
||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||
|
||
# 全局模型实例
|
||
detection_model_sahi_behavior = None
|
||
yolo_model_direct_behavior = None
|
||
|
||
|
||
def initialize_smoking_model():
|
||
global detection_model_sahi_behavior, yolo_model_direct_behavior
|
||
if detection_model_sahi_behavior is None or yolo_model_direct_behavior is None:
|
||
logger.info(f"Loading behavior detection model (face & cigarette) from {MODEL_WEIGHTS_PATH_BEHAVIOR}...")
|
||
try:
|
||
# 1. SAHI封装的模型,用于全局切片检测
|
||
detection_model_sahi_behavior = AutoDetectionModel.from_pretrained(
|
||
model_type="ultralytics",
|
||
model_path=MODEL_WEIGHTS_PATH_BEHAVIOR,
|
||
confidence_threshold=0.7,
|
||
device=DEVICE,
|
||
)
|
||
# 2. Ultralytics直接加载的模型,用于快速的ROI检测
|
||
yolo_model_direct_behavior = YOLO(MODEL_WEIGHTS_PATH_BEHAVIOR)
|
||
yolo_model_direct_behavior.to(DEVICE)
|
||
|
||
if hasattr(yolo_model_direct_behavior, 'names') and yolo_model_direct_behavior.names:
|
||
model_names = list(yolo_model_direct_behavior.names.values())
|
||
if model_names != CLASS_NAMES_BEHAVIOR and sorted(model_names) != sorted(CLASS_NAMES_BEHAVIOR) :
|
||
logger.warning(f"Mismatch in class names! Expected from code: {CLASS_NAMES_BEHAVIOR}, "
|
||
f"Got from model: {model_names}. Please verify consistency.")
|
||
else:
|
||
logger.info(f"Model class names confirmed: {model_names}")
|
||
else:
|
||
logger.warning("Could not retrieve class names from loaded yolo_model_direct_behavior for verification.")
|
||
|
||
logger.info("Behavior detection models (SAHI and Direct YOLO) loaded successfully.")
|
||
except Exception as e:
|
||
logger.error(f"Error loading behavior model: {e}", exc_info=True)
|
||
detection_model_sahi_behavior = None
|
||
yolo_model_direct_behavior = None
|
||
raise RuntimeError(f"Could not load behavior model: {e}")
|
||
return detection_model_sahi_behavior, yolo_model_direct_behavior
|
||
|
||
def _run_sahi_and_process_results(image_cv: np.ndarray, model, sahi_params: dict, target_class_id: int, target_class_name: str, model_actual_class_names: list):
|
||
try:
|
||
sahi_result = get_sliced_prediction(
|
||
image=image_cv, detection_model=model, slice_height=sahi_params["slice_height"],
|
||
slice_width=sahi_params["slice_width"], overlap_height_ratio=sahi_params["overlap_height_ratio"],
|
||
overlap_width_ratio=sahi_params["overlap_width_ratio"], verbose=0,
|
||
postprocess_type="GREEDYNMM",
|
||
postprocess_match_metric="IOS",
|
||
postprocess_match_threshold=0.7,
|
||
postprocess_class_agnostic=False
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"SAHI prediction failed for class '{target_class_name}' with params {sahi_params}: {e}", exc_info=True)
|
||
return []
|
||
|
||
valid_predictions = []
|
||
if hasattr(sahi_result, 'object_prediction_list'):
|
||
for pred in sahi_result.object_prediction_list:
|
||
cat_id_from_sahi = int(pred.category.id)
|
||
|
||
if not (0 <= cat_id_from_sahi < len(model_actual_class_names)):
|
||
logger.warning(f"SAHI Raw: Detected class ID {cat_id_from_sahi} for '{target_class_name}' "
|
||
f"is out of range for actual model class names ({len(model_actual_class_names)} classes). Ignoring pred.")
|
||
continue
|
||
|
||
if cat_id_from_sahi == target_class_id:
|
||
valid_predictions.append(pred)
|
||
|
||
logger.info(f"SAHI found {len(valid_predictions)} instances of '{target_class_name}' with params {sahi_params}.")
|
||
return valid_predictions
|
||
|
||
|
||
def detect_smoking_behavior(image_cv: np.ndarray, extract: bool = False):
|
||
global detection_model_sahi_behavior, yolo_model_direct_behavior
|
||
|
||
if detection_model_sahi_behavior is None or yolo_model_direct_behavior is None:
|
||
logger.warning("Behavior model (SAHI or Direct YOLO) was not loaded. Attempting to initialize now.")
|
||
try:
|
||
initialize_smoking_model()
|
||
except RuntimeError:
|
||
logger.error("Behavior model could not be initialized for detection.")
|
||
return {"error": "Behavior model is not available."}
|
||
|
||
if detection_model_sahi_behavior is None or yolo_model_direct_behavior is None:
|
||
logger.error("Behavior model components are still None after initialization attempt.")
|
||
return {"error": "Behavior model failed to initialize."}
|
||
|
||
model_actual_class_names_dict = yolo_model_direct_behavior.names
|
||
if not model_actual_class_names_dict:
|
||
logger.error("Cannot determine actual class names from the loaded YOLO model.")
|
||
return {"error": "Model class names missing."}
|
||
|
||
max_id = max(model_actual_class_names_dict.keys())
|
||
model_actual_class_names_list = [""] * (max_id + 1)
|
||
for id_val, name_val in model_actual_class_names_dict.items():
|
||
model_actual_class_names_list[id_val] = name_val
|
||
|
||
logger.info(f"Executing two-stage smoking behavior detection...")
|
||
|
||
# --- 阶段一: 使用SAHI检测人脸 ---
|
||
detected_faces_sahi = []
|
||
logger.info(f"Attempting FACE detection with SAHI (Large Slices {SAHI_PARAMS_FACE_LARGE})...")
|
||
face_preds_large = _run_sahi_and_process_results(
|
||
image_cv, detection_model_sahi_behavior, SAHI_PARAMS_FACE_LARGE, FACE_CLASS_ID, CLASS_NAMES_BEHAVIOR[FACE_CLASS_ID], model_actual_class_names_list
|
||
)
|
||
if face_preds_large:
|
||
detected_faces_sahi = face_preds_large
|
||
else:
|
||
logger.info(f"No SAHI face detections (large slices). Retrying with small slices {SAHI_PARAMS_FACE_SMALL}...")
|
||
face_preds_small = _run_sahi_and_process_results(
|
||
image_cv, detection_model_sahi_behavior, SAHI_PARAMS_FACE_SMALL, FACE_CLASS_ID, CLASS_NAMES_BEHAVIOR[FACE_CLASS_ID], model_actual_class_names_list
|
||
)
|
||
if face_preds_small:
|
||
detected_faces_sahi = face_preds_small
|
||
else:
|
||
logger.info("No SAHI face detections found even with small slices.")
|
||
|
||
# <<<--- 逻辑增强点 1:对检测到的人脸进行严格的资格审查 --- START --->>>
|
||
if detected_faces_sahi:
|
||
original_face_count = len(detected_faces_sahi)
|
||
qualified_faces = []
|
||
for face_pred in detected_faces_sahi:
|
||
# 1. 置信度审查 (您可以调整这个阈值)
|
||
if face_pred.score.value < 0.65:
|
||
continue
|
||
|
||
# 2. 尺寸审查 (您可以调整这些像素值)
|
||
x1, y1, x2, y2 = face_pred.bbox.to_xyxy()
|
||
width, height = x2 - x1, y2 - y1
|
||
if width < 32 or height < 32:
|
||
continue
|
||
|
||
# 3. 长宽比审查 (您可以调整这个范围)
|
||
aspect_ratio = width / height if height > 0 else 0
|
||
if not (0.6 < aspect_ratio < 1.6):
|
||
continue
|
||
|
||
qualified_faces.append(face_pred)
|
||
|
||
logger.info(f"Face Qualification: Initial candidates: {original_face_count}, Qualified after filtering: {len(qualified_faces)}")
|
||
detected_faces_sahi = qualified_faces # 用审查合格的列表覆盖原来的
|
||
# <<<--- 逻辑增强点 1 --- END --->>>
|
||
|
||
targets_output_list = []
|
||
if not detected_faces_sahi:
|
||
logger.info("No QUALIFIED faces found in Stage 1. No smoking behavior to report.")
|
||
return targets_output_list
|
||
|
||
logger.info(f"Found {len(detected_faces_sahi)} qualified faces in Stage 1. Proceeding to Stage 2...")
|
||
|
||
# --- 阶段二: 在每个人脸ROI内使用直接YOLO检测香烟 ---
|
||
for face_pred_sahi in detected_faces_sahi:
|
||
x1_face, y1_face, x2_face, y2_face = map(int, face_pred_sahi.bbox.to_xyxy())
|
||
|
||
# 定义并裁剪ROI
|
||
roi_x1 = max(0, x1_face - FACE_ROI_PADDING)
|
||
roi_y1 = max(0, y1_face - FACE_ROI_PADDING)
|
||
roi_x2 = min(image_cv.shape[1], x2_face + FACE_ROI_PADDING)
|
||
roi_y2 = min(image_cv.shape[0], y2_face + FACE_ROI_PADDING)
|
||
face_roi_crop = image_cv[roi_y1:roi_y2, roi_x1:roi_x2]
|
||
|
||
if face_roi_crop.size == 0:
|
||
logger.debug(f"Face ROI at [{x1_face},{y1_face},{x2_face},{y2_face}] resulted in empty crop. Skipping.")
|
||
continue
|
||
|
||
# <<<--- 逻辑增强点 2:在ROI内进行交叉验证 --- START --->>>
|
||
roi_results = yolo_model_direct_behavior.predict(source=face_roi_crop, verbose=False, device=DEVICE, conf=CIGARETTE_IN_ROI_CONF_THRESHOLD)
|
||
|
||
is_face_confirmed_in_roi = False
|
||
found_cigarette_in_roi = False
|
||
cigarette_details_for_this_face = []
|
||
|
||
if roi_results and len(roi_results) > 0:
|
||
result_for_roi = roi_results[0]
|
||
for box in result_for_roi.boxes:
|
||
class_id = int(box.cls)
|
||
if class_id == FACE_CLASS_ID:
|
||
is_face_confirmed_in_roi = True
|
||
elif class_id == CIGARETTE_CLASS_ID:
|
||
xc1_roi, yc1_roi, xc2_roi, yc2_roi = map(int, box.xyxy[0].tolist())
|
||
|
||
cig_minx_global = xc1_roi + roi_x1
|
||
cig_miny_global = yc1_roi + roi_y1
|
||
cig_maxx_global = xc2_roi + roi_x1
|
||
cig_maxy_global = yc2_roi + roi_y1
|
||
|
||
cig_width_global = cig_maxx_global - cig_minx_global
|
||
cig_height_global = cig_maxy_global - cig_miny_global
|
||
|
||
# <<<--- 逻辑增强点 3:对香烟进行严格的尺寸和比例审查 --- START --->>>
|
||
long_side = max(cig_width_global, cig_height_global)
|
||
short_side = min(cig_width_global, cig_height_global)
|
||
|
||
# 1. 绝对尺寸审查 (根据您的要求)
|
||
if long_side > 100 or short_side > 40:
|
||
logger.info(f" REJECTED Cigarette: Absolute size out of bounds. long={long_side}, short={short_side}")
|
||
continue
|
||
|
||
# 2. 长宽比审查 (针对电线杆等误检)
|
||
if short_side > 0:
|
||
aspect_ratio = long_side / short_side
|
||
# 可调:一个正常的香烟,其长度应远大于宽度。这里设定一个范围来排除比例怪异的物体。
|
||
if not (2.5 < aspect_ratio < 25.0):
|
||
logger.info(f" REJECTED Cigarette: Aspect ratio out of bounds. ratio={aspect_ratio:.2f}")
|
||
continue
|
||
elif long_side > 0: # 处理 short_side为0 但 long_side不为0 的细线情况
|
||
logger.info(f" REJECTED Cigarette: BBox is a zero-width/height line.")
|
||
continue
|
||
else: # 宽和高都为0,跳过
|
||
continue
|
||
# <<<--- 逻辑增强点 3 --- END --->>>
|
||
|
||
# 只有通过了所有审查的香烟才被视为有效
|
||
found_cigarette_in_roi = True
|
||
cigarette_details_for_this_face.append({
|
||
"type": CLASS_NAMES_BEHAVIOR[CIGARETTE_CLASS_ID],
|
||
"size": [int(cig_width_global), int(cig_height_global)],
|
||
"leftTopPoint": [int(cig_minx_global), int(cig_miny_global)],
|
||
"score": round(float(box.conf), 4)
|
||
})
|
||
|
||
# 最终裁定:必须在ROI内同时确认有人脸和(合格的)香烟
|
||
if is_face_confirmed_in_roi and found_cigarette_in_roi:
|
||
logger.info(f" SUCCESS: Cross-validation passed. Face confirmed and qualified cigarette found in ROI of face at [{x1_face},{y1_face},{x2_face},{y2_face}].")
|
||
targets_output_list.extend(cigarette_details_for_this_face)
|
||
else:
|
||
logger.info(f" INFO: Cross-validation failed for face ROI. Face confirmed in ROI: {is_face_confirmed_in_roi}, Cigarette found in ROI: {found_cigarette_in_roi}. Ignoring results for this face.")
|
||
# <<<--- 逻辑增强点 2 --- END --->>>
|
||
|
||
logger.info(f"Processed all faces. Final targets_output_list contains {len(targets_output_list)} cigarette detections after all logical checks.")
|
||
return targets_output_list |