import os import cv2 import logging from rapidocr import RapidOCR class OCRViolationDetector: """ 封装RapidOCR引擎,用于检测图像帧中的违禁词。 核心功能:加载违禁词、初始化OCR引擎、单帧图像违禁词检测 """ def __init__(self, forbidden_words_path: str, ocr_config_path: str, ocr_confidence_threshold: float = 0.5, log_level: int = logging.INFO, log_file: str = None): """ 初始化OCR引擎、违禁词列表和日志配置。 Args: forbidden_words_path (str): 违禁词列表 .txt 文件的路径。 ocr_config_path (str): OCR配置文件(如1.yaml)的路径。 ocr_confidence_threshold (float): OCR识别结果的置信度阈值(0~1)。 log_level (int): 日志级别,默认为logging.INFO。 log_file (str, optional): 日志文件路径,如不提供则只输出到控制台。 """ # 初始化日志(确保先初始化日志,后续操作可正常打日志) self.logger = self._setup_logger(log_level, log_file) # 加载违禁词(优先级:先加载配置,再初始化引擎) self.forbidden_words = self._load_forbidden_words(forbidden_words_path) # 初始化RapidOCR引擎(传入配置文件路径) self.ocr_engine = self._initialize_ocr(ocr_config_path) # 校验核心依赖是否就绪 self._check_dependencies() # 设置置信度阈值(限制在0~1范围,避免非法值) self.OCR_CONFIDENCE_THRESHOLD = max(0.0, min(ocr_confidence_threshold, 1.0)) self.logger.info(f"OCR置信度阈值已设置(范围0~1): {self.OCR_CONFIDENCE_THRESHOLD:.4f}") def _setup_logger(self, log_level: int, log_file: str = None) -> logging.Logger: """ 配置日志系统(避免重复添加处理器,支持控制台+文件双输出) Args: log_level: 日志级别(如logging.DEBUG、logging.INFO)。 log_file: 日志文件路径,为None时仅输出到控制台。 Returns: logging.Logger: 配置好的日志实例。 """ logger = logging.getLogger('OCRViolationDetector') logger.setLevel(log_level) # 避免重复添加处理器(防止日志重复输出) if logger.handlers: return logger # 定义日志格式(包含时间、模块名、级别、内容) formatter = logging.Formatter( '%(asctime)s - %(name)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S' ) # 1. 添加控制台处理器 console_handler = logging.StreamHandler() console_handler.setFormatter(formatter) logger.addHandler(console_handler) # 2. 若指定日志文件,添加文件处理器(自动创建目录) if log_file: try: log_dir = os.path.dirname(log_file) # 若日志目录不存在,自动创建 if log_dir and not os.path.exists(log_dir): os.makedirs(log_dir, exist_ok=True) self.logger.debug(f"自动创建日志目录: {log_dir}") file_handler = logging.FileHandler(log_file, encoding='utf-8') file_handler.setFormatter(formatter) logger.addHandler(file_handler) logger.info(f"日志文件已配置: {log_file}") except Exception as e: logger.warning(f"创建日志文件失败(仅控制台输出): {str(e)}") return logger def _load_forbidden_words(self, path: str) -> set: """ 从TXT文件加载违禁词(去重、过滤空行,支持UTF-8编码) Args: path (str): 违禁词TXT文件路径。 Returns: set: 去重后的违禁词集合(空集合表示加载失败)。 """ forbidden_words = set() # 第一步:检查文件是否存在 if not os.path.exists(path): self.logger.error(f"违禁词文件不存在: {path}") return forbidden_words # 第二步:读取文件并处理内容 try: with open(path, 'r', encoding='utf-8') as f: # 过滤空行、去除首尾空格、去重 forbidden_words = { line.strip() for line in f if line.strip() # 跳过空行或纯空格行 } self.logger.info(f"成功加载违禁词: {len(forbidden_words)} 个(已去重)") self.logger.debug(f"违禁词列表: {forbidden_words}") except UnicodeDecodeError: self.logger.error(f"违禁词文件编码错误(需UTF-8): {path}") except PermissionError: self.logger.error(f"无权限读取违禁词文件: {path}") except Exception as e: self.logger.error(f"加载违禁词失败: {str(e)}", exc_info=True) return forbidden_words def _initialize_ocr(self, config_path: str) -> RapidOCR | None: """ 初始化RapidOCR引擎(校验配置文件、捕获初始化异常) Args: config_path (str): RapidOCR配置文件(如1.yaml)路径。 Returns: RapidOCR | None: OCR引擎实例(None表示初始化失败)。 """ self.logger.info("开始初始化RapidOCR引擎...") # 第一步:检查配置文件是否存在 if not os.path.exists(config_path): self.logger.error(f"OCR配置文件不存在: {config_path}") return None # 第二步:初始化OCR引擎(捕获RapidOCR相关异常) try: ocr_engine = RapidOCR(config_path=config_path) self.logger.info("RapidOCR引擎初始化成功") return ocr_engine except ImportError: self.logger.error("RapidOCR依赖未安装(需执行:pip install rapidocr-onnxruntime)") except Exception as e: self.logger.error(f"RapidOCR初始化失败: {str(e)}", exc_info=True) return None def _check_dependencies(self) -> None: """校验OCR引擎和违禁词列表是否就绪(输出警告日志)""" if not self.ocr_engine: self.logger.warning("⚠️ OCR引擎未就绪,违禁词检测功能将禁用") if not self.forbidden_words: self.logger.warning("⚠️ 违禁词列表为空,违禁词检测功能将禁用") def detect(self, frame) -> tuple[bool, list, list]: """ 对单帧图像进行OCR违禁词检测(核心方法) Args: frame: 输入图像帧(NumPy数组,BGR格式,cv2读取的图像)。 Returns: tuple[bool, list, list]: - 第一个元素:是否检测到违禁词(True/False); - 第二个元素:检测到的违禁词列表(空列表表示无违禁词); - 第三个元素:对应违禁词的置信度列表(与违禁词列表一一对应)。 """ # 初始化返回结果 has_violation = False violation_words = [] violation_confs = [] # 前置校验:1. 图像帧是否有效 2. OCR引擎是否就绪 3. 违禁词是否存在 if frame is None or frame.size == 0: self.logger.warning("输入图像帧为空或无效,跳过OCR检测") return has_violation, violation_words, violation_confs if not self.ocr_engine or not self.forbidden_words: self.logger.debug("OCR引擎未就绪或违禁词为空,跳过OCR检测") return has_violation, violation_words, violation_confs try: # 1. 执行OCR识别(获取RapidOCR原始结果) self.logger.debug("开始执行OCR识别...") ocr_result = self.ocr_engine(frame) self.logger.debug(f"RapidOCR原始结果: {ocr_result}") # 2. 校验OCR结果是否有效(避免None或格式异常) if ocr_result is None: self.logger.debug("OCR识别未返回任何结果(图像无文本或识别失败)") return has_violation, violation_words, violation_confs # 3. 检查txts和scores是否存在且不为None if not hasattr(ocr_result, 'txts') or ocr_result.txts is None: self.logger.warning("OCR结果中txts为None或不存在") return has_violation, violation_words, violation_confs if not hasattr(ocr_result, 'scores') or ocr_result.scores is None: self.logger.warning("OCR结果中scores为None或不存在") return has_violation, violation_words, violation_confs # 4. 转为列表并去None(防止单个元素为None) # 确保txts是可迭代的,如果不是则转为空列表 if not isinstance(ocr_result.txts, (list, tuple)): self.logger.warning(f"OCR txts不是可迭代类型,实际类型: {type(ocr_result.txts)}") texts = [] else: texts = [txt.strip() for txt in ocr_result.txts if txt and isinstance(txt, str)] # 确保scores是可迭代的,如果不是则转为空列表 if not isinstance(ocr_result.scores, (list, tuple)): self.logger.warning(f"OCR scores不是可迭代类型,实际类型: {type(ocr_result.scores)}") confidences = [] else: confidences = [conf for conf in ocr_result.scores if conf and isinstance(conf, (int, float))] # 5. 校验文本和置信度列表长度是否一致(避免zip迭代错误) if len(texts) != len(confidences): self.logger.warning( f"OCR文本与置信度数量不匹配(文本{len(texts)}个,置信度{len(confidences)}个),跳过检测") return has_violation, violation_words, violation_confs if len(texts) == 0: self.logger.debug("OCR未识别到任何有效文本") return has_violation, violation_words, violation_confs # 6. 遍历识别结果,筛选违禁词(按置信度阈值过滤) self.logger.debug(f"开始筛选违禁词(阈值{self.OCR_CONFIDENCE_THRESHOLD:.4f})") for text, conf in zip(texts, confidences): # 过滤低置信度结果 if conf < self.OCR_CONFIDENCE_THRESHOLD: self.logger.debug(f"文本 '{text}' 置信度{conf:.4f} < 阈值,跳过") continue # 检查当前文本是否包含违禁词(支持一个文本含多个违禁词) matched_words = [word for word in self.forbidden_words if word in text] if matched_words: has_violation = True # 记录所有匹配的违禁词和对应置信度 violation_words.extend(matched_words) violation_confs.extend([conf] * len(matched_words)) # 一个文本对应多个违禁词时,置信度复用 self.logger.warning(f"检测到违禁词: {matched_words}(来源文本: '{text}',置信度: {conf:.4f})") except Exception as e: # 捕获所有异常,确保不中断上层调用 self.logger.error(f"OCR检测过程异常: {str(e)}", exc_info=True) return has_violation, violation_words, violation_confs