Files
video/ocr/model_violation_detector.py
2025-09-03 16:22:21 +08:00

136 lines
4.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import cv2
import yaml
from pathlib import Path
from .ocr_violation_detector import OCRViolationDetector
from .yolo_violation_detector import ViolationDetector as YoloViolationDetector
from .face_recognizer import FaceRecognizer
class MultiModelViolationDetector:
"""
多模型违规检测封装类串行调用OCR、人脸识别和YOLO模型任一模型检测到违规即返回结果
"""
def __init__(self,
forbidden_words_path: str,
ocr_config_path: str,
yolo_model_path: str,
known_faces_dir: str,
ocr_confidence_threshold: float = 0.5):
"""
初始化所有检测模型
"""
# 初始化OCR检测器
self.ocr_detector = OCRViolationDetector(
forbidden_words_path=forbidden_words_path,
ocr_config_path=ocr_config_path,
ocr_confidence_threshold=ocr_confidence_threshold
)
# 初始化人脸识别器
self.face_recognizer = FaceRecognizer(
known_faces_dir=known_faces_dir
)
# 初始化YOLO检测器
self.yolo_detector = YoloViolationDetector(
model_path=yolo_model_path
)
print("多模型违规检测器初始化完成")
def detect_violations(self, frame):
"""
串行调用三个检测模型OCR → 人脸识别 → YOLO任一检测到违规即返回结果
"""
# 1. 首先进行OCR违禁词检测
try:
ocr_has_violation, ocr_words, ocr_confs = self.ocr_detector.detect(frame)
if ocr_has_violation:
details = {
"words": ocr_words,
"confidences": ocr_confs
}
print(f"警告: OCR检测到违禁内容: {details}")
return (True, "ocr", details)
except Exception as e:
print(f"错误: OCR检测出错: {str(e)}")
# 2. 接着进行人脸识别检测
try:
face_has_violation, face_name, face_similarity = self.face_recognizer.recognize(frame)
if face_has_violation:
details = {
"name": face_name,
"similarity": face_similarity
}
print(f"警告: 人脸识别到违规人员: {details}")
return (True, "face", details)
except Exception as e:
print(f"错误: 人脸识别出错: {str(e)}")
# 3. 最后进行YOLO目标检测
try:
yolo_results = self.yolo_detector.detect(frame)
if len(yolo_results.boxes) > 0:
details = {
"classes": yolo_results.names,
"boxes": yolo_results.boxes.xyxy.tolist(),
"confidences": yolo_results.boxes.conf.tolist(),
"class_ids": yolo_results.boxes.cls.tolist()
}
print(f"警告: YOLO检测到违规目标: {details}")
return (True, "yolo", details)
except Exception as e:
print(f"错误: YOLO检测出错: {str(e)}")
# 所有检测均未发现违规
return (False, None, None)
def load_config(config_path: str) -> dict:
"""加载YAML配置文件"""
try:
with open(config_path, 'r', encoding='utf-8') as f:
return yaml.safe_load(f)
except FileNotFoundError:
print(f"错误: 配置文件未找到: {config_path}")
raise
except yaml.YAMLError as e:
print(f"错误: 配置文件格式错误: {config_path}, 错误: {str(e)}")
raise
except Exception as e:
print(f"错误: 加载配置文件出错: {str(e)}")
raise
# 使用示例
# if __name__ == "__main__":
# # 加载配置文件
# config = load_config("config.yaml") # 配置文件路径,可根据实际情况修改
#
# # 初始化多模型检测器
# detector = MultiModelViolationDetector(
# forbidden_words_path=config["forbidden_words_path"],
# ocr_config_path=config["ocr_config_path"],
# yolo_model_path=config["yolo_model_path"],
# known_faces_dir=config["known_faces_dir"],
# ocr_confidence_threshold=config.get("ocr_confidence_threshold", 0.5)
# )
#
# # 读取测试图像(可替换为视频帧读取逻辑)
# test_image_path = config.get("test_image_path") # 从配置文件获取测试图片路径
# if test_image_path:
# frame = cv2.imread(test_image_path)
#
# if frame is not None:
# has_violation, violation_type, details = detector.detect_violations(frame)
# if has_violation:
# print(f"检测到违规 - 类型: {violation_type}, 详情: {details}")
# else:
# print("未检测到任何违规内容")
# else:
# print(f"无法读取测试图像: {test_image_path}")
# else:
# print("配置文件中未指定测试图像路径")