现有服务备份

This commit is contained in:
2025-07-24 12:45:27 +08:00
commit 7ae047c7c2
23 changed files with 161390 additions and 0 deletions

258
utils/smoking_detector.py Normal file
View File

@ -0,0 +1,258 @@
import cv2
import numpy as np
from sahi import AutoDetectionModel
from sahi.predict import get_sliced_prediction
import logging
import torch # 用于设备检测和YOLO模型加载
from ultralytics import YOLO # 用于阶段二的ROI检测
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO) # 添加这行可以看到logger的输出
# --- 更新的配置 ---
CLASS_NAMES_BEHAVIOR = ["smoke", "face"]
CIGARETTE_CLASS_ID = 0
FACE_CLASS_ID = 1
MODEL_WEIGHTS_PATH_BEHAVIOR = "models/smoking_model/best.pt"
# SAHI 切片参数
SAHI_PARAMS_FACE_LARGE = {"slice_height": 1280, "slice_width": 1280, "overlap_height_ratio": 0.2, "overlap_width_ratio": 0.2}
SAHI_PARAMS_FACE_SMALL = {"slice_height": 640, "slice_width": 640, "overlap_height_ratio": 0.2, "overlap_width_ratio": 0.2}
# ROI 内香烟检测的YOLO置信度
CIGARETTE_IN_ROI_CONF_THRESHOLD = 0.6
FACE_ROI_PADDING = 30
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# 全局模型实例
detection_model_sahi_behavior = None
yolo_model_direct_behavior = None
def initialize_smoking_model():
global detection_model_sahi_behavior, yolo_model_direct_behavior
if detection_model_sahi_behavior is None or yolo_model_direct_behavior is None:
logger.info(f"Loading behavior detection model (face & cigarette) from {MODEL_WEIGHTS_PATH_BEHAVIOR}...")
try:
# 1. SAHI封装的模型用于全局切片检测
detection_model_sahi_behavior = AutoDetectionModel.from_pretrained(
model_type="ultralytics",
model_path=MODEL_WEIGHTS_PATH_BEHAVIOR,
confidence_threshold=0.7,
device=DEVICE,
)
# 2. Ultralytics直接加载的模型用于快速的ROI检测
yolo_model_direct_behavior = YOLO(MODEL_WEIGHTS_PATH_BEHAVIOR)
yolo_model_direct_behavior.to(DEVICE)
if hasattr(yolo_model_direct_behavior, 'names') and yolo_model_direct_behavior.names:
model_names = list(yolo_model_direct_behavior.names.values())
if model_names != CLASS_NAMES_BEHAVIOR and sorted(model_names) != sorted(CLASS_NAMES_BEHAVIOR) :
logger.warning(f"Mismatch in class names! Expected from code: {CLASS_NAMES_BEHAVIOR}, "
f"Got from model: {model_names}. Please verify consistency.")
else:
logger.info(f"Model class names confirmed: {model_names}")
else:
logger.warning("Could not retrieve class names from loaded yolo_model_direct_behavior for verification.")
logger.info("Behavior detection models (SAHI and Direct YOLO) loaded successfully.")
except Exception as e:
logger.error(f"Error loading behavior model: {e}", exc_info=True)
detection_model_sahi_behavior = None
yolo_model_direct_behavior = None
raise RuntimeError(f"Could not load behavior model: {e}")
return detection_model_sahi_behavior, yolo_model_direct_behavior
def _run_sahi_and_process_results(image_cv: np.ndarray, model, sahi_params: dict, target_class_id: int, target_class_name: str, model_actual_class_names: list):
try:
sahi_result = get_sliced_prediction(
image=image_cv, detection_model=model, slice_height=sahi_params["slice_height"],
slice_width=sahi_params["slice_width"], overlap_height_ratio=sahi_params["overlap_height_ratio"],
overlap_width_ratio=sahi_params["overlap_width_ratio"], verbose=0,
postprocess_type="GREEDYNMM",
postprocess_match_metric="IOS",
postprocess_match_threshold=0.7,
postprocess_class_agnostic=False
)
except Exception as e:
logger.error(f"SAHI prediction failed for class '{target_class_name}' with params {sahi_params}: {e}", exc_info=True)
return []
valid_predictions = []
if hasattr(sahi_result, 'object_prediction_list'):
for pred in sahi_result.object_prediction_list:
cat_id_from_sahi = int(pred.category.id)
if not (0 <= cat_id_from_sahi < len(model_actual_class_names)):
logger.warning(f"SAHI Raw: Detected class ID {cat_id_from_sahi} for '{target_class_name}' "
f"is out of range for actual model class names ({len(model_actual_class_names)} classes). Ignoring pred.")
continue
if cat_id_from_sahi == target_class_id:
valid_predictions.append(pred)
logger.info(f"SAHI found {len(valid_predictions)} instances of '{target_class_name}' with params {sahi_params}.")
return valid_predictions
def detect_smoking_behavior(image_cv: np.ndarray, extract: bool = False):
global detection_model_sahi_behavior, yolo_model_direct_behavior
if detection_model_sahi_behavior is None or yolo_model_direct_behavior is None:
logger.warning("Behavior model (SAHI or Direct YOLO) was not loaded. Attempting to initialize now.")
try:
initialize_smoking_model()
except RuntimeError:
logger.error("Behavior model could not be initialized for detection.")
return {"error": "Behavior model is not available."}
if detection_model_sahi_behavior is None or yolo_model_direct_behavior is None:
logger.error("Behavior model components are still None after initialization attempt.")
return {"error": "Behavior model failed to initialize."}
model_actual_class_names_dict = yolo_model_direct_behavior.names
if not model_actual_class_names_dict:
logger.error("Cannot determine actual class names from the loaded YOLO model.")
return {"error": "Model class names missing."}
max_id = max(model_actual_class_names_dict.keys())
model_actual_class_names_list = [""] * (max_id + 1)
for id_val, name_val in model_actual_class_names_dict.items():
model_actual_class_names_list[id_val] = name_val
logger.info(f"Executing two-stage smoking behavior detection...")
# --- 阶段一: 使用SAHI检测人脸 ---
detected_faces_sahi = []
logger.info(f"Attempting FACE detection with SAHI (Large Slices {SAHI_PARAMS_FACE_LARGE})...")
face_preds_large = _run_sahi_and_process_results(
image_cv, detection_model_sahi_behavior, SAHI_PARAMS_FACE_LARGE, FACE_CLASS_ID, CLASS_NAMES_BEHAVIOR[FACE_CLASS_ID], model_actual_class_names_list
)
if face_preds_large:
detected_faces_sahi = face_preds_large
else:
logger.info(f"No SAHI face detections (large slices). Retrying with small slices {SAHI_PARAMS_FACE_SMALL}...")
face_preds_small = _run_sahi_and_process_results(
image_cv, detection_model_sahi_behavior, SAHI_PARAMS_FACE_SMALL, FACE_CLASS_ID, CLASS_NAMES_BEHAVIOR[FACE_CLASS_ID], model_actual_class_names_list
)
if face_preds_small:
detected_faces_sahi = face_preds_small
else:
logger.info("No SAHI face detections found even with small slices.")
# <<<--- 逻辑增强点 1对检测到的人脸进行严格的资格审查 --- START --->>>
if detected_faces_sahi:
original_face_count = len(detected_faces_sahi)
qualified_faces = []
for face_pred in detected_faces_sahi:
# 1. 置信度审查 (您可以调整这个阈值)
if face_pred.score.value < 0.65:
continue
# 2. 尺寸审查 (您可以调整这些像素值)
x1, y1, x2, y2 = face_pred.bbox.to_xyxy()
width, height = x2 - x1, y2 - y1
if width < 32 or height < 32:
continue
# 3. 长宽比审查 (您可以调整这个范围)
aspect_ratio = width / height if height > 0 else 0
if not (0.6 < aspect_ratio < 1.6):
continue
qualified_faces.append(face_pred)
logger.info(f"Face Qualification: Initial candidates: {original_face_count}, Qualified after filtering: {len(qualified_faces)}")
detected_faces_sahi = qualified_faces # 用审查合格的列表覆盖原来的
# <<<--- 逻辑增强点 1 --- END --->>>
targets_output_list = []
if not detected_faces_sahi:
logger.info("No QUALIFIED faces found in Stage 1. No smoking behavior to report.")
return targets_output_list
logger.info(f"Found {len(detected_faces_sahi)} qualified faces in Stage 1. Proceeding to Stage 2...")
# --- 阶段二: 在每个人脸ROI内使用直接YOLO检测香烟 ---
for face_pred_sahi in detected_faces_sahi:
x1_face, y1_face, x2_face, y2_face = map(int, face_pred_sahi.bbox.to_xyxy())
# 定义并裁剪ROI
roi_x1 = max(0, x1_face - FACE_ROI_PADDING)
roi_y1 = max(0, y1_face - FACE_ROI_PADDING)
roi_x2 = min(image_cv.shape[1], x2_face + FACE_ROI_PADDING)
roi_y2 = min(image_cv.shape[0], y2_face + FACE_ROI_PADDING)
face_roi_crop = image_cv[roi_y1:roi_y2, roi_x1:roi_x2]
if face_roi_crop.size == 0:
logger.debug(f"Face ROI at [{x1_face},{y1_face},{x2_face},{y2_face}] resulted in empty crop. Skipping.")
continue
# <<<--- 逻辑增强点 2在ROI内进行交叉验证 --- START --->>>
roi_results = yolo_model_direct_behavior.predict(source=face_roi_crop, verbose=False, device=DEVICE, conf=CIGARETTE_IN_ROI_CONF_THRESHOLD)
is_face_confirmed_in_roi = False
found_cigarette_in_roi = False
cigarette_details_for_this_face = []
if roi_results and len(roi_results) > 0:
result_for_roi = roi_results[0]
for box in result_for_roi.boxes:
class_id = int(box.cls)
if class_id == FACE_CLASS_ID:
is_face_confirmed_in_roi = True
elif class_id == CIGARETTE_CLASS_ID:
xc1_roi, yc1_roi, xc2_roi, yc2_roi = map(int, box.xyxy[0].tolist())
cig_minx_global = xc1_roi + roi_x1
cig_miny_global = yc1_roi + roi_y1
cig_maxx_global = xc2_roi + roi_x1
cig_maxy_global = yc2_roi + roi_y1
cig_width_global = cig_maxx_global - cig_minx_global
cig_height_global = cig_maxy_global - cig_miny_global
# <<<--- 逻辑增强点 3对香烟进行严格的尺寸和比例审查 --- START --->>>
long_side = max(cig_width_global, cig_height_global)
short_side = min(cig_width_global, cig_height_global)
# 1. 绝对尺寸审查 (根据您的要求)
if long_side > 100 or short_side > 40:
logger.info(f" REJECTED Cigarette: Absolute size out of bounds. long={long_side}, short={short_side}")
continue
# 2. 长宽比审查 (针对电线杆等误检)
if short_side > 0:
aspect_ratio = long_side / short_side
# 可调:一个正常的香烟,其长度应远大于宽度。这里设定一个范围来排除比例怪异的物体。
if not (2.5 < aspect_ratio < 25.0):
logger.info(f" REJECTED Cigarette: Aspect ratio out of bounds. ratio={aspect_ratio:.2f}")
continue
elif long_side > 0: # 处理 short_side为0 但 long_side不为0 的细线情况
logger.info(f" REJECTED Cigarette: BBox is a zero-width/height line.")
continue
else: # 宽和高都为0跳过
continue
# <<<--- 逻辑增强点 3 --- END --->>>
# 只有通过了所有审查的香烟才被视为有效
found_cigarette_in_roi = True
cigarette_details_for_this_face.append({
"type": CLASS_NAMES_BEHAVIOR[CIGARETTE_CLASS_ID],
"size": [int(cig_width_global), int(cig_height_global)],
"leftTopPoint": [int(cig_minx_global), int(cig_miny_global)],
"score": round(float(box.conf), 4)
})
# 最终裁定必须在ROI内同时确认有人脸和合格的香烟
if is_face_confirmed_in_roi and found_cigarette_in_roi:
logger.info(f" SUCCESS: Cross-validation passed. Face confirmed and qualified cigarette found in ROI of face at [{x1_face},{y1_face},{x2_face},{y2_face}].")
targets_output_list.extend(cigarette_details_for_this_face)
else:
logger.info(f" INFO: Cross-validation failed for face ROI. Face confirmed in ROI: {is_face_confirmed_in_roi}, Cigarette found in ROI: {found_cigarette_in_roi}. Ignoring results for this face.")
# <<<--- 逻辑增强点 2 --- END --->>>
logger.info(f"Processed all faces. Final targets_output_list contains {len(targets_output_list)} cigarette detections after all logical checks.")
return targets_output_list