2025-09-03 16:22:21 +08:00
|
|
|
|
import os
|
2025-09-03 14:38:42 +08:00
|
|
|
|
|
2025-09-03 16:22:21 +08:00
|
|
|
|
import cv2
|
|
|
|
|
import yaml
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
from .ocr_violation_detector import OCRViolationDetector
|
|
|
|
|
from .yolo_violation_detector import ViolationDetector as YoloViolationDetector
|
|
|
|
|
from .face_recognizer import FaceRecognizer
|
2025-09-03 14:38:42 +08:00
|
|
|
|
|
|
|
|
|
class MultiModelViolationDetector:
|
|
|
|
|
"""
|
2025-09-03 16:22:21 +08:00
|
|
|
|
多模型违规检测封装类,串行调用OCR、人脸识别和YOLO模型,任一模型检测到违规即返回结果
|
2025-09-03 14:38:42 +08:00
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
|
forbidden_words_path: str,
|
2025-09-03 16:22:21 +08:00
|
|
|
|
ocr_config_path: str,
|
2025-09-03 14:38:42 +08:00
|
|
|
|
yolo_model_path: str,
|
|
|
|
|
known_faces_dir: str,
|
|
|
|
|
ocr_confidence_threshold: float = 0.5):
|
|
|
|
|
"""
|
|
|
|
|
初始化所有检测模型
|
|
|
|
|
"""
|
2025-09-03 16:22:21 +08:00
|
|
|
|
# 初始化OCR检测器
|
2025-09-03 14:38:42 +08:00
|
|
|
|
self.ocr_detector = OCRViolationDetector(
|
|
|
|
|
forbidden_words_path=forbidden_words_path,
|
2025-09-03 16:22:21 +08:00
|
|
|
|
ocr_config_path=ocr_config_path,
|
2025-09-03 14:38:42 +08:00
|
|
|
|
ocr_confidence_threshold=ocr_confidence_threshold
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# 初始化人脸识别器
|
|
|
|
|
self.face_recognizer = FaceRecognizer(
|
|
|
|
|
known_faces_dir=known_faces_dir
|
|
|
|
|
)
|
|
|
|
|
|
2025-09-03 16:22:21 +08:00
|
|
|
|
# 初始化YOLO检测器
|
2025-09-03 14:38:42 +08:00
|
|
|
|
self.yolo_detector = YoloViolationDetector(
|
|
|
|
|
model_path=yolo_model_path
|
|
|
|
|
)
|
|
|
|
|
|
2025-09-03 16:22:21 +08:00
|
|
|
|
print("多模型违规检测器初始化完成")
|
2025-09-03 14:38:42 +08:00
|
|
|
|
|
|
|
|
|
def detect_violations(self, frame):
|
|
|
|
|
"""
|
|
|
|
|
串行调用三个检测模型(OCR → 人脸识别 → YOLO),任一检测到违规即返回结果
|
|
|
|
|
"""
|
|
|
|
|
# 1. 首先进行OCR违禁词检测
|
|
|
|
|
try:
|
|
|
|
|
ocr_has_violation, ocr_words, ocr_confs = self.ocr_detector.detect(frame)
|
|
|
|
|
if ocr_has_violation:
|
|
|
|
|
details = {
|
|
|
|
|
"words": ocr_words,
|
|
|
|
|
"confidences": ocr_confs
|
|
|
|
|
}
|
2025-09-03 16:22:21 +08:00
|
|
|
|
print(f"警告: OCR检测到违禁内容: {details}")
|
2025-09-03 14:38:42 +08:00
|
|
|
|
return (True, "ocr", details)
|
|
|
|
|
except Exception as e:
|
2025-09-03 16:22:21 +08:00
|
|
|
|
print(f"错误: OCR检测出错: {str(e)}")
|
2025-09-03 14:38:42 +08:00
|
|
|
|
|
|
|
|
|
# 2. 接着进行人脸识别检测
|
|
|
|
|
try:
|
|
|
|
|
face_has_violation, face_name, face_similarity = self.face_recognizer.recognize(frame)
|
|
|
|
|
if face_has_violation:
|
|
|
|
|
details = {
|
|
|
|
|
"name": face_name,
|
|
|
|
|
"similarity": face_similarity
|
|
|
|
|
}
|
2025-09-03 16:22:21 +08:00
|
|
|
|
print(f"警告: 人脸识别到违规人员: {details}")
|
2025-09-03 14:38:42 +08:00
|
|
|
|
return (True, "face", details)
|
|
|
|
|
except Exception as e:
|
2025-09-03 16:22:21 +08:00
|
|
|
|
print(f"错误: 人脸识别出错: {str(e)}")
|
2025-09-03 14:38:42 +08:00
|
|
|
|
|
2025-09-03 16:22:21 +08:00
|
|
|
|
# 3. 最后进行YOLO目标检测
|
2025-09-03 14:38:42 +08:00
|
|
|
|
try:
|
|
|
|
|
yolo_results = self.yolo_detector.detect(frame)
|
|
|
|
|
if len(yolo_results.boxes) > 0:
|
|
|
|
|
details = {
|
|
|
|
|
"classes": yolo_results.names,
|
2025-09-03 16:22:21 +08:00
|
|
|
|
"boxes": yolo_results.boxes.xyxy.tolist(),
|
|
|
|
|
"confidences": yolo_results.boxes.conf.tolist(),
|
|
|
|
|
"class_ids": yolo_results.boxes.cls.tolist()
|
2025-09-03 14:38:42 +08:00
|
|
|
|
}
|
2025-09-03 16:22:21 +08:00
|
|
|
|
print(f"警告: YOLO检测到违规目标: {details}")
|
2025-09-03 14:38:42 +08:00
|
|
|
|
return (True, "yolo", details)
|
|
|
|
|
except Exception as e:
|
2025-09-03 16:22:21 +08:00
|
|
|
|
print(f"错误: YOLO检测出错: {str(e)}")
|
2025-09-03 14:38:42 +08:00
|
|
|
|
|
|
|
|
|
# 所有检测均未发现违规
|
|
|
|
|
return (False, None, None)
|
|
|
|
|
|
|
|
|
|
|
2025-09-03 16:22:21 +08:00
|
|
|
|
def load_config(config_path: str) -> dict:
|
|
|
|
|
"""加载YAML配置文件"""
|
|
|
|
|
try:
|
|
|
|
|
with open(config_path, 'r', encoding='utf-8') as f:
|
|
|
|
|
return yaml.safe_load(f)
|
|
|
|
|
except FileNotFoundError:
|
|
|
|
|
print(f"错误: 配置文件未找到: {config_path}")
|
|
|
|
|
raise
|
|
|
|
|
except yaml.YAMLError as e:
|
|
|
|
|
print(f"错误: 配置文件格式错误: {config_path}, 错误: {str(e)}")
|
|
|
|
|
raise
|
|
|
|
|
except Exception as e:
|
|
|
|
|
print(f"错误: 加载配置文件出错: {str(e)}")
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 使用示例
|
2025-09-03 14:38:42 +08:00
|
|
|
|
# if __name__ == "__main__":
|
2025-09-03 16:22:21 +08:00
|
|
|
|
# # 加载配置文件
|
|
|
|
|
# config = load_config("config.yaml") # 配置文件路径,可根据实际情况修改
|
2025-09-03 14:38:42 +08:00
|
|
|
|
#
|
|
|
|
|
# # 初始化多模型检测器
|
|
|
|
|
# detector = MultiModelViolationDetector(
|
2025-09-03 16:22:21 +08:00
|
|
|
|
# forbidden_words_path=config["forbidden_words_path"],
|
|
|
|
|
# ocr_config_path=config["ocr_config_path"],
|
|
|
|
|
# yolo_model_path=config["yolo_model_path"],
|
|
|
|
|
# known_faces_dir=config["known_faces_dir"],
|
|
|
|
|
# ocr_confidence_threshold=config.get("ocr_confidence_threshold", 0.5)
|
2025-09-03 14:38:42 +08:00
|
|
|
|
# )
|
|
|
|
|
#
|
|
|
|
|
# # 读取测试图像(可替换为视频帧读取逻辑)
|
2025-09-03 16:22:21 +08:00
|
|
|
|
# test_image_path = config.get("test_image_path") # 从配置文件获取测试图片路径
|
|
|
|
|
# if test_image_path:
|
|
|
|
|
# frame = cv2.imread(test_image_path)
|
2025-09-03 14:38:42 +08:00
|
|
|
|
#
|
2025-09-03 16:22:21 +08:00
|
|
|
|
# if frame is not None:
|
|
|
|
|
# has_violation, violation_type, details = detector.detect_violations(frame)
|
|
|
|
|
# if has_violation:
|
|
|
|
|
# print(f"检测到违规 - 类型: {violation_type}, 详情: {details}")
|
|
|
|
|
# else:
|
|
|
|
|
# print("未检测到任何违规内容")
|
2025-09-03 14:38:42 +08:00
|
|
|
|
# else:
|
2025-09-03 16:22:21 +08:00
|
|
|
|
# print(f"无法读取测试图像: {test_image_path}")
|
2025-09-03 14:38:42 +08:00
|
|
|
|
# else:
|
2025-09-03 16:22:21 +08:00
|
|
|
|
# print("配置文件中未指定测试图像路径")
|