现有服务备份

This commit is contained in:
2025-07-24 12:45:27 +08:00
commit 7ae047c7c2
23 changed files with 161390 additions and 0 deletions

View File

@ -0,0 +1,93 @@
import cv2
import numpy as np
# import supervision as sv # No longer directly needed
from sahi import AutoDetectionModel
from sahi.predict import get_sliced_prediction
# import base64 # No longer needed for primary output
import logging
logger = logging.getLogger(__name__)
CLASS_NAMES_FIRE_SMOKE = ["fire", "smoggy"]
MODEL_WEIGHTS_PATH_FIRE_SMOKE = "models/fire_smoke_model/best.pt"
SAHI_PARAMS_LARGE = {"slice_height": 2000, "slice_width": 2000, "overlap_height_ratio": 0.2, "overlap_width_ratio": 0.2}
SAHI_PARAMS_SMALL = {"slice_height": 256, "slice_width": 256, "overlap_height_ratio": 0.2, "overlap_width_ratio": 0.2}
detection_model_fire_smoke = None
def initialize_fire_smoke_model():
global detection_model_fire_smoke
if detection_model_fire_smoke is None:
logger.info(f"Loading fire_smoke detection model from {MODEL_WEIGHTS_PATH_FIRE_SMOKE}...")
try:
detection_model_fire_smoke = AutoDetectionModel.from_pretrained(
model_type="ultralytics",
model_path=MODEL_WEIGHTS_PATH_FIRE_SMOKE,
confidence_threshold=0.8,
)
logger.info("Fire_smoke detection model loaded successfully.")
except Exception as e:
logger.error(f"Error loading fire_smoke model: {e}", exc_info=True)
raise RuntimeError(f"Could not load fire_smoke model: {e}")
return detection_model_fire_smoke
def _run_sahi_and_process_results(image_cv: np.ndarray, model, sahi_params: dict, model_class_names: list):
# This helper function is the same
try:
sahi_result = get_sliced_prediction(
image=image_cv, detection_model=model, slice_height=sahi_params["slice_height"],
slice_width=sahi_params["slice_width"], overlap_height_ratio=sahi_params["overlap_height_ratio"],
overlap_width_ratio=sahi_params["overlap_width_ratio"], verbose=0
)
except Exception as e:
logger.error(f"SAHI prediction failed with params {sahi_params}: {e}", exc_info=True)
return []
object_predictions_original = sahi_result.object_prediction_list
valid_original_predictions = []
for pred in object_predictions_original:
cat_id = int(pred.category.id)
if 0 <= cat_id < len(model_class_names):
valid_original_predictions.append(pred)
else:
logger.warning(f"SAHI Raw: Detected class ID {cat_id} for {model.__class__.__name__} is out of range for model_class_names: {model_class_names}. Ignoring.")
return valid_original_predictions
def detect_fire_smoke_with_sahi(image_cv: np.ndarray, extract: bool = False):
global detection_model_fire_smoke
if detection_model_fire_smoke is None:
logger.warning("Fire_smoke model was not loaded. Attempting to initialize now.")
initialize_fire_smoke_model()
if detection_model_fire_smoke is None:
logger.error("Fire_smoke model could not be initialized for detection.")
return {"error": "Fire_smoke model is not available."}
logger.info(f"Executing fire_smoke detection, extract flag is: {extract}")
_obj_preds_all_model_outputs = []
logger.info(f"Attempting fire_smoke detection with SAHI (Large Slices)...")
obj_preds_attempt1 = _run_sahi_and_process_results(
image_cv, detection_model_fire_smoke, SAHI_PARAMS_LARGE, CLASS_NAMES_FIRE_SMOKE
)
if obj_preds_attempt1:
_obj_preds_all_model_outputs = obj_preds_attempt1
else:
logger.info("No SAHI fire_smoke detections (large slices). Skipping small slices attempt as requested.")
_obj_preds_all_model_outputs = []
targets_output_list = []
for pred in _obj_preds_all_model_outputs:
class_id = int(pred.category.id)
class_name = CLASS_NAMES_FIRE_SMOKE[class_id]
minx, miny, maxx, maxy = pred.bbox.minx, pred.bbox.miny, pred.bbox.maxx, pred.bbox.maxy
bbox_width = maxx - minx
bbox_height = maxy - miny
targets_output_list.append({
"type": class_name,
"size": [int(bbox_width), int(bbox_height)],
"leftTopPoint": [int(minx), int(miny)],
"score": round(float(pred.score.value), 4)
})
logger.info(f"Prepared {len(targets_output_list)} fire/smoke targets for API response.")
return targets_output_list

View File

@ -0,0 +1,111 @@
import cv2
import numpy as np
import supervision as sv
from sahi import AutoDetectionModel
from sahi.predict import get_sliced_prediction
# import base64 # No longer needed for primary output if base64 image isn't returned
import logging
logger = logging.getLogger(__name__)
CLASS_NAMES_HELMET_MODEL_OUTPUT = ["helmet", "nohelmet", "vast", "novast"]
TARGET_CLASS_FOR_REPORTING = "nohelmet"
MODEL_WEIGHTS_PATH_HELMET = "models/helmet_model/best.pt"
SAHI_PARAMS_LARGE = {"slice_height": 1200, "slice_width": 1200, "overlap_height_ratio": 0.2, "overlap_width_ratio": 0.2}
SAHI_PARAMS_SMALL = {"slice_height": 700, "slice_width": 700, "overlap_height_ratio": 0.2, "overlap_width_ratio": 0.2}
detection_model_helmet = None
def initialize_helmet_model():
global detection_model_helmet
if detection_model_helmet is None:
logger.info(f"Loading helmet detection model from {MODEL_WEIGHTS_PATH_HELMET}...")
try:
detection_model_helmet = AutoDetectionModel.from_pretrained(
model_type="ultralytics",
model_path=MODEL_WEIGHTS_PATH_HELMET,
confidence_threshold=0.8,
)
logger.info("Helmet detection model loaded successfully.")
except Exception as e:
logger.error(f"Error loading helmet model: {e}", exc_info=True)
raise RuntimeError(f"Could not load helmet model: {e}")
return detection_model_helmet
# encode_image_to_base64 is no longer strictly needed if not returning annotated image
def _run_sahi_and_process_results(image_cv: np.ndarray, model, sahi_params: dict, model_class_names: list):
# This helper function is mostly the same, returns List[ObjectPrediction]
try:
sahi_result = get_sliced_prediction(
image=image_cv, detection_model=model, slice_height=sahi_params["slice_height"],
slice_width=sahi_params["slice_width"], overlap_height_ratio=sahi_params["overlap_height_ratio"],
overlap_width_ratio=sahi_params["overlap_width_ratio"], verbose=0
)
except Exception as e:
logger.error(f"SAHI prediction failed with params {sahi_params}: {e}", exc_info=True)
return [] # Return empty list of predictions on SAHI error
object_predictions_original = sahi_result.object_prediction_list
valid_original_predictions = []
for pred in object_predictions_original:
cat_id = int(pred.category.id)
if 0 <= cat_id < len(model_class_names):
valid_original_predictions.append(pred)
else:
logger.warning(f"SAHI Raw: Detected class ID {cat_id} for {model.__class__.__name__} is out of range for model_class_names: {model_class_names}. Ignoring.")
return valid_original_predictions
def detect_hardhat_with_sahi(image_cv: np.ndarray, extract: bool = False):
global detection_model_helmet
if detection_model_helmet is None:
logger.warning("Helmet model was not loaded. Attempting to initialize now.")
initialize_helmet_model()
if detection_model_helmet is None:
logger.error("Helmet model could not be initialized for detection.")
return {"error": "Helmet model is not available."} # Return error dict
logger.info(f"Executing helmet task (reporting only '{TARGET_CLASS_FOR_REPORTING}'), extract flag is: {extract}")
_obj_preds_all_model_outputs = [] # Store all valid SAHI ObjectPrediction
logger.info(f"Attempting detection with SAHI (Large Slices)...")
obj_preds_attempt1 = _run_sahi_and_process_results(
image_cv, detection_model_helmet, SAHI_PARAMS_LARGE, CLASS_NAMES_HELMET_MODEL_OUTPUT
)
if obj_preds_attempt1: # If list is not empty
logger.info(f"Initial SAHI detections (large slices): {len(obj_preds_attempt1)}")
_obj_preds_all_model_outputs = obj_preds_attempt1
else:
logger.info("No SAHI detections (large slices). Retrying with small slices...")
obj_preds_attempt2 = _run_sahi_and_process_results(
image_cv, detection_model_helmet, SAHI_PARAMS_SMALL, CLASS_NAMES_HELMET_MODEL_OUTPUT
)
if obj_preds_attempt2:
logger.info(f"Initial SAHI detections (small slices): {len(obj_preds_attempt2)}")
_obj_preds_all_model_outputs = obj_preds_attempt2
else:
logger.info("No initial SAHI detections found even with small slices.")
# --- Transform to new output format, filtering for TARGET_CLASS_FOR_REPORTING ---
targets_output_list = []
for pred in _obj_preds_all_model_outputs:
class_id = int(pred.category.id)
class_name = CLASS_NAMES_HELMET_MODEL_OUTPUT[class_id]
if class_name == TARGET_CLASS_FOR_REPORTING:
minx, miny, maxx, maxy = pred.bbox.minx, pred.bbox.miny, pred.bbox.maxx, pred.bbox.maxy
bbox_width = maxx - minx
bbox_height = maxy - miny
targets_output_list.append({
"type": class_name,
"size": [int(bbox_width), int(bbox_height)],
"leftTopPoint": [int(minx), int(miny)],
"score": round(float(pred.score.value), 4) # Rounded score
})
logger.info(f"Prepared {len(targets_output_list)} '{TARGET_CLASS_FOR_REPORTING}' targets for API response.")
return targets_output_list # Directly return list of target dicts

199
utils/hardhat_detector.py Normal file
View File

@ -0,0 +1,199 @@
import cv2
import numpy as np
from sahi import AutoDetectionModel
from sahi.predict import get_sliced_prediction
import logging
import os
import torch
from typing import Set, List # 导入Set和List以支持类型提示
# 使用在main.py中配置好的根logger
logger = logging.getLogger(__name__)
# --- 1. 配置模型和类别信息 (这是需要您根据实际模型调整的部分) ---
# 新的“状态检测”模型的类别名称列表。顺序必须与模型训练时一致。
CLASS_NAMES_PPE_STATE_MODEL = ["wearingall", "noequipment", "nohelmet", "novest"]
# 后端期望接收的违规类型字符串
BACKEND_VIOLATION_CODE_NO_HELMET = "nohelmet"
BACKEND_VIOLATION_CODE_NO_VEST = "novest"
# 新模型的权重文件路径
MODEL_WEIGHTS_PATH_HELMET = "models/ppe_state_model/best.pt"
# --- 2. 保留并适配现有的调用接口 ---
# SAHI切片参数
SAHI_PARAMS = {"slice_height": 1440, "slice_width": 1440, "overlap_height_ratio": 0.3, "overlap_width_ratio": 0.3}
# 全局模型实例
detection_model_helmet = None # 变量名保持'detection_model_helmet'与旧版一致
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def initialize_helmet_model():
"""
符合 main.py 调用约定的初始化函数。
它加载新的 PPE 状态模型。
"""
global detection_model_helmet
if detection_model_helmet is None:
if not os.path.exists(MODEL_WEIGHTS_PATH_HELMET):
error_msg = f"Model weights file not found at: {MODEL_WEIGHTS_PATH_HELMET}"
logger.error(error_msg)
raise FileNotFoundError(error_msg)
logger.info(f"Loading hardhat state detection model from {MODEL_WEIGHTS_PATH_HELMET}...")
try:
detection_model_helmet = AutoDetectionModel.from_pretrained(
model_type="yolov8",
model_path=MODEL_WEIGHTS_PATH_HELMET,
confidence_threshold=0.83,
device=DEVICE,
)
# 可选的验证步骤
loaded_class_names = list(detection_model_helmet.model.names.values())
if sorted(loaded_class_names) != sorted(CLASS_NAMES_PPE_STATE_MODEL):
logger.warning(f"Model class names mismatch! Code expects {CLASS_NAMES_PPE_STATE_MODEL} "
f"but model has {loaded_class_names}. Please check consistency.")
logger.info("hardhat state detection model loaded successfully.")
except Exception as e:
logger.error(f"Error loading hardhat state model: {e}", exc_info=True)
raise RuntimeError(f"Could not load hardhat state model: {e}")
return detection_model_helmet
def _run_sahi_and_process_results(image_cv: np.ndarray, model, sahi_params: dict, model_class_names: list) -> List:
"""
辅助函数执行SAHI推理并返回所有有效预测的列表。
"""
try:
sahi_result = get_sliced_prediction(
image=image_cv,
detection_model=model,
slice_height=sahi_params["slice_height"],
slice_width=sahi_params["slice_width"],
overlap_height_ratio=sahi_params["overlap_height_ratio"],
overlap_width_ratio=sahi_params["overlap_width_ratio"],
verbose=0,
postprocess_type="GREEDYNMM",
postprocess_match_metric="IOS",
postprocess_match_threshold=0.5
)
except Exception as e:
logger.error(f"SAHI prediction failed with params {sahi_params}: {e}", exc_info=True)
return []
object_predictions = sahi_result.object_prediction_list if sahi_result else []
valid_predictions = []
for pred in object_predictions:
cat_id = int(pred.category.id)
if 0 <= cat_id < len(model_class_names):
valid_predictions.append(pred)
else:
logger.warning(f"SAHI Raw: Detected class ID {cat_id} is out of range. Ignoring prediction.")
return valid_predictions
def map_model_output_to_backend_violations(class_name: str) -> list[str]:
"""
将模型输出的状态类别名“翻译”成后端需要的违规代码列表。
"""
if class_name == "nohelmet":
return [BACKEND_VIOLATION_CODE_NO_HELMET]
elif class_name == "novest":
return [BACKEND_VIOLATION_CODE_NO_VEST]
elif class_name == "noequipment":
return [BACKEND_VIOLATION_CODE_NO_HELMET, BACKEND_VIOLATION_CODE_NO_VEST]
return []
def detect_hardhat_with_sahi(image_cv: np.ndarray, requested_violations: Set[str], extract: bool = False):
"""
主检测函数现在接收一个`requested_violations`集合参数,用于过滤最终的输出。
"""
global detection_model_helmet
if detection_model_helmet is None:
logger.warning("hardhat state model was not loaded. Initializing on first call.")
try:
initialize_helmet_model()
except Exception as e:
logger.error(f"Failed to initialize model on demand: {e}")
raise RuntimeError(f"hardhat model could not be initialized: {e}")
logger.info(f"Executing PPE detection, requested violations: {requested_violations}...")
img_height, img_width, _ = image_cv.shape
raw_sahi_predictions = _run_sahi_and_process_results(
image_cv, detection_model_helmet, SAHI_PARAMS, CLASS_NAMES_PPE_STATE_MODEL
)
if not raw_sahi_predictions:
logger.info("No raw objects detected by SAHI.")
return []
logger.info(f"SAHI raw detection count: {len(raw_sahi_predictions)}. Applying post-processing filters...")
# --- 后处理过滤逻辑 ---
filtered_predictions = []
MIN_PERSON_HEIGHT_PX = 40
MAX_PERSON_WIDTH_RATIO = 0.21
MAX_PERSON_HEIGHT_RATIO = 0.33
MIN_ASPECT_RATIO = 1.2
MAX_ASPECT_RATIO = 5.0
FEET_POSITION_THRESHOLD_RATIO = 0.05
for pred in raw_sahi_predictions:
minx, miny, maxx, maxy = map(int, pred.bbox.to_xyxy())
score = pred.score.value
bbox_width = maxx - minx
bbox_height = maxy - miny
if bbox_height < MIN_PERSON_HEIGHT_PX: continue
relative_width = bbox_width / img_width
relative_height = bbox_height / img_height
if relative_width > MAX_PERSON_WIDTH_RATIO or relative_height > MAX_PERSON_HEIGHT_RATIO: continue
if bbox_width == 0: continue
aspect_ratio = bbox_height / bbox_width
if not (MIN_ASPECT_RATIO <= aspect_ratio <= MAX_ASPECT_RATIO): continue
if maxy < (img_height * FEET_POSITION_THRESHOLD_RATIO): continue
bbox_area = bbox_width * bbox_height
image_area = img_width * img_height
if bbox_area > (image_area * 0.1) and score < 0.5: continue
filtered_predictions.append(pred)
logger.info(f"After filtering, {len(filtered_predictions)} valid person states remain. Mapping to violations...")
# --- 映射和最终过滤逻辑 ---
targets_output_list = []
for pred in filtered_predictions:
class_id = int(pred.category.id)
class_name = CLASS_NAMES_PPE_STATE_MODEL[class_id]
score = round(float(pred.score.value), 4)
minx, miny, maxx, maxy = map(int, pred.bbox.to_xyxy())
logger.info(f"Processing prediction -> Model says: '{class_name}', Score: {score}, BBox: [{minx},{miny},{maxx},{maxy}]")
all_possible_violations = map_model_output_to_backend_violations(class_name)
final_violations_to_report = [
v for v in all_possible_violations if v in requested_violations
]
if not final_violations_to_report:
continue
bbox_width = maxx - minx
bbox_height = maxy - miny
for code in final_violations_to_report:
logger.info(f" └─ Violation Found & Requested: Type: '{code}', BBox: [{minx},{miny},{maxx},{maxy}]")
targets_output_list.append({
"type": code,
"size": [int(bbox_width), int(bbox_height)],
"leftTopPoint": [int(minx), int(miny)],
"score": score
})
logger.info(f"Prepared {len(targets_output_list)} filtered violation targets for response.")
return targets_output_list

163
utils/pho_detector.py Normal file
View File

@ -0,0 +1,163 @@
# utils/pho_detector.py
import cv2
import numpy as np
from sahi import AutoDetectionModel
from sahi.predict import get_sliced_prediction
import logging
from typing import List, Dict, Any
# --- Basic Configuration ---
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# --- Module-Specific Configuration ---
# All classes this module can detect.
ALL_SUPPORTED_CLASSES = ['pho', 'shelves', 'pile', 'hole']
# Configuration for the first model (detects 'pho' and 'shelves').
MODEL_1_CONFIG = {
"name": "pho_model",
# IMPORTANT: Update this path to your actual model location.
"weights_path": "models/photovoltaic_model/pho_model/best.pt",
# Class order must match the model's training IDs (0: 'pho', 1: 'shelves').
"class_names": ['pho', 'shelves'],
"confidence_threshold": 0.7
}
# Configuration for the second model (detects 'pile' and 'hole').
MODEL_2_CONFIG = {
"name": "pile_model",
# IMPORTANT: Update this path to your actual model location.
"weights_path": "models/photovoltaic_model/pile_model/best.pt",
# Class order must match the model's training IDs (0: 'pile', 1: 'hole').
"class_names": ['pile', 'hole'],
"confidence_threshold": 0.7
}
# Unified SAHI slicing parameters for all tasks in this module.
SAHI_PARAMS = {
"slice_height": 800,
"slice_width": 800,
"overlap_height_ratio": 0.2,
"overlap_width_ratio": 0.2
}
# Global dictionary to cache loaded model instances.
detection_models: Dict[str, AutoDetectionModel] = {}
# --- Model Management Functions ---
def initialize_model(model_config: Dict[str, Any]) -> AutoDetectionModel:
"""Initializes a single model based on its configuration if not already loaded."""
global detection_models
model_name = model_config["name"]
if model_name not in detection_models:
logger.info(f"Loading detection model '{model_name}' from {model_config['weights_path']}...")
try:
model = AutoDetectionModel.from_pretrained(
model_type="ultralytics",
model_path=model_config["weights_path"],
confidence_threshold=model_config["confidence_threshold"],
)
detection_models[model_name] = model
logger.info(f"Model '{model_name}' loaded successfully.")
except Exception as e:
logger.error(f"Error loading model '{model_name}': {e}", exc_info=True)
raise RuntimeError(f"Could not load model '{model_name}': {e}")
return detection_models[model_name]
def initialize_all_pho_models():
"""Initializes all models defined in this module. Called once at application startup."""
logger.info("Pre-initializing all photovoltaic-related models...")
initialize_model(MODEL_1_CONFIG)
initialize_model(MODEL_2_CONFIG)
logger.info("All photovoltaic models have been initialized.")
# --- Inference and Core Logic ---
def _run_inference_for_model(image_cv: np.ndarray, model_config: Dict[str, Any]) -> List[Dict]:
"""Internal function to run SAHI prediction for a single specified model."""
try:
model = initialize_model(model_config)
except RuntimeError as e:
logger.error(f"Cannot run inference; model '{model_config['name']}' failed to initialize: {e}")
return []
logger.info(f"Running SAHI prediction with model: {model_config['name']}...")
try:
sahi_result = get_sliced_prediction(
image=image_cv,
detection_model=model,
**SAHI_PARAMS,
verbose=0
)
except Exception as e:
logger.error(f"SAHI prediction failed for model '{model_config['name']}': {e}", exc_info=True)
return []
targets_output_list = []
# Use the specific class list for this model to map IDs to names.
model_class_names = model_config["class_names"]
for pred in sahi_result.object_prediction_list:
cat_id = int(pred.category.id)
# Core Logic: Map the model's local ID to its correct global class name.
if 0 <= cat_id < len(model_class_names):
class_name = model_class_names[cat_id]
minx, miny, maxx, maxy = pred.bbox.minx, pred.bbox.miny, pred.bbox.maxx, pred.bbox.maxy
targets_output_list.append({
"type": class_name,
"size": [int(maxx - minx), int(maxy - miny)],
"leftTopPoint": [int(minx), int(miny)],
"score": round(float(pred.score.value), 4)
})
else:
logger.warning(f"Model '{model_config['name']}' detected an out-of-range class ID: {cat_id}. Ignoring.")
logger.info(f"Model '{model_config['name']}' found {len(targets_output_list)} raw objects.")
return targets_output_list
def detect_objects_by_class(image_cv: np.ndarray, target_classes: List[str]) -> List[Dict]:
"""
Main public function. Detects objects by intelligently dispatching to the correct model(s).
Args:
image_cv: The input OpenCV image.
target_classes: A list of class names to detect (e.g., ['pho', 'hole']).
Returns:
A list of detection dictionaries for the successfully found targets.
"""
final_results = []
# Determine which targets belong to Model 1.
model_1_targets = [cls for cls in target_classes if cls in MODEL_1_CONFIG["class_names"]]
# Determine which targets belong to Model 2.
model_2_targets = [cls for cls in target_classes if cls in MODEL_2_CONFIG["class_names"]]
# Run Model 1 only if it's needed.
if model_1_targets:
logger.info(f"Model 1 ('{MODEL_1_CONFIG['name']}') triggered for targets: {model_1_targets}")
model_1_results = _run_inference_for_model(image_cv, MODEL_1_CONFIG)
# Filter the raw results to only include what the user asked for.
final_results.extend([res for res in model_1_results if res["type"] in model_1_targets])
# Run Model 2 only if it's needed.
if model_2_targets:
logger.info(f"Model 2 ('{MODEL_2_CONFIG['name']}') triggered for targets: {model_2_targets}")
model_2_results = _run_inference_for_model(image_cv, MODEL_2_CONFIG)
# Filter the raw results to only include what the user asked for.
final_results.extend([res for res in model_2_results if res["type"] in model_2_targets])
logger.info(f"Photovoltaic detection complete. Returning {len(final_results)} filtered targets.")
return final_results
# --- Export for main.py ---
# This list is imported by main.py to know which tasks to delegate to this module.
PHO_DETECTOR_SUPPORTED_CLASSES = ALL_SUPPORTED_CLASSES

258
utils/smoking_detector.py Normal file
View File

@ -0,0 +1,258 @@
import cv2
import numpy as np
from sahi import AutoDetectionModel
from sahi.predict import get_sliced_prediction
import logging
import torch # 用于设备检测和YOLO模型加载
from ultralytics import YOLO # 用于阶段二的ROI检测
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO) # 添加这行可以看到logger的输出
# --- 更新的配置 ---
CLASS_NAMES_BEHAVIOR = ["smoke", "face"]
CIGARETTE_CLASS_ID = 0
FACE_CLASS_ID = 1
MODEL_WEIGHTS_PATH_BEHAVIOR = "models/smoking_model/best.pt"
# SAHI 切片参数
SAHI_PARAMS_FACE_LARGE = {"slice_height": 1280, "slice_width": 1280, "overlap_height_ratio": 0.2, "overlap_width_ratio": 0.2}
SAHI_PARAMS_FACE_SMALL = {"slice_height": 640, "slice_width": 640, "overlap_height_ratio": 0.2, "overlap_width_ratio": 0.2}
# ROI 内香烟检测的YOLO置信度
CIGARETTE_IN_ROI_CONF_THRESHOLD = 0.6
FACE_ROI_PADDING = 30
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# 全局模型实例
detection_model_sahi_behavior = None
yolo_model_direct_behavior = None
def initialize_smoking_model():
global detection_model_sahi_behavior, yolo_model_direct_behavior
if detection_model_sahi_behavior is None or yolo_model_direct_behavior is None:
logger.info(f"Loading behavior detection model (face & cigarette) from {MODEL_WEIGHTS_PATH_BEHAVIOR}...")
try:
# 1. SAHI封装的模型用于全局切片检测
detection_model_sahi_behavior = AutoDetectionModel.from_pretrained(
model_type="ultralytics",
model_path=MODEL_WEIGHTS_PATH_BEHAVIOR,
confidence_threshold=0.7,
device=DEVICE,
)
# 2. Ultralytics直接加载的模型用于快速的ROI检测
yolo_model_direct_behavior = YOLO(MODEL_WEIGHTS_PATH_BEHAVIOR)
yolo_model_direct_behavior.to(DEVICE)
if hasattr(yolo_model_direct_behavior, 'names') and yolo_model_direct_behavior.names:
model_names = list(yolo_model_direct_behavior.names.values())
if model_names != CLASS_NAMES_BEHAVIOR and sorted(model_names) != sorted(CLASS_NAMES_BEHAVIOR) :
logger.warning(f"Mismatch in class names! Expected from code: {CLASS_NAMES_BEHAVIOR}, "
f"Got from model: {model_names}. Please verify consistency.")
else:
logger.info(f"Model class names confirmed: {model_names}")
else:
logger.warning("Could not retrieve class names from loaded yolo_model_direct_behavior for verification.")
logger.info("Behavior detection models (SAHI and Direct YOLO) loaded successfully.")
except Exception as e:
logger.error(f"Error loading behavior model: {e}", exc_info=True)
detection_model_sahi_behavior = None
yolo_model_direct_behavior = None
raise RuntimeError(f"Could not load behavior model: {e}")
return detection_model_sahi_behavior, yolo_model_direct_behavior
def _run_sahi_and_process_results(image_cv: np.ndarray, model, sahi_params: dict, target_class_id: int, target_class_name: str, model_actual_class_names: list):
try:
sahi_result = get_sliced_prediction(
image=image_cv, detection_model=model, slice_height=sahi_params["slice_height"],
slice_width=sahi_params["slice_width"], overlap_height_ratio=sahi_params["overlap_height_ratio"],
overlap_width_ratio=sahi_params["overlap_width_ratio"], verbose=0,
postprocess_type="GREEDYNMM",
postprocess_match_metric="IOS",
postprocess_match_threshold=0.7,
postprocess_class_agnostic=False
)
except Exception as e:
logger.error(f"SAHI prediction failed for class '{target_class_name}' with params {sahi_params}: {e}", exc_info=True)
return []
valid_predictions = []
if hasattr(sahi_result, 'object_prediction_list'):
for pred in sahi_result.object_prediction_list:
cat_id_from_sahi = int(pred.category.id)
if not (0 <= cat_id_from_sahi < len(model_actual_class_names)):
logger.warning(f"SAHI Raw: Detected class ID {cat_id_from_sahi} for '{target_class_name}' "
f"is out of range for actual model class names ({len(model_actual_class_names)} classes). Ignoring pred.")
continue
if cat_id_from_sahi == target_class_id:
valid_predictions.append(pred)
logger.info(f"SAHI found {len(valid_predictions)} instances of '{target_class_name}' with params {sahi_params}.")
return valid_predictions
def detect_smoking_behavior(image_cv: np.ndarray, extract: bool = False):
global detection_model_sahi_behavior, yolo_model_direct_behavior
if detection_model_sahi_behavior is None or yolo_model_direct_behavior is None:
logger.warning("Behavior model (SAHI or Direct YOLO) was not loaded. Attempting to initialize now.")
try:
initialize_smoking_model()
except RuntimeError:
logger.error("Behavior model could not be initialized for detection.")
return {"error": "Behavior model is not available."}
if detection_model_sahi_behavior is None or yolo_model_direct_behavior is None:
logger.error("Behavior model components are still None after initialization attempt.")
return {"error": "Behavior model failed to initialize."}
model_actual_class_names_dict = yolo_model_direct_behavior.names
if not model_actual_class_names_dict:
logger.error("Cannot determine actual class names from the loaded YOLO model.")
return {"error": "Model class names missing."}
max_id = max(model_actual_class_names_dict.keys())
model_actual_class_names_list = [""] * (max_id + 1)
for id_val, name_val in model_actual_class_names_dict.items():
model_actual_class_names_list[id_val] = name_val
logger.info(f"Executing two-stage smoking behavior detection...")
# --- 阶段一: 使用SAHI检测人脸 ---
detected_faces_sahi = []
logger.info(f"Attempting FACE detection with SAHI (Large Slices {SAHI_PARAMS_FACE_LARGE})...")
face_preds_large = _run_sahi_and_process_results(
image_cv, detection_model_sahi_behavior, SAHI_PARAMS_FACE_LARGE, FACE_CLASS_ID, CLASS_NAMES_BEHAVIOR[FACE_CLASS_ID], model_actual_class_names_list
)
if face_preds_large:
detected_faces_sahi = face_preds_large
else:
logger.info(f"No SAHI face detections (large slices). Retrying with small slices {SAHI_PARAMS_FACE_SMALL}...")
face_preds_small = _run_sahi_and_process_results(
image_cv, detection_model_sahi_behavior, SAHI_PARAMS_FACE_SMALL, FACE_CLASS_ID, CLASS_NAMES_BEHAVIOR[FACE_CLASS_ID], model_actual_class_names_list
)
if face_preds_small:
detected_faces_sahi = face_preds_small
else:
logger.info("No SAHI face detections found even with small slices.")
# <<<--- 逻辑增强点 1对检测到的人脸进行严格的资格审查 --- START --->>>
if detected_faces_sahi:
original_face_count = len(detected_faces_sahi)
qualified_faces = []
for face_pred in detected_faces_sahi:
# 1. 置信度审查 (您可以调整这个阈值)
if face_pred.score.value < 0.65:
continue
# 2. 尺寸审查 (您可以调整这些像素值)
x1, y1, x2, y2 = face_pred.bbox.to_xyxy()
width, height = x2 - x1, y2 - y1
if width < 32 or height < 32:
continue
# 3. 长宽比审查 (您可以调整这个范围)
aspect_ratio = width / height if height > 0 else 0
if not (0.6 < aspect_ratio < 1.6):
continue
qualified_faces.append(face_pred)
logger.info(f"Face Qualification: Initial candidates: {original_face_count}, Qualified after filtering: {len(qualified_faces)}")
detected_faces_sahi = qualified_faces # 用审查合格的列表覆盖原来的
# <<<--- 逻辑增强点 1 --- END --->>>
targets_output_list = []
if not detected_faces_sahi:
logger.info("No QUALIFIED faces found in Stage 1. No smoking behavior to report.")
return targets_output_list
logger.info(f"Found {len(detected_faces_sahi)} qualified faces in Stage 1. Proceeding to Stage 2...")
# --- 阶段二: 在每个人脸ROI内使用直接YOLO检测香烟 ---
for face_pred_sahi in detected_faces_sahi:
x1_face, y1_face, x2_face, y2_face = map(int, face_pred_sahi.bbox.to_xyxy())
# 定义并裁剪ROI
roi_x1 = max(0, x1_face - FACE_ROI_PADDING)
roi_y1 = max(0, y1_face - FACE_ROI_PADDING)
roi_x2 = min(image_cv.shape[1], x2_face + FACE_ROI_PADDING)
roi_y2 = min(image_cv.shape[0], y2_face + FACE_ROI_PADDING)
face_roi_crop = image_cv[roi_y1:roi_y2, roi_x1:roi_x2]
if face_roi_crop.size == 0:
logger.debug(f"Face ROI at [{x1_face},{y1_face},{x2_face},{y2_face}] resulted in empty crop. Skipping.")
continue
# <<<--- 逻辑增强点 2在ROI内进行交叉验证 --- START --->>>
roi_results = yolo_model_direct_behavior.predict(source=face_roi_crop, verbose=False, device=DEVICE, conf=CIGARETTE_IN_ROI_CONF_THRESHOLD)
is_face_confirmed_in_roi = False
found_cigarette_in_roi = False
cigarette_details_for_this_face = []
if roi_results and len(roi_results) > 0:
result_for_roi = roi_results[0]
for box in result_for_roi.boxes:
class_id = int(box.cls)
if class_id == FACE_CLASS_ID:
is_face_confirmed_in_roi = True
elif class_id == CIGARETTE_CLASS_ID:
xc1_roi, yc1_roi, xc2_roi, yc2_roi = map(int, box.xyxy[0].tolist())
cig_minx_global = xc1_roi + roi_x1
cig_miny_global = yc1_roi + roi_y1
cig_maxx_global = xc2_roi + roi_x1
cig_maxy_global = yc2_roi + roi_y1
cig_width_global = cig_maxx_global - cig_minx_global
cig_height_global = cig_maxy_global - cig_miny_global
# <<<--- 逻辑增强点 3对香烟进行严格的尺寸和比例审查 --- START --->>>
long_side = max(cig_width_global, cig_height_global)
short_side = min(cig_width_global, cig_height_global)
# 1. 绝对尺寸审查 (根据您的要求)
if long_side > 100 or short_side > 40:
logger.info(f" REJECTED Cigarette: Absolute size out of bounds. long={long_side}, short={short_side}")
continue
# 2. 长宽比审查 (针对电线杆等误检)
if short_side > 0:
aspect_ratio = long_side / short_side
# 可调:一个正常的香烟,其长度应远大于宽度。这里设定一个范围来排除比例怪异的物体。
if not (2.5 < aspect_ratio < 25.0):
logger.info(f" REJECTED Cigarette: Aspect ratio out of bounds. ratio={aspect_ratio:.2f}")
continue
elif long_side > 0: # 处理 short_side为0 但 long_side不为0 的细线情况
logger.info(f" REJECTED Cigarette: BBox is a zero-width/height line.")
continue
else: # 宽和高都为0跳过
continue
# <<<--- 逻辑增强点 3 --- END --->>>
# 只有通过了所有审查的香烟才被视为有效
found_cigarette_in_roi = True
cigarette_details_for_this_face.append({
"type": CLASS_NAMES_BEHAVIOR[CIGARETTE_CLASS_ID],
"size": [int(cig_width_global), int(cig_height_global)],
"leftTopPoint": [int(cig_minx_global), int(cig_miny_global)],
"score": round(float(box.conf), 4)
})
# 最终裁定必须在ROI内同时确认有人脸和合格的香烟
if is_face_confirmed_in_roi and found_cigarette_in_roi:
logger.info(f" SUCCESS: Cross-validation passed. Face confirmed and qualified cigarette found in ROI of face at [{x1_face},{y1_face},{x2_face},{y2_face}].")
targets_output_list.extend(cigarette_details_for_this_face)
else:
logger.info(f" INFO: Cross-validation failed for face ROI. Face confirmed in ROI: {is_face_confirmed_in_roi}, Cigarette found in ROI: {found_cigarette_in_roi}. Ignoring results for this face.")
# <<<--- 逻辑增强点 2 --- END --->>>
logger.info(f"Processed all faces. Final targets_output_list contains {len(targets_output_list)} cigarette detections after all logical checks.")
return targets_output_list