diff --git a/ocr/face_recognizer.py b/ocr/face_recognizer.py new file mode 100644 index 0000000..d3c2aa7 --- /dev/null +++ b/ocr/face_recognizer.py @@ -0,0 +1,187 @@ +import os +import cv2 +import numpy as np +import insightface +from insightface.app import FaceAnalysis + +class FaceRecognizer: + """ + 封装InsightFace人脸识别功能,支持从文件夹加载已知人脸。 + """ + def __init__(self, known_faces_dir: str): + self.known_faces_dir = known_faces_dir + self.app = self._initialize_insightface() + self.known_faces_embeddings = {} + self.known_faces_names = [] + self._load_known_faces() + + def _initialize_insightface(self): + """ + 初始化InsightFace FaceAnalysis应用。 + 默认使用CPU,如果检测到CUDA,会自动使用GPU。 + """ + print("正在初始化InsightFace人脸识别引擎...") + try: + # 默认模型是 'buffalo_l',包含检测、对齐、识别功能 + # 如果需要更小的模型,可以尝试 'buffalo_s' 或 'buffalo_m' + # ctx_id=0 表示使用GPU,ctx_id=-1 表示使用CPU + # InsightFace会自动检测CUDA并选择GPU,所以通常不需要手动设置ctx_id + app = FaceAnalysis(name='buffalo_l', root='~/.insightface') # 模型下载到用户目录 + app.prepare(ctx_id=0, det_size=(640, 640)) # det_size影响检测性能和精度 + print("InsightFace人脸识别引擎初始化成功。") + return app + except Exception as e: + print(f"InsightFace人脸识别引擎初始化失败: {e}") + print("请确保已安装insightface和onnxruntime,并且模型文件已下载或可访问。") + return None + + def _load_known_faces(self): + """ + 扫描已知人脸目录,加载每个人的照片并计算人脸特征。 + """ + if not os.path.exists(self.known_faces_dir): + print(f"警告: 已知人脸目录 '{self.known_faces_dir}' 不存在。请创建并放入照片。") + os.makedirs(self.known_faces_dir, exist_ok=True) + return + + print(f"正在加载已知人脸特征从: '{self.known_faces_dir}'...") + for person_name in os.listdir(self.known_faces_dir): + + person_dir = os.path.join(self.known_faces_dir, person_name) + if os.path.isdir(person_dir): + print(f" 加载人物: {person_name}") + embeddings = [] + for filename in os.listdir(person_dir): + if filename.lower().endswith(('.png', '.jpg', '.jpeg')): + image_path = os.path.join(person_dir, filename) + try: + img = cv2.imread(image_path) + if img is None: + print(f" 警告: 无法读取图片 '{image_path}',已跳过。") + continue + + # 查找人脸并提取特征 + faces = self.app.get(img) + if faces: + # 通常一张照片只有一个人脸,取第一个 + embeddings.append(faces[0].embedding) + print(f" 成功提取 '{filename}' 的人脸特征。") + else: + print(f" 警告: 在图片 '{filename}' 中未检测到人脸,已跳过。") + except Exception as e: + print(f" 处理图片 '{image_path}' 时发生错误: {e}") + + if embeddings: + # 将多张照片的特征取平均,作为该人物的最终特征 + self.known_faces_embeddings[person_name] = np.array(embeddings).mean(axis=0) + self.known_faces_names.append(person_name) + print(f" 人物 '{person_name}' 加载完成,共 {len(embeddings)} 张照片。") + else: + print(f" 警告: 人物 '{person_name}' 没有有效的人脸特征,已跳过。") + print(f"已知人脸加载完成。共 {len(self.known_faces_names)} 个人物。") + + def recognize(self, frame, threshold=0.4): + """ + 在视频帧中识别人脸。 + + Args: + frame: 输入的图像帧 (NumPy数组, BGR格式)。 + threshold (float): 识别相似度阈值。0.0到1.0,越高越严格。 + + Returns: + tuple[bool, str|None, float|None]: 是否检测到已知人脸,检测到的人名,以及相似度。 + """ + if not self.app or not self.known_faces_names: + return False, None, None + + faces = self.app.get(frame) # 在帧中检测并提取所有人的脸 + if not faces: + return False, None, None + + for face in faces: + # 遍历已知人脸,进行比对 + for known_name in self.known_faces_names: + known_embedding = self.known_faces_embeddings[known_name] + + # --- 关键修改:手动计算余弦相似度 --- + # 确保embedding是float32类型,避免潜在的类型不匹配问题 + embedding1 = face.embedding.astype(np.float32) + embedding2 = known_embedding.astype(np.float32) + + # 计算点积 + dot_product = np.dot(embedding1, embedding2) + # 计算L2范数(向量长度) + norm_embedding1 = np.linalg.norm(embedding1) + norm_embedding2 = np.linalg.norm(embedding2) + + # 避免除以零 + if norm_embedding1 == 0 or norm_embedding2 == 0: + similarity = 0.0 + else: + similarity = dot_product / (norm_embedding1 * norm_embedding2) + # ------------------------------------- + + if similarity >= threshold: + print(f"!!! 人脸识别检测到已知人物: '{known_name}' (相似度: {similarity:.4f}) !!!") + return True, known_name, similarity # 只要检测到一个就立即返回 + + return False, None, None # 没有检测到已知人脸 + + + # def test_single_image(self, image_path: str, threshold=0.4): + # """ + # 测试单张图片的人脸识别效果 + # + # Args: + # image_path: 图片路径 + # threshold: 识别阈值 + # + # Returns: + # tuple[bool, str|None, float|None]: 是否检测到已知人脸,检测到的人名,以及相似度 + # """ + # if not os.path.exists(image_path): + # print(f"错误: 图片 '{image_path}' 不存在") + # return False, None, None + # + # # 读取图片 + # frame = cv2.imread(image_path) + # if frame is None: + # print(f"错误: 无法读取图片 '{image_path}'") + # return False, None, None + # + # # 调用识别方法 + # result, name, similarity = self.recognize(frame, threshold) + # + # # 显示结果 + # if result: + # print(f"测试结果: 在图片中识别到 {name},相似度: {similarity:.4f}") + # + # # 绘制识别结果并显示图片 + # faces = self.app.get(frame) + # for face in faces: + # bbox = face.bbox.astype(int) + # # 绘制 bounding box + # cv2.rectangle(frame, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (0, 255, 0), 2) + # # 绘制姓名和相似度 + # text = f"{name}: {similarity:.2f}" + # cv2.putText(frame, text, (bbox[0], bbox[1] - 10), + # cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2) + # + # # 显示图片 + # cv2.imshow('Recognition Result', frame) + # print("按任意键关闭图片窗口...") + # cv2.waitKey(0) + # cv2.destroyAllWindows() + # else: + # print("测试结果: 未在图片中识别到已知人脸") + # + # return result, name, similarity + + +# if __name__ == "__main__": +# # 初始化人脸识别器,指定已知人脸目录 +# recognizer = FaceRecognizer(known_faces_dir="known_faces") +# +# # 测试单张图片 +# test_image_path = r"F:\OCR\RapidOCR-main\known_faces\B\14sino-qiu02-master1050.jpg" # 替换为你的测试图片路径 +# recognizer.test_single_image(test_image_path, threshold=0.4) \ No newline at end of file diff --git a/ocr/images/img.png b/ocr/images/img.png new file mode 100644 index 0000000..7c6e9bf Binary files /dev/null and b/ocr/images/img.png differ diff --git a/ocr/images/img_7.png b/ocr/images/img_7.png new file mode 100644 index 0000000..2308c31 Binary files /dev/null and b/ocr/images/img_7.png differ diff --git a/ocr/known_faces/B/102-f.jpg_1140x855.jpg b/ocr/known_faces/B/102-f.jpg_1140x855.jpg new file mode 100644 index 0000000..9184dbc Binary files /dev/null and b/ocr/known_faces/B/102-f.jpg_1140x855.jpg differ diff --git a/ocr/known_faces/B/104-1.jpg b/ocr/known_faces/B/104-1.jpg new file mode 100644 index 0000000..e6a2a8f Binary files /dev/null and b/ocr/known_faces/B/104-1.jpg differ diff --git a/ocr/known_faces/B/110627170414_boxilai_304x304_cns.jpg.webp b/ocr/known_faces/B/110627170414_boxilai_304x304_cns.jpg.webp new file mode 100644 index 0000000..9fb4dd6 Binary files /dev/null and b/ocr/known_faces/B/110627170414_boxilai_304x304_cns.jpg.webp differ diff --git a/ocr/known_faces/B/14sino-qiu02-master1050.jpg b/ocr/known_faces/B/14sino-qiu02-master1050.jpg new file mode 100644 index 0000000..b18210c Binary files /dev/null and b/ocr/known_faces/B/14sino-qiu02-master1050.jpg differ diff --git a/ocr/known_faces/B/xilai003.webp b/ocr/known_faces/B/xilai003.webp new file mode 100644 index 0000000..186417b Binary files /dev/null and b/ocr/known_faces/B/xilai003.webp differ diff --git a/ocr/known_faces/W/120208041156_wang_lijun_304x171_xinhua.jpg.webp b/ocr/known_faces/W/120208041156_wang_lijun_304x171_xinhua.jpg.webp new file mode 100644 index 0000000..520bc13 Binary files /dev/null and b/ocr/known_faces/W/120208041156_wang_lijun_304x171_xinhua.jpg.webp differ diff --git a/ocr/known_faces/W/2f0f70db48.jpg b/ocr/known_faces/W/2f0f70db48.jpg new file mode 100644 index 0000000..7dbeda5 Binary files /dev/null and b/ocr/known_faces/W/2f0f70db48.jpg differ diff --git a/ocr/known_faces/W/lijun-jumbo.jpg b/ocr/known_faces/W/lijun-jumbo.jpg new file mode 100644 index 0000000..c1b742d Binary files /dev/null and b/ocr/known_faces/W/lijun-jumbo.jpg differ diff --git a/ocr/known_faces/X/1404123658308624.jpg b/ocr/known_faces/X/1404123658308624.jpg new file mode 100644 index 0000000..1ece977 Binary files /dev/null and b/ocr/known_faces/X/1404123658308624.jpg differ diff --git a/ocr/known_faces/X/Xu_CaiHou.jpg b/ocr/known_faces/X/Xu_CaiHou.jpg new file mode 100644 index 0000000..e81f62e Binary files /dev/null and b/ocr/known_faces/X/Xu_CaiHou.jpg differ diff --git a/ocr/known_faces/X/a0a2e8d4-69d2-409d-ac3e-fdf8f6755f0e_cx0_cy6_cw0_w1023_r1_s.jpg b/ocr/known_faces/X/a0a2e8d4-69d2-409d-ac3e-fdf8f6755f0e_cx0_cy6_cw0_w1023_r1_s.jpg new file mode 100644 index 0000000..591264d Binary files /dev/null and b/ocr/known_faces/X/a0a2e8d4-69d2-409d-ac3e-fdf8f6755f0e_cx0_cy6_cw0_w1023_r1_s.jpg differ diff --git a/ocr/model_violation_detector.py b/ocr/model_violation_detector.py new file mode 100644 index 0000000..92a19c7 --- /dev/null +++ b/ocr/model_violation_detector.py @@ -0,0 +1,133 @@ +import cv2 +from logger_config import logger +from ocr_violation_detector import OCRViolationDetector +from yolo_violation_detector import ViolationDetector as YoloViolationDetector +from face_recognizer import FaceRecognizer + + +class MultiModelViolationDetector: + """ + 多模型违规检测封装类,串行调用OCR、人脸识别和YOLO模型(调整为YOLO最后检测),任一模型检测到违规即返回结果 + """ + + def __init__(self, + forbidden_words_path: str, + ocr_config_path: str, # 新增OCR配置文件路径参数 + yolo_model_path: str, + known_faces_dir: str, + ocr_confidence_threshold: float = 0.5): + """ + 初始化所有检测模型 + + Args: + forbidden_words_path: 违禁词文件路径 + ocr_config_path: OCR配置文件(1.yaml)路径 + yolo_model_path: YOLO模型文件路径 + known_faces_dir: 已知人脸目录路径 + ocr_confidence_threshold: OCR置信度阈值 + """ + # 初始化OCR检测器(传入配置文件路径) + self.ocr_detector = OCRViolationDetector( + forbidden_words_path=forbidden_words_path, + ocr_config_path=ocr_config_path, # 传递配置文件路径 + ocr_confidence_threshold=ocr_confidence_threshold + ) + + # 初始化人脸识别器 + self.face_recognizer = FaceRecognizer( + known_faces_dir=known_faces_dir + ) + + # 初始化YOLO检测器(调整为最后初始化) + self.yolo_detector = YoloViolationDetector( + model_path=yolo_model_path + ) + + logger.info("多模型违规检测器初始化完成") + + def detect_violations(self, frame): + """ + 串行调用三个检测模型(OCR → 人脸识别 → YOLO),任一检测到违规即返回结果 + Args: + frame: 输入视频帧 (NumPy数组, BGR格式) + Returns: + tuple: (是否有违规, 违规类型, 违规详情) + 违规类型: 'ocr' | 'yolo' | 'face' | None + 违规详情: 对应模型的检测结果 + """ + # 1. 首先进行OCR违禁词检测 + try: + ocr_has_violation, ocr_words, ocr_confs = self.ocr_detector.detect(frame) + if ocr_has_violation: + details = { + "words": ocr_words, + "confidences": ocr_confs + } + logger.warning(f"OCR检测到违禁内容: {details}") + return (True, "ocr", details) + except Exception as e: + logger.error(f"OCR检测出错: {str(e)}", exc_info=True) + + # 2. 接着进行人脸识别检测 + try: + face_has_violation, face_name, face_similarity = self.face_recognizer.recognize(frame) + if face_has_violation: + details = { + "name": face_name, + "similarity": face_similarity + } + logger.warning(f"人脸识别到违规人员: {details}") + return (True, "face", details) + except Exception as e: + logger.error(f"人脸识别出错: {str(e)}", exc_info=True) + + # 3. 最后进行YOLO目标检测(调整为最后检测) + try: + yolo_results = self.yolo_detector.detect(frame) + # 检查是否有检测结果(根据实际业务定义何为违规目标) + if len(yolo_results.boxes) > 0: + # 提取检测到的目标信息 + details = { + "classes": yolo_results.names, + "boxes": yolo_results.boxes.xyxy.tolist(), # 边界框坐标 + "confidences": yolo_results.boxes.conf.tolist(), # 置信度 + "class_ids": yolo_results.boxes.cls.tolist() # 类别ID + } + logger.warning(f"YOLO检测到违规目标: {details}") + return (True, "yolo", details) + except Exception as e: + logger.error(f"YOLO检测出错: {str(e)}", exc_info=True) + + # 所有检测均未发现违规 + return (False, None, None) + + +# # 使用示例 +# if __name__ == "__main__": +# # 配置文件路径(根据实际情况修改) +# FORBIDDEN_WORDS_PATH = r"D:\Git\bin\video\ocr\forbidden_words.txt" +# OCR_CONFIG_PATH = r"D:\Git\bin\video\ocr\config\1.yaml" # 新增OCR配置文件路径 +# YOLO_MODEL_PATH = r"D:\Git\bin\video\ocr\models\best.pt" +# KNOWN_FACES_DIR = r"D:\Git\bin\video\ocr\known_faces" +# +# # 初始化多模型检测器 +# detector = MultiModelViolationDetector( +# forbidden_words_path=FORBIDDEN_WORDS_PATH, +# ocr_config_path=OCR_CONFIG_PATH, # 传入OCR配置文件路径 +# yolo_model_path=YOLO_MODEL_PATH, +# known_faces_dir=KNOWN_FACES_DIR, +# ocr_confidence_threshold=0.5 +# ) +# +# # 读取测试图像(可替换为视频帧读取逻辑) +# test_image_path = r"D:\Git\bin\video\ocr\images\img.png" +# frame = cv2.imread(test_image_path) +# +# if frame is not None: +# has_violation, violation_type, details = detector.detect_violations(frame) +# if has_violation: +# print(f"检测到违规 - 类型: {violation_type}, 详情: {details}") +# else: +# print("未检测到任何违规内容") +# else: +# print(f"无法读取测试图像: {test_image_path}") \ No newline at end of file diff --git a/ocr/models/best.pt b/ocr/models/best.pt new file mode 100644 index 0000000..c6958f7 Binary files /dev/null and b/ocr/models/best.pt differ diff --git a/ocr/ocr_violation_detector.py b/ocr/ocr_violation_detector.py index dbd4fee..bfb3407 100644 --- a/ocr/ocr_violation_detector.py +++ b/ocr/ocr_violation_detector.py @@ -7,227 +7,248 @@ from rapidocr import RapidOCR class OCRViolationDetector: """ 封装RapidOCR引擎,用于检测图像帧中的违禁词。 + 核心功能:加载违禁词、初始化OCR引擎、单帧图像违禁词检测 """ - def __init__(self, forbidden_words_path: str, ocr_confidence_threshold: float = 0.5, - log_level: int = logging.INFO, log_file: str = None): + def __init__(self, + forbidden_words_path: str, + ocr_config_path: str, + ocr_confidence_threshold: float = 0.5, + log_level: int = logging.INFO, + log_file: str = None): """ 初始化OCR引擎、违禁词列表和日志配置。 Args: forbidden_words_path (str): 违禁词列表 .txt 文件的路径。 - ocr_confidence_threshold (float): OCR识别结果的置信度阈值。 - log_level (int): 日志级别,默认为logging.INFO - log_file (str, optional): 日志文件路径,如不提供则只输出到控制台 + ocr_config_path (str): OCR配置文件(如1.yaml)的路径。 + ocr_confidence_threshold (float): OCR识别结果的置信度阈值(0~1)。 + log_level (int): 日志级别,默认为logging.INFO。 + log_file (str, optional): 日志文件路径,如不提供则只输出到控制台。 """ - # 初始化日志 + # 初始化日志(确保先初始化日志,后续操作可正常打日志) self.logger = self._setup_logger(log_level, log_file) - # 加载违禁词 + # 加载违禁词(优先级:先加载配置,再初始化引擎) self.forbidden_words = self._load_forbidden_words(forbidden_words_path) - # 初始化OCR引擎 - self.ocr_engine = self._initialize_ocr() + # 初始化RapidOCR引擎(传入配置文件路径) + self.ocr_engine = self._initialize_ocr(ocr_config_path) - # 设置置信度阈值 - self.OCR_CONFIDENCE_THRESHOLD = ocr_confidence_threshold - self.logger.info(f"OCR置信度阈值设置为: {ocr_confidence_threshold}") + # 校验核心依赖是否就绪 + self._check_dependencies() + + # 设置置信度阈值(限制在0~1范围,避免非法值) + self.OCR_CONFIDENCE_THRESHOLD = max(0.0, min(ocr_confidence_threshold, 1.0)) + self.logger.info(f"OCR置信度阈值已设置(范围0~1): {self.OCR_CONFIDENCE_THRESHOLD:.4f}") def _setup_logger(self, log_level: int, log_file: str = None) -> logging.Logger: """ - 配置日志系统 + 配置日志系统(避免重复添加处理器,支持控制台+文件双输出) Args: - log_level: 日志级别 - log_file: 日志文件路径,如为None则只输出到控制台 + log_level: 日志级别(如logging.DEBUG、logging.INFO)。 + log_file: 日志文件路径,为None时仅输出到控制台。 Returns: - 配置好的logger实例 + logging.Logger: 配置好的日志实例。 """ - # 创建logger logger = logging.getLogger('OCRViolationDetector') logger.setLevel(log_level) - # 避免重复添加处理器 + # 避免重复添加处理器(防止日志重复输出) if logger.handlers: return logger - # 定义日志格式 + # 定义日志格式(包含时间、模块名、级别、内容) formatter = logging.Formatter( - '%(asctime)s - %(name)s - %(levelname)s - %(message)s' + '%(asctime)s - %(name)s - %(levelname)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' ) - # 添加控制台处理器 + # 1. 添加控制台处理器 console_handler = logging.StreamHandler() console_handler.setFormatter(formatter) logger.addHandler(console_handler) - # 如果提供了日志文件路径,则添加文件处理器 + # 2. 若指定日志文件,添加文件处理器(自动创建目录) if log_file: try: - # 确保日志目录存在 log_dir = os.path.dirname(log_file) + # 若日志目录不存在,自动创建 if log_dir and not os.path.exists(log_dir): - os.makedirs(log_dir) + os.makedirs(log_dir, exist_ok=True) + self.logger.debug(f"自动创建日志目录: {log_dir}") file_handler = logging.FileHandler(log_file, encoding='utf-8') file_handler.setFormatter(formatter) logger.addHandler(file_handler) - logger.info(f"日志文件将保存至: {log_file}") + logger.info(f"日志文件已配置: {log_file}") except Exception as e: - logger.warning(f"无法创建日志文件处理器: {str(e)},仅输出至控制台") + logger.warning(f"创建日志文件失败(仅控制台输出): {str(e)}") return logger - def _load_forbidden_words(self, path): - """从txt文件加载违禁词列表""" - words = set() - if not os.path.exists(path): - self.logger.warning(f"警告:未找到违禁词文件 {path},将跳过违禁词检测") - return words + def _load_forbidden_words(self, path: str) -> set: + """ + 从TXT文件加载违禁词(去重、过滤空行,支持UTF-8编码) + Args: + path (str): 违禁词TXT文件路径。 + + Returns: + set: 去重后的违禁词集合(空集合表示加载失败)。 + """ + forbidden_words = set() + + # 第一步:检查文件是否存在 + if not os.path.exists(path): + self.logger.error(f"违禁词文件不存在: {path}") + return forbidden_words + + # 第二步:读取文件并处理内容 try: with open(path, 'r', encoding='utf-8') as f: - # 去除每行首尾空格和换行符,过滤空行 - words = {line.strip() for line in f if line.strip()} - self.logger.info(f"成功加载 {len(words)} 个违禁词。") + # 过滤空行、去除首尾空格、去重 + forbidden_words = { + line.strip() for line in f + if line.strip() # 跳过空行或纯空格行 + } + self.logger.info(f"成功加载违禁词: {len(forbidden_words)} 个(已去重)") + self.logger.debug(f"违禁词列表: {forbidden_words}") + except UnicodeDecodeError: + self.logger.error(f"违禁词文件编码错误(需UTF-8): {path}") + except PermissionError: + self.logger.error(f"无权限读取违禁词文件: {path}") except Exception as e: - self.logger.error(f"加载违禁词文件失败:{str(e)},将跳过违禁词检测") - return words + self.logger.error(f"加载违禁词失败: {str(e)}", exc_info=True) - def _initialize_ocr(self): - """初始化RapidOCR引擎""" - self.logger.info("正在初始化RapidOCR引擎...") + return forbidden_words - config_path = r"D:\Git\bin\video\ocr\config\1.yaml" - try: - # 检查配置文件是否存在 - if not os.path.exists(config_path): - self.logger.error(f"RapidOCR配置文件不存在: {config_path}") - return None + def _initialize_ocr(self, config_path: str) -> RapidOCR | None: + """ + 初始化RapidOCR引擎(校验配置文件、捕获初始化异常) - engine = RapidOCR( - config_path=config_path - ) - self.logger.info("RapidOCR引擎初始化成功。") - return engine - except Exception as e: - self.logger.error(f"RapidOCR引擎初始化失败: {e}") + Args: + config_path (str): RapidOCR配置文件(如1.yaml)路径。 + + Returns: + RapidOCR | None: OCR引擎实例(None表示初始化失败)。 + """ + self.logger.info("开始初始化RapidOCR引擎...") + + # 第一步:检查配置文件是否存在 + if not os.path.exists(config_path): + self.logger.error(f"OCR配置文件不存在: {config_path}") return None - def detect(self, frame): - """ - 对单帧图像进行OCR,检测所有出现的违禁词并返回列表 - 返回格式:(是否有违禁词, 违禁词列表, 对应的置信度列表) - """ - print("收到帧") - if not self.ocr_engine or not self.forbidden_words: - return False, [], [] + # 第二步:初始化OCR引擎(捕获RapidOCR相关异常) + try: + ocr_engine = RapidOCR(config_path=config_path) + self.logger.info("RapidOCR引擎初始化成功") + return ocr_engine + except ImportError: + self.logger.error("RapidOCR依赖未安装(需执行:pip install rapidocr-onnxruntime)") + except Exception as e: + self.logger.error(f"RapidOCR初始化失败: {str(e)}", exc_info=True) - all_prohibited = [] # 存储所有检测到的违禁词 - all_confidences = [] # 存储对应违禁词的置信度 + return None + + def _check_dependencies(self) -> None: + """校验OCR引擎和违禁词列表是否就绪(输出警告日志)""" + if not self.ocr_engine: + self.logger.warning("⚠️ OCR引擎未就绪,违禁词检测功能将禁用") + if not self.forbidden_words: + self.logger.warning("⚠️ 违禁词列表为空,违禁词检测功能将禁用") + + def detect(self, frame) -> tuple[bool, list, list]: + """ + 对单帧图像进行OCR违禁词检测(核心方法) + + Args: + frame: 输入图像帧(NumPy数组,BGR格式,cv2读取的图像)。 + + Returns: + tuple[bool, list, list]: + - 第一个元素:是否检测到违禁词(True/False); + - 第二个元素:检测到的违禁词列表(空列表表示无违禁词); + - 第三个元素:对应违禁词的置信度列表(与违禁词列表一一对应)。 + """ + # 初始化返回结果 + has_violation = False + violation_words = [] + violation_confs = [] + + # 前置校验:1. 图像帧是否有效 2. OCR引擎是否就绪 3. 违禁词是否存在 + if frame is None or frame.size == 0: + self.logger.warning("输入图像帧为空或无效,跳过OCR检测") + return has_violation, violation_words, violation_confs + if not self.ocr_engine or not self.forbidden_words: + self.logger.debug("OCR引擎未就绪或违禁词为空,跳过OCR检测") + return has_violation, violation_words, violation_confs try: - # 执行OCR识别 - result = self.ocr_engine(frame) - self.logger.debug(f"RapidOCR 原始返回结果: {result}") + # 1. 执行OCR识别(获取RapidOCR原始结果) + self.logger.debug("开始执行OCR识别...") + ocr_result = self.ocr_engine(frame) + self.logger.debug(f"RapidOCR原始结果: {ocr_result}") - if result is None: - return False, [], [] + # 2. 校验OCR结果是否有效(避免None或格式异常) + if ocr_result is None: + self.logger.debug("OCR识别未返回任何结果(图像无文本或识别失败)") + return has_violation, violation_words, violation_confs - # 提取文本和置信度(适配RapidOCR的结果格式) - texts = result.txts if hasattr(result, 'txts') else [] - confidences = result.scores if hasattr(result, 'scores') else [] + # 3. 检查txts和scores是否存在且不为None + if not hasattr(ocr_result, 'txts') or ocr_result.txts is None: + self.logger.warning("OCR结果中txts为None或不存在") + return has_violation, violation_words, violation_confs - # 遍历所有识别结果,收集所有违禁词 + if not hasattr(ocr_result, 'scores') or ocr_result.scores is None: + self.logger.warning("OCR结果中scores为None或不存在") + return has_violation, violation_words, violation_confs + + # 4. 转为列表并去None(防止单个元素为None) + # 确保txts是可迭代的,如果不是则转为空列表 + if not isinstance(ocr_result.txts, (list, tuple)): + self.logger.warning(f"OCR txts不是可迭代类型,实际类型: {type(ocr_result.txts)}") + texts = [] + else: + texts = [txt.strip() for txt in ocr_result.txts if txt and isinstance(txt, str)] + + # 确保scores是可迭代的,如果不是则转为空列表 + if not isinstance(ocr_result.scores, (list, tuple)): + self.logger.warning(f"OCR scores不是可迭代类型,实际类型: {type(ocr_result.scores)}") + confidences = [] + else: + confidences = [conf for conf in ocr_result.scores if conf and isinstance(conf, (int, float))] + + # 5. 校验文本和置信度列表长度是否一致(避免zip迭代错误) + if len(texts) != len(confidences): + self.logger.warning( + f"OCR文本与置信度数量不匹配(文本{len(texts)}个,置信度{len(confidences)}个),跳过检测") + return has_violation, violation_words, violation_confs + if len(texts) == 0: + self.logger.debug("OCR未识别到任何有效文本") + return has_violation, violation_words, violation_confs + + # 6. 遍历识别结果,筛选违禁词(按置信度阈值过滤) + self.logger.debug(f"开始筛选违禁词(阈值{self.OCR_CONFIDENCE_THRESHOLD:.4f})") for text, conf in zip(texts, confidences): + # 过滤低置信度结果 if conf < self.OCR_CONFIDENCE_THRESHOLD: - self.logger.debug(f"文本 '{text}' 置信度 {conf:.4f} 低于阈值,跳过") + self.logger.debug(f"文本 '{text}' 置信度{conf:.4f} < 阈值,跳过") continue - - # 检查当前文本中是否包含多个违禁词 - for word in self.forbidden_words: - if word in text: - self.logger.warning(f"OCR检测到违禁词: '{word}' (来自文本: '{text}') 置信度: {conf:.4f}") - all_prohibited.append(word) - all_confidences.append(conf) + # 检查当前文本是否包含违禁词(支持一个文本含多个违禁词) + matched_words = [word for word in self.forbidden_words if word in text] + if matched_words: + has_violation = True + # 记录所有匹配的违禁词和对应置信度 + violation_words.extend(matched_words) + violation_confs.extend([conf] * len(matched_words)) # 一个文本对应多个违禁词时,置信度复用 + self.logger.warning(f"检测到违禁词: {matched_words}(来源文本: '{text}',置信度: {conf:.4f})") except Exception as e: - self.logger.error(f"OCR检测过程中发生错误: {e}", exc_info=True) + # 捕获所有异常,确保不中断上层调用 + self.logger.error(f"OCR检测过程异常: {str(e)}", exc_info=True) - # 返回检测结果(是否有违禁词、所有违禁词列表、对应置信度列表) - return len(all_prohibited) > 0, all_prohibited, all_confidences - -# def test_image(self, image_path: str, show_image: bool = True) -> tuple: -# """ -# 对单张图片进行OCR违禁词检测并展示结果 -# -# Args: -# image_path (str): 图片文件路径 -# show_image (bool): 是否显示图片,默认为True -# -# Returns: -# tuple: (是否有违禁词, 违禁词列表, 对应的置信度列表) -# """ -# # 检查图片文件是否存在 -# if not os.path.exists(image_path): -# self.logger.error(f"图片文件不存在: {image_path}") -# return False, [], [] -# -# try: -# # 读取图片 -# frame = cv2.imread(image_path) -# if frame is None: -# self.logger.error(f"无法读取图片: {image_path}") -# return False, [], [] -# -# self.logger.info(f"开始处理图片: {image_path}") -# -# # 调用检测方法 -# has_violation, violations, confidences = self.detect(frame) -# -# # 输出检测结果 -# if has_violation: -# self.logger.info(f"在图片中检测到 {len(violations)} 个违禁词:") -# for word, conf in zip(violations, confidences): -# self.logger.info(f"- {word} (置信度: {conf:.4f})") -# else: -# self.logger.info("图片中未检测到违禁词") - -# # 显示图片(如果需要) -# if show_image: -# # 调整图片大小以便于显示(如果太大) -# height, width = frame.shape[:2] -# max_size = 800 -# if max(height, width) > max_size: -# scale = max_size / max(height, width) -# frame = cv2.resize(frame, None, fx=scale, fy=scale) -# -# cv2.imshow(f"OCR检测结果: {'发现违禁词' if has_violation else '未发现违禁词'}", frame) -# cv2.waitKey(0) # 等待用户按键 -# cv2.destroyAllWindows() -# -# return has_violation, violations, confidences -# -# except Exception as e: -# self.logger.error(f"处理图片时发生错误: {str(e)}", exc_info=True) -# return False, [], [] -# -# -# # 使用示例 -# if __name__ == "__main__": -# # 配置参数 -# forbidden_words_path = "forbidden_words.txt" # 违禁词文件路径 -# test_image_path = r"D:\Git\bin\video\ocr\images\img_7.png" # 测试图片路径 -# ocr_threshold = 0.6 # OCR置信度阈值 -# -# # 创建检测器实例 -# detector = OCRViolationDetector( -# forbidden_words_path=forbidden_words_path, -# ocr_confidence_threshold=ocr_threshold, -# log_level=logging.INFO, -# log_file="ocr_detection.log" -# ) -# -# # 测试图片 -# detector.test_image(test_image_path, show_image=True) \ No newline at end of file + return has_violation, violation_words, violation_confs diff --git a/ocr/yolo_violation_detector.py b/ocr/yolo_violation_detector.py new file mode 100644 index 0000000..be50afc --- /dev/null +++ b/ocr/yolo_violation_detector.py @@ -0,0 +1,48 @@ +from ultralytics import YOLO +import cv2 +from logger_config import logger + +class ViolationDetector: + """ + 用于加载YOLOv8 .pt模型并进行违规内容检测的类。 + """ + def __init__(self, model_path): + """ + 初始化检测器。 + + Args: + model_path (str): YOLO .pt模型的路径。 + """ + logger.info(f"正在从 '{model_path}' 加载YOLO模型...") + self.model = YOLO(model_path) + logger.info("YOLO模型加载成功。") + + def detect(self, frame): + """ + 对单帧图像进行目标检测。 + + Args: + frame: 输入的图像帧 (NumPy数组, BGR格式)。 + + Returns: + ultralytics.engine.results.Results: YOLO的检测结果对象。 + """ + # conf可以根据您的模型效果进行调整 + # --- 为了测试,我们暂时将置信度调低,例如 0.2 --- + results = self.model(frame, conf=0.2) + return results[0] + + def draw_boxes(self, frame, result): + """ + 在图像帧上绘制检测框。 + + Args: + frame: 原始图像帧。 + result: YOLO的检测结果对象。 + + Returns: + numpy.ndarray: 绘制了检测框的图像帧。 + """ + # 使用YOLO自带的plot功能,方便快捷 + annotated_frame = result.plot() + return annotated_frame diff --git a/rtc/rtc.py b/rtc/rtc.py index c47e5e3..d9e4a03 100644 --- a/rtc/rtc.py +++ b/rtc/rtc.py @@ -6,17 +6,6 @@ import time from aiortc import RTCPeerConnection, RTCSessionDescription, RTCConfiguration from aiortc.mediastreams import MediaStreamTrack -from ocr.ocr_violation_detector import OCRViolationDetector -import logging - -# 创建检测器实例 -detector = OCRViolationDetector( - forbidden_words_path=r"D:\Git\bin\video\ocr\forbidden_words.txt", - ocr_confidence_threshold=0.7, - log_level=logging.INFO, - log_file="ocr_detection.log" -) - # 创建一个长度为1的队列,用于生产者和消费者之间的通信 frame_queue = queue.Queue(maxsize=1)