import os import cv2 import yaml from pathlib import Path from .ocr_violation_detector import OCRViolationDetector from .yolo_violation_detector import ViolationDetector as YoloViolationDetector from .face_recognizer import FaceRecognizer class MultiModelViolationDetector: """ 多模型违规检测封装类,串行调用OCR、人脸识别和YOLO模型,任一模型检测到违规即返回结果 """ def __init__(self, forbidden_words_path: str, ocr_config_path: str, yolo_model_path: str, known_faces_dir: str, ocr_confidence_threshold: float = 0.5): """ 初始化所有检测模型 """ # 初始化OCR检测器 self.ocr_detector = OCRViolationDetector( forbidden_words_path=forbidden_words_path, ocr_config_path=ocr_config_path, ocr_confidence_threshold=ocr_confidence_threshold ) # 初始化人脸识别器 self.face_recognizer = FaceRecognizer( known_faces_dir=known_faces_dir ) # 初始化YOLO检测器 self.yolo_detector = YoloViolationDetector( model_path=yolo_model_path ) print("多模型违规检测器初始化完成") def detect_violations(self, frame): """ 串行调用三个检测模型(OCR → 人脸识别 → YOLO),任一检测到违规即返回结果 """ # 1. 首先进行OCR违禁词检测 try: ocr_has_violation, ocr_words, ocr_confs = self.ocr_detector.detect(frame) if ocr_has_violation: details = { "words": ocr_words, "confidences": ocr_confs } print(f"警告: OCR检测到违禁内容: {details}") return (True, "ocr", details) except Exception as e: print(f"错误: OCR检测出错: {str(e)}") # 2. 接着进行人脸识别检测 try: face_has_violation, face_name, face_similarity = self.face_recognizer.recognize(frame) if face_has_violation: details = { "name": face_name, "similarity": face_similarity } print(f"警告: 人脸识别到违规人员: {details}") return (True, "face", details) except Exception as e: print(f"错误: 人脸识别出错: {str(e)}") # 3. 最后进行YOLO目标检测 try: yolo_results = self.yolo_detector.detect(frame) if len(yolo_results.boxes) > 0: details = { "classes": yolo_results.names, "boxes": yolo_results.boxes.xyxy.tolist(), "confidences": yolo_results.boxes.conf.tolist(), "class_ids": yolo_results.boxes.cls.tolist() } print(f"警告: YOLO检测到违规目标: {details}") return (True, "yolo", details) except Exception as e: print(f"错误: YOLO检测出错: {str(e)}") # 所有检测均未发现违规 return (False, None, None) def load_config(config_path: str) -> dict: """加载YAML配置文件""" try: with open(config_path, 'r', encoding='utf-8') as f: return yaml.safe_load(f) except FileNotFoundError: print(f"错误: 配置文件未找到: {config_path}") raise except yaml.YAMLError as e: print(f"错误: 配置文件格式错误: {config_path}, 错误: {str(e)}") raise except Exception as e: print(f"错误: 加载配置文件出错: {str(e)}") raise # 使用示例 # if __name__ == "__main__": # # 加载配置文件 # config = load_config("config.yaml") # 配置文件路径,可根据实际情况修改 # # # 初始化多模型检测器 # detector = MultiModelViolationDetector( # forbidden_words_path=config["forbidden_words_path"], # ocr_config_path=config["ocr_config_path"], # yolo_model_path=config["yolo_model_path"], # known_faces_dir=config["known_faces_dir"], # ocr_confidence_threshold=config.get("ocr_confidence_threshold", 0.5) # ) # # # 读取测试图像(可替换为视频帧读取逻辑) # test_image_path = config.get("test_image_path") # 从配置文件获取测试图片路径 # if test_image_path: # frame = cv2.imread(test_image_path) # # if frame is not None: # has_violation, violation_type, details = detector.detect_violations(frame) # if has_violation: # print(f"检测到违规 - 类型: {violation_type}, 详情: {details}") # else: # print("未检测到任何违规内容") # else: # print(f"无法读取测试图像: {test_image_path}") # else: # print("配置文件中未指定测试图像路径")