Files
video/ocr/model_violation_detector.py
2025-09-03 14:38:42 +08:00

133 lines
5.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import cv2
from logger_config import logger
from ocr_violation_detector import OCRViolationDetector
from yolo_violation_detector import ViolationDetector as YoloViolationDetector
from face_recognizer import FaceRecognizer
class MultiModelViolationDetector:
"""
多模型违规检测封装类串行调用OCR、人脸识别和YOLO模型调整为YOLO最后检测任一模型检测到违规即返回结果
"""
def __init__(self,
forbidden_words_path: str,
ocr_config_path: str, # 新增OCR配置文件路径参数
yolo_model_path: str,
known_faces_dir: str,
ocr_confidence_threshold: float = 0.5):
"""
初始化所有检测模型
Args:
forbidden_words_path: 违禁词文件路径
ocr_config_path: OCR配置文件1.yaml路径
yolo_model_path: YOLO模型文件路径
known_faces_dir: 已知人脸目录路径
ocr_confidence_threshold: OCR置信度阈值
"""
# 初始化OCR检测器传入配置文件路径
self.ocr_detector = OCRViolationDetector(
forbidden_words_path=forbidden_words_path,
ocr_config_path=ocr_config_path, # 传递配置文件路径
ocr_confidence_threshold=ocr_confidence_threshold
)
# 初始化人脸识别器
self.face_recognizer = FaceRecognizer(
known_faces_dir=known_faces_dir
)
# 初始化YOLO检测器调整为最后初始化
self.yolo_detector = YoloViolationDetector(
model_path=yolo_model_path
)
logger.info("多模型违规检测器初始化完成")
def detect_violations(self, frame):
"""
串行调用三个检测模型OCR → 人脸识别 → YOLO任一检测到违规即返回结果
Args:
frame: 输入视频帧 (NumPy数组, BGR格式)
Returns:
tuple: (是否有违规, 违规类型, 违规详情)
违规类型: 'ocr' | 'yolo' | 'face' | None
违规详情: 对应模型的检测结果
"""
# 1. 首先进行OCR违禁词检测
try:
ocr_has_violation, ocr_words, ocr_confs = self.ocr_detector.detect(frame)
if ocr_has_violation:
details = {
"words": ocr_words,
"confidences": ocr_confs
}
logger.warning(f"OCR检测到违禁内容: {details}")
return (True, "ocr", details)
except Exception as e:
logger.error(f"OCR检测出错: {str(e)}", exc_info=True)
# 2. 接着进行人脸识别检测
try:
face_has_violation, face_name, face_similarity = self.face_recognizer.recognize(frame)
if face_has_violation:
details = {
"name": face_name,
"similarity": face_similarity
}
logger.warning(f"人脸识别到违规人员: {details}")
return (True, "face", details)
except Exception as e:
logger.error(f"人脸识别出错: {str(e)}", exc_info=True)
# 3. 最后进行YOLO目标检测调整为最后检测
try:
yolo_results = self.yolo_detector.detect(frame)
# 检查是否有检测结果(根据实际业务定义何为违规目标)
if len(yolo_results.boxes) > 0:
# 提取检测到的目标信息
details = {
"classes": yolo_results.names,
"boxes": yolo_results.boxes.xyxy.tolist(), # 边界框坐标
"confidences": yolo_results.boxes.conf.tolist(), # 置信度
"class_ids": yolo_results.boxes.cls.tolist() # 类别ID
}
logger.warning(f"YOLO检测到违规目标: {details}")
return (True, "yolo", details)
except Exception as e:
logger.error(f"YOLO检测出错: {str(e)}", exc_info=True)
# 所有检测均未发现违规
return (False, None, None)
# # 使用示例
# if __name__ == "__main__":
# # 配置文件路径(根据实际情况修改)
# FORBIDDEN_WORDS_PATH = r"D:\Git\bin\video\ocr\forbidden_words.txt"
# OCR_CONFIG_PATH = r"D:\Git\bin\video\ocr\config\1.yaml" # 新增OCR配置文件路径
# YOLO_MODEL_PATH = r"D:\Git\bin\video\ocr\models\best.pt"
# KNOWN_FACES_DIR = r"D:\Git\bin\video\ocr\known_faces"
#
# # 初始化多模型检测器
# detector = MultiModelViolationDetector(
# forbidden_words_path=FORBIDDEN_WORDS_PATH,
# ocr_config_path=OCR_CONFIG_PATH, # 传入OCR配置文件路径
# yolo_model_path=YOLO_MODEL_PATH,
# known_faces_dir=KNOWN_FACES_DIR,
# ocr_confidence_threshold=0.5
# )
#
# # 读取测试图像(可替换为视频帧读取逻辑)
# test_image_path = r"D:\Git\bin\video\ocr\images\img.png"
# frame = cv2.imread(test_image_path)
#
# if frame is not None:
# has_violation, violation_type, details = detector.detect_violations(frame)
# if has_violation:
# print(f"检测到违规 - 类型: {violation_type}, 详情: {details}")
# else:
# print("未检测到任何违规内容")
# else:
# print(f"无法读取测试图像: {test_image_path}")