现有服务备份
This commit is contained in:
93
utils/fire_smoke_detector.py
Normal file
93
utils/fire_smoke_detector.py
Normal file
@ -0,0 +1,93 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
# import supervision as sv # No longer directly needed
|
||||
from sahi import AutoDetectionModel
|
||||
from sahi.predict import get_sliced_prediction
|
||||
# import base64 # No longer needed for primary output
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CLASS_NAMES_FIRE_SMOKE = ["fire", "smoggy"]
|
||||
MODEL_WEIGHTS_PATH_FIRE_SMOKE = "models/fire_smoke_model/best.pt"
|
||||
|
||||
SAHI_PARAMS_LARGE = {"slice_height": 2000, "slice_width": 2000, "overlap_height_ratio": 0.2, "overlap_width_ratio": 0.2}
|
||||
SAHI_PARAMS_SMALL = {"slice_height": 256, "slice_width": 256, "overlap_height_ratio": 0.2, "overlap_width_ratio": 0.2}
|
||||
|
||||
detection_model_fire_smoke = None
|
||||
|
||||
def initialize_fire_smoke_model():
|
||||
global detection_model_fire_smoke
|
||||
if detection_model_fire_smoke is None:
|
||||
logger.info(f"Loading fire_smoke detection model from {MODEL_WEIGHTS_PATH_FIRE_SMOKE}...")
|
||||
try:
|
||||
detection_model_fire_smoke = AutoDetectionModel.from_pretrained(
|
||||
model_type="ultralytics",
|
||||
model_path=MODEL_WEIGHTS_PATH_FIRE_SMOKE,
|
||||
confidence_threshold=0.8,
|
||||
)
|
||||
logger.info("Fire_smoke detection model loaded successfully.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading fire_smoke model: {e}", exc_info=True)
|
||||
raise RuntimeError(f"Could not load fire_smoke model: {e}")
|
||||
return detection_model_fire_smoke
|
||||
|
||||
def _run_sahi_and_process_results(image_cv: np.ndarray, model, sahi_params: dict, model_class_names: list):
|
||||
# This helper function is the same
|
||||
try:
|
||||
sahi_result = get_sliced_prediction(
|
||||
image=image_cv, detection_model=model, slice_height=sahi_params["slice_height"],
|
||||
slice_width=sahi_params["slice_width"], overlap_height_ratio=sahi_params["overlap_height_ratio"],
|
||||
overlap_width_ratio=sahi_params["overlap_width_ratio"], verbose=0
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"SAHI prediction failed with params {sahi_params}: {e}", exc_info=True)
|
||||
return []
|
||||
object_predictions_original = sahi_result.object_prediction_list
|
||||
valid_original_predictions = []
|
||||
for pred in object_predictions_original:
|
||||
cat_id = int(pred.category.id)
|
||||
if 0 <= cat_id < len(model_class_names):
|
||||
valid_original_predictions.append(pred)
|
||||
else:
|
||||
logger.warning(f"SAHI Raw: Detected class ID {cat_id} for {model.__class__.__name__} is out of range for model_class_names: {model_class_names}. Ignoring.")
|
||||
return valid_original_predictions
|
||||
|
||||
def detect_fire_smoke_with_sahi(image_cv: np.ndarray, extract: bool = False):
|
||||
global detection_model_fire_smoke
|
||||
if detection_model_fire_smoke is None:
|
||||
logger.warning("Fire_smoke model was not loaded. Attempting to initialize now.")
|
||||
initialize_fire_smoke_model()
|
||||
if detection_model_fire_smoke is None:
|
||||
logger.error("Fire_smoke model could not be initialized for detection.")
|
||||
return {"error": "Fire_smoke model is not available."}
|
||||
|
||||
logger.info(f"Executing fire_smoke detection, extract flag is: {extract}")
|
||||
_obj_preds_all_model_outputs = []
|
||||
|
||||
logger.info(f"Attempting fire_smoke detection with SAHI (Large Slices)...")
|
||||
obj_preds_attempt1 = _run_sahi_and_process_results(
|
||||
image_cv, detection_model_fire_smoke, SAHI_PARAMS_LARGE, CLASS_NAMES_FIRE_SMOKE
|
||||
)
|
||||
if obj_preds_attempt1:
|
||||
_obj_preds_all_model_outputs = obj_preds_attempt1
|
||||
else:
|
||||
logger.info("No SAHI fire_smoke detections (large slices). Skipping small slices attempt as requested.")
|
||||
_obj_preds_all_model_outputs = []
|
||||
|
||||
targets_output_list = []
|
||||
for pred in _obj_preds_all_model_outputs:
|
||||
class_id = int(pred.category.id)
|
||||
class_name = CLASS_NAMES_FIRE_SMOKE[class_id]
|
||||
minx, miny, maxx, maxy = pred.bbox.minx, pred.bbox.miny, pred.bbox.maxx, pred.bbox.maxy
|
||||
bbox_width = maxx - minx
|
||||
bbox_height = maxy - miny
|
||||
targets_output_list.append({
|
||||
"type": class_name,
|
||||
"size": [int(bbox_width), int(bbox_height)],
|
||||
"leftTopPoint": [int(minx), int(miny)],
|
||||
"score": round(float(pred.score.value), 4)
|
||||
})
|
||||
|
||||
logger.info(f"Prepared {len(targets_output_list)} fire/smoke targets for API response.")
|
||||
return targets_output_list
|
111
utils/hardhat_detector copy.py
Normal file
111
utils/hardhat_detector copy.py
Normal file
@ -0,0 +1,111 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
import supervision as sv
|
||||
from sahi import AutoDetectionModel
|
||||
from sahi.predict import get_sliced_prediction
|
||||
# import base64 # No longer needed for primary output if base64 image isn't returned
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CLASS_NAMES_HELMET_MODEL_OUTPUT = ["helmet", "nohelmet", "vast", "novast"]
|
||||
TARGET_CLASS_FOR_REPORTING = "nohelmet"
|
||||
MODEL_WEIGHTS_PATH_HELMET = "models/helmet_model/best.pt"
|
||||
|
||||
SAHI_PARAMS_LARGE = {"slice_height": 1200, "slice_width": 1200, "overlap_height_ratio": 0.2, "overlap_width_ratio": 0.2}
|
||||
SAHI_PARAMS_SMALL = {"slice_height": 700, "slice_width": 700, "overlap_height_ratio": 0.2, "overlap_width_ratio": 0.2}
|
||||
|
||||
detection_model_helmet = None
|
||||
|
||||
def initialize_helmet_model():
|
||||
global detection_model_helmet
|
||||
if detection_model_helmet is None:
|
||||
logger.info(f"Loading helmet detection model from {MODEL_WEIGHTS_PATH_HELMET}...")
|
||||
try:
|
||||
detection_model_helmet = AutoDetectionModel.from_pretrained(
|
||||
model_type="ultralytics",
|
||||
model_path=MODEL_WEIGHTS_PATH_HELMET,
|
||||
confidence_threshold=0.8,
|
||||
)
|
||||
logger.info("Helmet detection model loaded successfully.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading helmet model: {e}", exc_info=True)
|
||||
raise RuntimeError(f"Could not load helmet model: {e}")
|
||||
return detection_model_helmet
|
||||
|
||||
# encode_image_to_base64 is no longer strictly needed if not returning annotated image
|
||||
|
||||
def _run_sahi_and_process_results(image_cv: np.ndarray, model, sahi_params: dict, model_class_names: list):
|
||||
# This helper function is mostly the same, returns List[ObjectPrediction]
|
||||
try:
|
||||
sahi_result = get_sliced_prediction(
|
||||
image=image_cv, detection_model=model, slice_height=sahi_params["slice_height"],
|
||||
slice_width=sahi_params["slice_width"], overlap_height_ratio=sahi_params["overlap_height_ratio"],
|
||||
overlap_width_ratio=sahi_params["overlap_width_ratio"], verbose=0
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"SAHI prediction failed with params {sahi_params}: {e}", exc_info=True)
|
||||
return [] # Return empty list of predictions on SAHI error
|
||||
|
||||
object_predictions_original = sahi_result.object_prediction_list
|
||||
valid_original_predictions = []
|
||||
for pred in object_predictions_original:
|
||||
cat_id = int(pred.category.id)
|
||||
if 0 <= cat_id < len(model_class_names):
|
||||
valid_original_predictions.append(pred)
|
||||
else:
|
||||
logger.warning(f"SAHI Raw: Detected class ID {cat_id} for {model.__class__.__name__} is out of range for model_class_names: {model_class_names}. Ignoring.")
|
||||
return valid_original_predictions
|
||||
|
||||
|
||||
def detect_hardhat_with_sahi(image_cv: np.ndarray, extract: bool = False):
|
||||
global detection_model_helmet
|
||||
if detection_model_helmet is None:
|
||||
logger.warning("Helmet model was not loaded. Attempting to initialize now.")
|
||||
initialize_helmet_model()
|
||||
if detection_model_helmet is None:
|
||||
logger.error("Helmet model could not be initialized for detection.")
|
||||
return {"error": "Helmet model is not available."} # Return error dict
|
||||
|
||||
logger.info(f"Executing helmet task (reporting only '{TARGET_CLASS_FOR_REPORTING}'), extract flag is: {extract}")
|
||||
|
||||
_obj_preds_all_model_outputs = [] # Store all valid SAHI ObjectPrediction
|
||||
|
||||
logger.info(f"Attempting detection with SAHI (Large Slices)...")
|
||||
obj_preds_attempt1 = _run_sahi_and_process_results(
|
||||
image_cv, detection_model_helmet, SAHI_PARAMS_LARGE, CLASS_NAMES_HELMET_MODEL_OUTPUT
|
||||
)
|
||||
|
||||
if obj_preds_attempt1: # If list is not empty
|
||||
logger.info(f"Initial SAHI detections (large slices): {len(obj_preds_attempt1)}")
|
||||
_obj_preds_all_model_outputs = obj_preds_attempt1
|
||||
else:
|
||||
logger.info("No SAHI detections (large slices). Retrying with small slices...")
|
||||
obj_preds_attempt2 = _run_sahi_and_process_results(
|
||||
image_cv, detection_model_helmet, SAHI_PARAMS_SMALL, CLASS_NAMES_HELMET_MODEL_OUTPUT
|
||||
)
|
||||
if obj_preds_attempt2:
|
||||
logger.info(f"Initial SAHI detections (small slices): {len(obj_preds_attempt2)}")
|
||||
_obj_preds_all_model_outputs = obj_preds_attempt2
|
||||
else:
|
||||
logger.info("No initial SAHI detections found even with small slices.")
|
||||
|
||||
# --- Transform to new output format, filtering for TARGET_CLASS_FOR_REPORTING ---
|
||||
targets_output_list = []
|
||||
for pred in _obj_preds_all_model_outputs:
|
||||
class_id = int(pred.category.id)
|
||||
class_name = CLASS_NAMES_HELMET_MODEL_OUTPUT[class_id]
|
||||
|
||||
if class_name == TARGET_CLASS_FOR_REPORTING:
|
||||
minx, miny, maxx, maxy = pred.bbox.minx, pred.bbox.miny, pred.bbox.maxx, pred.bbox.maxy
|
||||
bbox_width = maxx - minx
|
||||
bbox_height = maxy - miny
|
||||
targets_output_list.append({
|
||||
"type": class_name,
|
||||
"size": [int(bbox_width), int(bbox_height)],
|
||||
"leftTopPoint": [int(minx), int(miny)],
|
||||
"score": round(float(pred.score.value), 4) # Rounded score
|
||||
})
|
||||
|
||||
logger.info(f"Prepared {len(targets_output_list)} '{TARGET_CLASS_FOR_REPORTING}' targets for API response.")
|
||||
return targets_output_list # Directly return list of target dicts
|
199
utils/hardhat_detector.py
Normal file
199
utils/hardhat_detector.py
Normal file
@ -0,0 +1,199 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
from sahi import AutoDetectionModel
|
||||
from sahi.predict import get_sliced_prediction
|
||||
import logging
|
||||
import os
|
||||
import torch
|
||||
from typing import Set, List # 导入Set和List以支持类型提示
|
||||
|
||||
# 使用在main.py中配置好的根logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# --- 1. 配置模型和类别信息 (这是需要您根据实际模型调整的部分) ---
|
||||
|
||||
# 新的“状态检测”模型的类别名称列表。顺序必须与模型训练时一致。
|
||||
CLASS_NAMES_PPE_STATE_MODEL = ["wearingall", "noequipment", "nohelmet", "novest"]
|
||||
|
||||
# 后端期望接收的违规类型字符串
|
||||
BACKEND_VIOLATION_CODE_NO_HELMET = "nohelmet"
|
||||
BACKEND_VIOLATION_CODE_NO_VEST = "novest"
|
||||
|
||||
# 新模型的权重文件路径
|
||||
MODEL_WEIGHTS_PATH_HELMET = "models/ppe_state_model/best.pt"
|
||||
|
||||
|
||||
# --- 2. 保留并适配现有的调用接口 ---
|
||||
|
||||
# SAHI切片参数
|
||||
SAHI_PARAMS = {"slice_height": 1440, "slice_width": 1440, "overlap_height_ratio": 0.3, "overlap_width_ratio": 0.3}
|
||||
|
||||
# 全局模型实例
|
||||
detection_model_helmet = None # 变量名保持'detection_model_helmet'与旧版一致
|
||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
def initialize_helmet_model():
|
||||
"""
|
||||
符合 main.py 调用约定的初始化函数。
|
||||
它加载新的 PPE 状态模型。
|
||||
"""
|
||||
global detection_model_helmet
|
||||
if detection_model_helmet is None:
|
||||
if not os.path.exists(MODEL_WEIGHTS_PATH_HELMET):
|
||||
error_msg = f"Model weights file not found at: {MODEL_WEIGHTS_PATH_HELMET}"
|
||||
logger.error(error_msg)
|
||||
raise FileNotFoundError(error_msg)
|
||||
|
||||
logger.info(f"Loading hardhat state detection model from {MODEL_WEIGHTS_PATH_HELMET}...")
|
||||
try:
|
||||
detection_model_helmet = AutoDetectionModel.from_pretrained(
|
||||
model_type="yolov8",
|
||||
model_path=MODEL_WEIGHTS_PATH_HELMET,
|
||||
confidence_threshold=0.83,
|
||||
device=DEVICE,
|
||||
)
|
||||
# 可选的验证步骤
|
||||
loaded_class_names = list(detection_model_helmet.model.names.values())
|
||||
if sorted(loaded_class_names) != sorted(CLASS_NAMES_PPE_STATE_MODEL):
|
||||
logger.warning(f"Model class names mismatch! Code expects {CLASS_NAMES_PPE_STATE_MODEL} "
|
||||
f"but model has {loaded_class_names}. Please check consistency.")
|
||||
|
||||
logger.info("hardhat state detection model loaded successfully.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading hardhat state model: {e}", exc_info=True)
|
||||
raise RuntimeError(f"Could not load hardhat state model: {e}")
|
||||
return detection_model_helmet
|
||||
|
||||
def _run_sahi_and_process_results(image_cv: np.ndarray, model, sahi_params: dict, model_class_names: list) -> List:
|
||||
"""
|
||||
辅助函数,执行SAHI推理并返回所有有效预测的列表。
|
||||
"""
|
||||
try:
|
||||
sahi_result = get_sliced_prediction(
|
||||
image=image_cv,
|
||||
detection_model=model,
|
||||
slice_height=sahi_params["slice_height"],
|
||||
slice_width=sahi_params["slice_width"],
|
||||
overlap_height_ratio=sahi_params["overlap_height_ratio"],
|
||||
overlap_width_ratio=sahi_params["overlap_width_ratio"],
|
||||
verbose=0,
|
||||
postprocess_type="GREEDYNMM",
|
||||
postprocess_match_metric="IOS",
|
||||
postprocess_match_threshold=0.5
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"SAHI prediction failed with params {sahi_params}: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
object_predictions = sahi_result.object_prediction_list if sahi_result else []
|
||||
|
||||
valid_predictions = []
|
||||
for pred in object_predictions:
|
||||
cat_id = int(pred.category.id)
|
||||
if 0 <= cat_id < len(model_class_names):
|
||||
valid_predictions.append(pred)
|
||||
else:
|
||||
logger.warning(f"SAHI Raw: Detected class ID {cat_id} is out of range. Ignoring prediction.")
|
||||
return valid_predictions
|
||||
|
||||
def map_model_output_to_backend_violations(class_name: str) -> list[str]:
|
||||
"""
|
||||
将模型输出的状态类别名“翻译”成后端需要的违规代码列表。
|
||||
"""
|
||||
if class_name == "nohelmet":
|
||||
return [BACKEND_VIOLATION_CODE_NO_HELMET]
|
||||
elif class_name == "novest":
|
||||
return [BACKEND_VIOLATION_CODE_NO_VEST]
|
||||
elif class_name == "noequipment":
|
||||
return [BACKEND_VIOLATION_CODE_NO_HELMET, BACKEND_VIOLATION_CODE_NO_VEST]
|
||||
return []
|
||||
|
||||
def detect_hardhat_with_sahi(image_cv: np.ndarray, requested_violations: Set[str], extract: bool = False):
|
||||
"""
|
||||
主检测函数现在接收一个`requested_violations`集合参数,用于过滤最终的输出。
|
||||
"""
|
||||
global detection_model_helmet
|
||||
if detection_model_helmet is None:
|
||||
logger.warning("hardhat state model was not loaded. Initializing on first call.")
|
||||
try:
|
||||
initialize_helmet_model()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize model on demand: {e}")
|
||||
raise RuntimeError(f"hardhat model could not be initialized: {e}")
|
||||
|
||||
logger.info(f"Executing PPE detection, requested violations: {requested_violations}...")
|
||||
|
||||
img_height, img_width, _ = image_cv.shape
|
||||
|
||||
raw_sahi_predictions = _run_sahi_and_process_results(
|
||||
image_cv, detection_model_helmet, SAHI_PARAMS, CLASS_NAMES_PPE_STATE_MODEL
|
||||
)
|
||||
if not raw_sahi_predictions:
|
||||
logger.info("No raw objects detected by SAHI.")
|
||||
return []
|
||||
|
||||
logger.info(f"SAHI raw detection count: {len(raw_sahi_predictions)}. Applying post-processing filters...")
|
||||
|
||||
# --- 后处理过滤逻辑 ---
|
||||
filtered_predictions = []
|
||||
MIN_PERSON_HEIGHT_PX = 40
|
||||
MAX_PERSON_WIDTH_RATIO = 0.21
|
||||
MAX_PERSON_HEIGHT_RATIO = 0.33
|
||||
MIN_ASPECT_RATIO = 1.2
|
||||
MAX_ASPECT_RATIO = 5.0
|
||||
FEET_POSITION_THRESHOLD_RATIO = 0.05
|
||||
|
||||
for pred in raw_sahi_predictions:
|
||||
minx, miny, maxx, maxy = map(int, pred.bbox.to_xyxy())
|
||||
score = pred.score.value
|
||||
bbox_width = maxx - minx
|
||||
bbox_height = maxy - miny
|
||||
if bbox_height < MIN_PERSON_HEIGHT_PX: continue
|
||||
relative_width = bbox_width / img_width
|
||||
relative_height = bbox_height / img_height
|
||||
if relative_width > MAX_PERSON_WIDTH_RATIO or relative_height > MAX_PERSON_HEIGHT_RATIO: continue
|
||||
if bbox_width == 0: continue
|
||||
aspect_ratio = bbox_height / bbox_width
|
||||
if not (MIN_ASPECT_RATIO <= aspect_ratio <= MAX_ASPECT_RATIO): continue
|
||||
if maxy < (img_height * FEET_POSITION_THRESHOLD_RATIO): continue
|
||||
bbox_area = bbox_width * bbox_height
|
||||
image_area = img_width * img_height
|
||||
if bbox_area > (image_area * 0.1) and score < 0.5: continue
|
||||
filtered_predictions.append(pred)
|
||||
|
||||
logger.info(f"After filtering, {len(filtered_predictions)} valid person states remain. Mapping to violations...")
|
||||
|
||||
# --- 映射和最终过滤逻辑 ---
|
||||
targets_output_list = []
|
||||
for pred in filtered_predictions:
|
||||
class_id = int(pred.category.id)
|
||||
class_name = CLASS_NAMES_PPE_STATE_MODEL[class_id]
|
||||
score = round(float(pred.score.value), 4)
|
||||
minx, miny, maxx, maxy = map(int, pred.bbox.to_xyxy())
|
||||
|
||||
logger.info(f"Processing prediction -> Model says: '{class_name}', Score: {score}, BBox: [{minx},{miny},{maxx},{maxy}]")
|
||||
|
||||
all_possible_violations = map_model_output_to_backend_violations(class_name)
|
||||
|
||||
final_violations_to_report = [
|
||||
v for v in all_possible_violations if v in requested_violations
|
||||
]
|
||||
|
||||
if not final_violations_to_report:
|
||||
continue
|
||||
|
||||
bbox_width = maxx - minx
|
||||
bbox_height = maxy - miny
|
||||
|
||||
for code in final_violations_to_report:
|
||||
logger.info(f" └─ Violation Found & Requested: Type: '{code}', BBox: [{minx},{miny},{maxx},{maxy}]")
|
||||
targets_output_list.append({
|
||||
"type": code,
|
||||
"size": [int(bbox_width), int(bbox_height)],
|
||||
"leftTopPoint": [int(minx), int(miny)],
|
||||
"score": score
|
||||
})
|
||||
|
||||
logger.info(f"Prepared {len(targets_output_list)} filtered violation targets for response.")
|
||||
return targets_output_list
|
163
utils/pho_detector.py
Normal file
163
utils/pho_detector.py
Normal file
@ -0,0 +1,163 @@
|
||||
# utils/pho_detector.py
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from sahi import AutoDetectionModel
|
||||
from sahi.predict import get_sliced_prediction
|
||||
import logging
|
||||
from typing import List, Dict, Any
|
||||
|
||||
# --- Basic Configuration ---
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# --- Module-Specific Configuration ---
|
||||
|
||||
# All classes this module can detect.
|
||||
ALL_SUPPORTED_CLASSES = ['pho', 'shelves', 'pile', 'hole']
|
||||
|
||||
# Configuration for the first model (detects 'pho' and 'shelves').
|
||||
MODEL_1_CONFIG = {
|
||||
"name": "pho_model",
|
||||
# IMPORTANT: Update this path to your actual model location.
|
||||
"weights_path": "models/photovoltaic_model/pho_model/best.pt",
|
||||
# Class order must match the model's training IDs (0: 'pho', 1: 'shelves').
|
||||
"class_names": ['pho', 'shelves'],
|
||||
"confidence_threshold": 0.7
|
||||
}
|
||||
|
||||
# Configuration for the second model (detects 'pile' and 'hole').
|
||||
MODEL_2_CONFIG = {
|
||||
"name": "pile_model",
|
||||
# IMPORTANT: Update this path to your actual model location.
|
||||
"weights_path": "models/photovoltaic_model/pile_model/best.pt",
|
||||
# Class order must match the model's training IDs (0: 'pile', 1: 'hole').
|
||||
"class_names": ['pile', 'hole'],
|
||||
"confidence_threshold": 0.7
|
||||
}
|
||||
|
||||
# Unified SAHI slicing parameters for all tasks in this module.
|
||||
SAHI_PARAMS = {
|
||||
"slice_height": 800,
|
||||
"slice_width": 800,
|
||||
"overlap_height_ratio": 0.2,
|
||||
"overlap_width_ratio": 0.2
|
||||
}
|
||||
|
||||
# Global dictionary to cache loaded model instances.
|
||||
detection_models: Dict[str, AutoDetectionModel] = {}
|
||||
|
||||
# --- Model Management Functions ---
|
||||
|
||||
def initialize_model(model_config: Dict[str, Any]) -> AutoDetectionModel:
|
||||
"""Initializes a single model based on its configuration if not already loaded."""
|
||||
global detection_models
|
||||
model_name = model_config["name"]
|
||||
|
||||
if model_name not in detection_models:
|
||||
logger.info(f"Loading detection model '{model_name}' from {model_config['weights_path']}...")
|
||||
try:
|
||||
model = AutoDetectionModel.from_pretrained(
|
||||
model_type="ultralytics",
|
||||
model_path=model_config["weights_path"],
|
||||
confidence_threshold=model_config["confidence_threshold"],
|
||||
)
|
||||
detection_models[model_name] = model
|
||||
logger.info(f"Model '{model_name}' loaded successfully.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading model '{model_name}': {e}", exc_info=True)
|
||||
raise RuntimeError(f"Could not load model '{model_name}': {e}")
|
||||
|
||||
return detection_models[model_name]
|
||||
|
||||
def initialize_all_pho_models():
|
||||
"""Initializes all models defined in this module. Called once at application startup."""
|
||||
logger.info("Pre-initializing all photovoltaic-related models...")
|
||||
initialize_model(MODEL_1_CONFIG)
|
||||
initialize_model(MODEL_2_CONFIG)
|
||||
logger.info("All photovoltaic models have been initialized.")
|
||||
|
||||
# --- Inference and Core Logic ---
|
||||
|
||||
def _run_inference_for_model(image_cv: np.ndarray, model_config: Dict[str, Any]) -> List[Dict]:
|
||||
"""Internal function to run SAHI prediction for a single specified model."""
|
||||
try:
|
||||
model = initialize_model(model_config)
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Cannot run inference; model '{model_config['name']}' failed to initialize: {e}")
|
||||
return []
|
||||
|
||||
logger.info(f"Running SAHI prediction with model: {model_config['name']}...")
|
||||
try:
|
||||
sahi_result = get_sliced_prediction(
|
||||
image=image_cv,
|
||||
detection_model=model,
|
||||
**SAHI_PARAMS,
|
||||
verbose=0
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"SAHI prediction failed for model '{model_config['name']}': {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
targets_output_list = []
|
||||
# Use the specific class list for this model to map IDs to names.
|
||||
model_class_names = model_config["class_names"]
|
||||
|
||||
for pred in sahi_result.object_prediction_list:
|
||||
cat_id = int(pred.category.id)
|
||||
|
||||
# Core Logic: Map the model's local ID to its correct global class name.
|
||||
if 0 <= cat_id < len(model_class_names):
|
||||
class_name = model_class_names[cat_id]
|
||||
minx, miny, maxx, maxy = pred.bbox.minx, pred.bbox.miny, pred.bbox.maxx, pred.bbox.maxy
|
||||
|
||||
targets_output_list.append({
|
||||
"type": class_name,
|
||||
"size": [int(maxx - minx), int(maxy - miny)],
|
||||
"leftTopPoint": [int(minx), int(miny)],
|
||||
"score": round(float(pred.score.value), 4)
|
||||
})
|
||||
else:
|
||||
logger.warning(f"Model '{model_config['name']}' detected an out-of-range class ID: {cat_id}. Ignoring.")
|
||||
|
||||
logger.info(f"Model '{model_config['name']}' found {len(targets_output_list)} raw objects.")
|
||||
return targets_output_list
|
||||
|
||||
def detect_objects_by_class(image_cv: np.ndarray, target_classes: List[str]) -> List[Dict]:
|
||||
"""
|
||||
Main public function. Detects objects by intelligently dispatching to the correct model(s).
|
||||
|
||||
Args:
|
||||
image_cv: The input OpenCV image.
|
||||
target_classes: A list of class names to detect (e.g., ['pho', 'hole']).
|
||||
|
||||
Returns:
|
||||
A list of detection dictionaries for the successfully found targets.
|
||||
"""
|
||||
final_results = []
|
||||
|
||||
# Determine which targets belong to Model 1.
|
||||
model_1_targets = [cls for cls in target_classes if cls in MODEL_1_CONFIG["class_names"]]
|
||||
# Determine which targets belong to Model 2.
|
||||
model_2_targets = [cls for cls in target_classes if cls in MODEL_2_CONFIG["class_names"]]
|
||||
|
||||
# Run Model 1 only if it's needed.
|
||||
if model_1_targets:
|
||||
logger.info(f"Model 1 ('{MODEL_1_CONFIG['name']}') triggered for targets: {model_1_targets}")
|
||||
model_1_results = _run_inference_for_model(image_cv, MODEL_1_CONFIG)
|
||||
# Filter the raw results to only include what the user asked for.
|
||||
final_results.extend([res for res in model_1_results if res["type"] in model_1_targets])
|
||||
|
||||
# Run Model 2 only if it's needed.
|
||||
if model_2_targets:
|
||||
logger.info(f"Model 2 ('{MODEL_2_CONFIG['name']}') triggered for targets: {model_2_targets}")
|
||||
model_2_results = _run_inference_for_model(image_cv, MODEL_2_CONFIG)
|
||||
# Filter the raw results to only include what the user asked for.
|
||||
final_results.extend([res for res in model_2_results if res["type"] in model_2_targets])
|
||||
|
||||
logger.info(f"Photovoltaic detection complete. Returning {len(final_results)} filtered targets.")
|
||||
return final_results
|
||||
|
||||
# --- Export for main.py ---
|
||||
# This list is imported by main.py to know which tasks to delegate to this module.
|
||||
PHO_DETECTOR_SUPPORTED_CLASSES = ALL_SUPPORTED_CLASSES
|
258
utils/smoking_detector.py
Normal file
258
utils/smoking_detector.py
Normal file
@ -0,0 +1,258 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
from sahi import AutoDetectionModel
|
||||
from sahi.predict import get_sliced_prediction
|
||||
import logging
|
||||
import torch # 用于设备检测和YOLO模型加载
|
||||
from ultralytics import YOLO # 用于阶段二的ROI检测
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(level=logging.INFO) # 添加这行可以看到logger的输出
|
||||
|
||||
# --- 更新的配置 ---
|
||||
CLASS_NAMES_BEHAVIOR = ["smoke", "face"]
|
||||
CIGARETTE_CLASS_ID = 0
|
||||
FACE_CLASS_ID = 1
|
||||
|
||||
MODEL_WEIGHTS_PATH_BEHAVIOR = "models/smoking_model/best.pt"
|
||||
|
||||
# SAHI 切片参数
|
||||
SAHI_PARAMS_FACE_LARGE = {"slice_height": 1280, "slice_width": 1280, "overlap_height_ratio": 0.2, "overlap_width_ratio": 0.2}
|
||||
SAHI_PARAMS_FACE_SMALL = {"slice_height": 640, "slice_width": 640, "overlap_height_ratio": 0.2, "overlap_width_ratio": 0.2}
|
||||
|
||||
# ROI 内香烟检测的YOLO置信度
|
||||
CIGARETTE_IN_ROI_CONF_THRESHOLD = 0.6
|
||||
FACE_ROI_PADDING = 30
|
||||
|
||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
# 全局模型实例
|
||||
detection_model_sahi_behavior = None
|
||||
yolo_model_direct_behavior = None
|
||||
|
||||
|
||||
def initialize_smoking_model():
|
||||
global detection_model_sahi_behavior, yolo_model_direct_behavior
|
||||
if detection_model_sahi_behavior is None or yolo_model_direct_behavior is None:
|
||||
logger.info(f"Loading behavior detection model (face & cigarette) from {MODEL_WEIGHTS_PATH_BEHAVIOR}...")
|
||||
try:
|
||||
# 1. SAHI封装的模型,用于全局切片检测
|
||||
detection_model_sahi_behavior = AutoDetectionModel.from_pretrained(
|
||||
model_type="ultralytics",
|
||||
model_path=MODEL_WEIGHTS_PATH_BEHAVIOR,
|
||||
confidence_threshold=0.7,
|
||||
device=DEVICE,
|
||||
)
|
||||
# 2. Ultralytics直接加载的模型,用于快速的ROI检测
|
||||
yolo_model_direct_behavior = YOLO(MODEL_WEIGHTS_PATH_BEHAVIOR)
|
||||
yolo_model_direct_behavior.to(DEVICE)
|
||||
|
||||
if hasattr(yolo_model_direct_behavior, 'names') and yolo_model_direct_behavior.names:
|
||||
model_names = list(yolo_model_direct_behavior.names.values())
|
||||
if model_names != CLASS_NAMES_BEHAVIOR and sorted(model_names) != sorted(CLASS_NAMES_BEHAVIOR) :
|
||||
logger.warning(f"Mismatch in class names! Expected from code: {CLASS_NAMES_BEHAVIOR}, "
|
||||
f"Got from model: {model_names}. Please verify consistency.")
|
||||
else:
|
||||
logger.info(f"Model class names confirmed: {model_names}")
|
||||
else:
|
||||
logger.warning("Could not retrieve class names from loaded yolo_model_direct_behavior for verification.")
|
||||
|
||||
logger.info("Behavior detection models (SAHI and Direct YOLO) loaded successfully.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading behavior model: {e}", exc_info=True)
|
||||
detection_model_sahi_behavior = None
|
||||
yolo_model_direct_behavior = None
|
||||
raise RuntimeError(f"Could not load behavior model: {e}")
|
||||
return detection_model_sahi_behavior, yolo_model_direct_behavior
|
||||
|
||||
def _run_sahi_and_process_results(image_cv: np.ndarray, model, sahi_params: dict, target_class_id: int, target_class_name: str, model_actual_class_names: list):
|
||||
try:
|
||||
sahi_result = get_sliced_prediction(
|
||||
image=image_cv, detection_model=model, slice_height=sahi_params["slice_height"],
|
||||
slice_width=sahi_params["slice_width"], overlap_height_ratio=sahi_params["overlap_height_ratio"],
|
||||
overlap_width_ratio=sahi_params["overlap_width_ratio"], verbose=0,
|
||||
postprocess_type="GREEDYNMM",
|
||||
postprocess_match_metric="IOS",
|
||||
postprocess_match_threshold=0.7,
|
||||
postprocess_class_agnostic=False
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"SAHI prediction failed for class '{target_class_name}' with params {sahi_params}: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
valid_predictions = []
|
||||
if hasattr(sahi_result, 'object_prediction_list'):
|
||||
for pred in sahi_result.object_prediction_list:
|
||||
cat_id_from_sahi = int(pred.category.id)
|
||||
|
||||
if not (0 <= cat_id_from_sahi < len(model_actual_class_names)):
|
||||
logger.warning(f"SAHI Raw: Detected class ID {cat_id_from_sahi} for '{target_class_name}' "
|
||||
f"is out of range for actual model class names ({len(model_actual_class_names)} classes). Ignoring pred.")
|
||||
continue
|
||||
|
||||
if cat_id_from_sahi == target_class_id:
|
||||
valid_predictions.append(pred)
|
||||
|
||||
logger.info(f"SAHI found {len(valid_predictions)} instances of '{target_class_name}' with params {sahi_params}.")
|
||||
return valid_predictions
|
||||
|
||||
|
||||
def detect_smoking_behavior(image_cv: np.ndarray, extract: bool = False):
|
||||
global detection_model_sahi_behavior, yolo_model_direct_behavior
|
||||
|
||||
if detection_model_sahi_behavior is None or yolo_model_direct_behavior is None:
|
||||
logger.warning("Behavior model (SAHI or Direct YOLO) was not loaded. Attempting to initialize now.")
|
||||
try:
|
||||
initialize_smoking_model()
|
||||
except RuntimeError:
|
||||
logger.error("Behavior model could not be initialized for detection.")
|
||||
return {"error": "Behavior model is not available."}
|
||||
|
||||
if detection_model_sahi_behavior is None or yolo_model_direct_behavior is None:
|
||||
logger.error("Behavior model components are still None after initialization attempt.")
|
||||
return {"error": "Behavior model failed to initialize."}
|
||||
|
||||
model_actual_class_names_dict = yolo_model_direct_behavior.names
|
||||
if not model_actual_class_names_dict:
|
||||
logger.error("Cannot determine actual class names from the loaded YOLO model.")
|
||||
return {"error": "Model class names missing."}
|
||||
|
||||
max_id = max(model_actual_class_names_dict.keys())
|
||||
model_actual_class_names_list = [""] * (max_id + 1)
|
||||
for id_val, name_val in model_actual_class_names_dict.items():
|
||||
model_actual_class_names_list[id_val] = name_val
|
||||
|
||||
logger.info(f"Executing two-stage smoking behavior detection...")
|
||||
|
||||
# --- 阶段一: 使用SAHI检测人脸 ---
|
||||
detected_faces_sahi = []
|
||||
logger.info(f"Attempting FACE detection with SAHI (Large Slices {SAHI_PARAMS_FACE_LARGE})...")
|
||||
face_preds_large = _run_sahi_and_process_results(
|
||||
image_cv, detection_model_sahi_behavior, SAHI_PARAMS_FACE_LARGE, FACE_CLASS_ID, CLASS_NAMES_BEHAVIOR[FACE_CLASS_ID], model_actual_class_names_list
|
||||
)
|
||||
if face_preds_large:
|
||||
detected_faces_sahi = face_preds_large
|
||||
else:
|
||||
logger.info(f"No SAHI face detections (large slices). Retrying with small slices {SAHI_PARAMS_FACE_SMALL}...")
|
||||
face_preds_small = _run_sahi_and_process_results(
|
||||
image_cv, detection_model_sahi_behavior, SAHI_PARAMS_FACE_SMALL, FACE_CLASS_ID, CLASS_NAMES_BEHAVIOR[FACE_CLASS_ID], model_actual_class_names_list
|
||||
)
|
||||
if face_preds_small:
|
||||
detected_faces_sahi = face_preds_small
|
||||
else:
|
||||
logger.info("No SAHI face detections found even with small slices.")
|
||||
|
||||
# <<<--- 逻辑增强点 1:对检测到的人脸进行严格的资格审查 --- START --->>>
|
||||
if detected_faces_sahi:
|
||||
original_face_count = len(detected_faces_sahi)
|
||||
qualified_faces = []
|
||||
for face_pred in detected_faces_sahi:
|
||||
# 1. 置信度审查 (您可以调整这个阈值)
|
||||
if face_pred.score.value < 0.65:
|
||||
continue
|
||||
|
||||
# 2. 尺寸审查 (您可以调整这些像素值)
|
||||
x1, y1, x2, y2 = face_pred.bbox.to_xyxy()
|
||||
width, height = x2 - x1, y2 - y1
|
||||
if width < 32 or height < 32:
|
||||
continue
|
||||
|
||||
# 3. 长宽比审查 (您可以调整这个范围)
|
||||
aspect_ratio = width / height if height > 0 else 0
|
||||
if not (0.6 < aspect_ratio < 1.6):
|
||||
continue
|
||||
|
||||
qualified_faces.append(face_pred)
|
||||
|
||||
logger.info(f"Face Qualification: Initial candidates: {original_face_count}, Qualified after filtering: {len(qualified_faces)}")
|
||||
detected_faces_sahi = qualified_faces # 用审查合格的列表覆盖原来的
|
||||
# <<<--- 逻辑增强点 1 --- END --->>>
|
||||
|
||||
targets_output_list = []
|
||||
if not detected_faces_sahi:
|
||||
logger.info("No QUALIFIED faces found in Stage 1. No smoking behavior to report.")
|
||||
return targets_output_list
|
||||
|
||||
logger.info(f"Found {len(detected_faces_sahi)} qualified faces in Stage 1. Proceeding to Stage 2...")
|
||||
|
||||
# --- 阶段二: 在每个人脸ROI内使用直接YOLO检测香烟 ---
|
||||
for face_pred_sahi in detected_faces_sahi:
|
||||
x1_face, y1_face, x2_face, y2_face = map(int, face_pred_sahi.bbox.to_xyxy())
|
||||
|
||||
# 定义并裁剪ROI
|
||||
roi_x1 = max(0, x1_face - FACE_ROI_PADDING)
|
||||
roi_y1 = max(0, y1_face - FACE_ROI_PADDING)
|
||||
roi_x2 = min(image_cv.shape[1], x2_face + FACE_ROI_PADDING)
|
||||
roi_y2 = min(image_cv.shape[0], y2_face + FACE_ROI_PADDING)
|
||||
face_roi_crop = image_cv[roi_y1:roi_y2, roi_x1:roi_x2]
|
||||
|
||||
if face_roi_crop.size == 0:
|
||||
logger.debug(f"Face ROI at [{x1_face},{y1_face},{x2_face},{y2_face}] resulted in empty crop. Skipping.")
|
||||
continue
|
||||
|
||||
# <<<--- 逻辑增强点 2:在ROI内进行交叉验证 --- START --->>>
|
||||
roi_results = yolo_model_direct_behavior.predict(source=face_roi_crop, verbose=False, device=DEVICE, conf=CIGARETTE_IN_ROI_CONF_THRESHOLD)
|
||||
|
||||
is_face_confirmed_in_roi = False
|
||||
found_cigarette_in_roi = False
|
||||
cigarette_details_for_this_face = []
|
||||
|
||||
if roi_results and len(roi_results) > 0:
|
||||
result_for_roi = roi_results[0]
|
||||
for box in result_for_roi.boxes:
|
||||
class_id = int(box.cls)
|
||||
if class_id == FACE_CLASS_ID:
|
||||
is_face_confirmed_in_roi = True
|
||||
elif class_id == CIGARETTE_CLASS_ID:
|
||||
xc1_roi, yc1_roi, xc2_roi, yc2_roi = map(int, box.xyxy[0].tolist())
|
||||
|
||||
cig_minx_global = xc1_roi + roi_x1
|
||||
cig_miny_global = yc1_roi + roi_y1
|
||||
cig_maxx_global = xc2_roi + roi_x1
|
||||
cig_maxy_global = yc2_roi + roi_y1
|
||||
|
||||
cig_width_global = cig_maxx_global - cig_minx_global
|
||||
cig_height_global = cig_maxy_global - cig_miny_global
|
||||
|
||||
# <<<--- 逻辑增强点 3:对香烟进行严格的尺寸和比例审查 --- START --->>>
|
||||
long_side = max(cig_width_global, cig_height_global)
|
||||
short_side = min(cig_width_global, cig_height_global)
|
||||
|
||||
# 1. 绝对尺寸审查 (根据您的要求)
|
||||
if long_side > 100 or short_side > 40:
|
||||
logger.info(f" REJECTED Cigarette: Absolute size out of bounds. long={long_side}, short={short_side}")
|
||||
continue
|
||||
|
||||
# 2. 长宽比审查 (针对电线杆等误检)
|
||||
if short_side > 0:
|
||||
aspect_ratio = long_side / short_side
|
||||
# 可调:一个正常的香烟,其长度应远大于宽度。这里设定一个范围来排除比例怪异的物体。
|
||||
if not (2.5 < aspect_ratio < 25.0):
|
||||
logger.info(f" REJECTED Cigarette: Aspect ratio out of bounds. ratio={aspect_ratio:.2f}")
|
||||
continue
|
||||
elif long_side > 0: # 处理 short_side为0 但 long_side不为0 的细线情况
|
||||
logger.info(f" REJECTED Cigarette: BBox is a zero-width/height line.")
|
||||
continue
|
||||
else: # 宽和高都为0,跳过
|
||||
continue
|
||||
# <<<--- 逻辑增强点 3 --- END --->>>
|
||||
|
||||
# 只有通过了所有审查的香烟才被视为有效
|
||||
found_cigarette_in_roi = True
|
||||
cigarette_details_for_this_face.append({
|
||||
"type": CLASS_NAMES_BEHAVIOR[CIGARETTE_CLASS_ID],
|
||||
"size": [int(cig_width_global), int(cig_height_global)],
|
||||
"leftTopPoint": [int(cig_minx_global), int(cig_miny_global)],
|
||||
"score": round(float(box.conf), 4)
|
||||
})
|
||||
|
||||
# 最终裁定:必须在ROI内同时确认有人脸和(合格的)香烟
|
||||
if is_face_confirmed_in_roi and found_cigarette_in_roi:
|
||||
logger.info(f" SUCCESS: Cross-validation passed. Face confirmed and qualified cigarette found in ROI of face at [{x1_face},{y1_face},{x2_face},{y2_face}].")
|
||||
targets_output_list.extend(cigarette_details_for_this_face)
|
||||
else:
|
||||
logger.info(f" INFO: Cross-validation failed for face ROI. Face confirmed in ROI: {is_face_confirmed_in_roi}, Cigarette found in ROI: {found_cigarette_in_roi}. Ignoring results for this face.")
|
||||
# <<<--- 逻辑增强点 2 --- END --->>>
|
||||
|
||||
logger.info(f"Processed all faces. Final targets_output_list contains {len(targets_output_list)} cigarette detections after all logical checks.")
|
||||
return targets_output_list
|
Reference in New Issue
Block a user