import cv2 from logger_config import logger from ocr_violation_detector import OCRViolationDetector from yolo_violation_detector import ViolationDetector as YoloViolationDetector from face_recognizer import FaceRecognizer class MultiModelViolationDetector: """ 多模型违规检测封装类,串行调用OCR、人脸识别和YOLO模型(调整为YOLO最后检测),任一模型检测到违规即返回结果 """ def __init__(self, forbidden_words_path: str, ocr_config_path: str, # 新增OCR配置文件路径参数 yolo_model_path: str, known_faces_dir: str, ocr_confidence_threshold: float = 0.5): """ 初始化所有检测模型 Args: forbidden_words_path: 违禁词文件路径 ocr_config_path: OCR配置文件(1.yaml)路径 yolo_model_path: YOLO模型文件路径 known_faces_dir: 已知人脸目录路径 ocr_confidence_threshold: OCR置信度阈值 """ # 初始化OCR检测器(传入配置文件路径) self.ocr_detector = OCRViolationDetector( forbidden_words_path=forbidden_words_path, ocr_config_path=ocr_config_path, # 传递配置文件路径 ocr_confidence_threshold=ocr_confidence_threshold ) # 初始化人脸识别器 self.face_recognizer = FaceRecognizer( known_faces_dir=known_faces_dir ) # 初始化YOLO检测器(调整为最后初始化) self.yolo_detector = YoloViolationDetector( model_path=yolo_model_path ) logger.info("多模型违规检测器初始化完成") def detect_violations(self, frame): """ 串行调用三个检测模型(OCR → 人脸识别 → YOLO),任一检测到违规即返回结果 Args: frame: 输入视频帧 (NumPy数组, BGR格式) Returns: tuple: (是否有违规, 违规类型, 违规详情) 违规类型: 'ocr' | 'yolo' | 'face' | None 违规详情: 对应模型的检测结果 """ # 1. 首先进行OCR违禁词检测 try: ocr_has_violation, ocr_words, ocr_confs = self.ocr_detector.detect(frame) if ocr_has_violation: details = { "words": ocr_words, "confidences": ocr_confs } logger.warning(f"OCR检测到违禁内容: {details}") return (True, "ocr", details) except Exception as e: logger.error(f"OCR检测出错: {str(e)}", exc_info=True) # 2. 接着进行人脸识别检测 try: face_has_violation, face_name, face_similarity = self.face_recognizer.recognize(frame) if face_has_violation: details = { "name": face_name, "similarity": face_similarity } logger.warning(f"人脸识别到违规人员: {details}") return (True, "face", details) except Exception as e: logger.error(f"人脸识别出错: {str(e)}", exc_info=True) # 3. 最后进行YOLO目标检测(调整为最后检测) try: yolo_results = self.yolo_detector.detect(frame) # 检查是否有检测结果(根据实际业务定义何为违规目标) if len(yolo_results.boxes) > 0: # 提取检测到的目标信息 details = { "classes": yolo_results.names, "boxes": yolo_results.boxes.xyxy.tolist(), # 边界框坐标 "confidences": yolo_results.boxes.conf.tolist(), # 置信度 "class_ids": yolo_results.boxes.cls.tolist() # 类别ID } logger.warning(f"YOLO检测到违规目标: {details}") return (True, "yolo", details) except Exception as e: logger.error(f"YOLO检测出错: {str(e)}", exc_info=True) # 所有检测均未发现违规 return (False, None, None) # # 使用示例 # if __name__ == "__main__": # # 配置文件路径(根据实际情况修改) # FORBIDDEN_WORDS_PATH = r"D:\Git\bin\video\ocr\forbidden_words.txt" # OCR_CONFIG_PATH = r"D:\Git\bin\video\ocr\config\1.yaml" # 新增OCR配置文件路径 # YOLO_MODEL_PATH = r"D:\Git\bin\video\ocr\models\best.pt" # KNOWN_FACES_DIR = r"D:\Git\bin\video\ocr\known_faces" # # # 初始化多模型检测器 # detector = MultiModelViolationDetector( # forbidden_words_path=FORBIDDEN_WORDS_PATH, # ocr_config_path=OCR_CONFIG_PATH, # 传入OCR配置文件路径 # yolo_model_path=YOLO_MODEL_PATH, # known_faces_dir=KNOWN_FACES_DIR, # ocr_confidence_threshold=0.5 # ) # # # 读取测试图像(可替换为视频帧读取逻辑) # test_image_path = r"D:\Git\bin\video\ocr\images\img.png" # frame = cv2.imread(test_image_path) # # if frame is not None: # has_violation, violation_type, details = detector.detect_violations(frame) # if has_violation: # print(f"检测到违规 - 类型: {violation_type}, 详情: {details}") # else: # print("未检测到任何违规内容") # else: # print(f"无法读取测试图像: {test_image_path}")