import os import cv2 import logging from rapidocr import RapidOCR class OCRViolationDetector: """ 封装RapidOCR引擎,用于检测图像帧中的违禁词。 """ def __init__(self, forbidden_words_path: str, ocr_confidence_threshold: float = 0.5, log_level: int = logging.INFO, log_file: str = None): """ 初始化OCR引擎、违禁词列表和日志配置。 Args: forbidden_words_path (str): 违禁词列表 .txt 文件的路径。 ocr_confidence_threshold (float): OCR识别结果的置信度阈值。 log_level (int): 日志级别,默认为logging.INFO log_file (str, optional): 日志文件路径,如不提供则只输出到控制台 """ # 初始化日志 self.logger = self._setup_logger(log_level, log_file) # 加载违禁词 self.forbidden_words = self._load_forbidden_words(forbidden_words_path) # 初始化OCR引擎 self.ocr_engine = self._initialize_ocr() # 设置置信度阈值 self.OCR_CONFIDENCE_THRESHOLD = ocr_confidence_threshold self.logger.info(f"OCR置信度阈值设置为: {ocr_confidence_threshold}") def _setup_logger(self, log_level: int, log_file: str = None) -> logging.Logger: """ 配置日志系统 Args: log_level: 日志级别 log_file: 日志文件路径,如为None则只输出到控制台 Returns: 配置好的logger实例 """ # 创建logger logger = logging.getLogger('OCRViolationDetector') logger.setLevel(log_level) # 避免重复添加处理器 if logger.handlers: return logger # 定义日志格式 formatter = logging.Formatter( '%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) # 添加控制台处理器 console_handler = logging.StreamHandler() console_handler.setFormatter(formatter) logger.addHandler(console_handler) # 如果提供了日志文件路径,则添加文件处理器 if log_file: try: # 确保日志目录存在 log_dir = os.path.dirname(log_file) if log_dir and not os.path.exists(log_dir): os.makedirs(log_dir) file_handler = logging.FileHandler(log_file, encoding='utf-8') file_handler.setFormatter(formatter) logger.addHandler(file_handler) logger.info(f"日志文件将保存至: {log_file}") except Exception as e: logger.warning(f"无法创建日志文件处理器: {str(e)},仅输出至控制台") return logger def _load_forbidden_words(self, path): """从txt文件加载违禁词列表""" words = set() if not os.path.exists(path): self.logger.warning(f"警告:未找到违禁词文件 {path},将跳过违禁词检测") return words try: with open(path, 'r', encoding='utf-8') as f: # 去除每行首尾空格和换行符,过滤空行 words = {line.strip() for line in f if line.strip()} self.logger.info(f"成功加载 {len(words)} 个违禁词。") except Exception as e: self.logger.error(f"加载违禁词文件失败:{str(e)},将跳过违禁词检测") return words def _initialize_ocr(self): """初始化RapidOCR引擎""" self.logger.info("正在初始化RapidOCR引擎...") config_path = r"D:\Git\bin\video\ocr\config\1.yaml" try: # 检查配置文件是否存在 if not os.path.exists(config_path): self.logger.error(f"RapidOCR配置文件不存在: {config_path}") return None engine = RapidOCR( config_path=config_path ) self.logger.info("RapidOCR引擎初始化成功。") return engine except Exception as e: self.logger.error(f"RapidOCR引擎初始化失败: {e}") return None def detect(self, frame): """ 对单帧图像进行OCR,检测所有出现的违禁词并返回列表 返回格式:(是否有违禁词, 违禁词列表, 对应的置信度列表) """ print("收到帧") if not self.ocr_engine or not self.forbidden_words: return False, [], [] all_prohibited = [] # 存储所有检测到的违禁词 all_confidences = [] # 存储对应违禁词的置信度 try: # 执行OCR识别 result = self.ocr_engine(frame) self.logger.debug(f"RapidOCR 原始返回结果: {result}") if result is None: return False, [], [] # 提取文本和置信度(适配RapidOCR的结果格式) texts = result.txts if hasattr(result, 'txts') else [] confidences = result.scores if hasattr(result, 'scores') else [] # 遍历所有识别结果,收集所有违禁词 for text, conf in zip(texts, confidences): if conf < self.OCR_CONFIDENCE_THRESHOLD: self.logger.debug(f"文本 '{text}' 置信度 {conf:.4f} 低于阈值,跳过") continue # 检查当前文本中是否包含多个违禁词 for word in self.forbidden_words: if word in text: self.logger.warning(f"OCR检测到违禁词: '{word}' (来自文本: '{text}') 置信度: {conf:.4f}") all_prohibited.append(word) all_confidences.append(conf) except Exception as e: self.logger.error(f"OCR检测过程中发生错误: {e}", exc_info=True) # 返回检测结果(是否有违禁词、所有违禁词列表、对应置信度列表) return len(all_prohibited) > 0, all_prohibited, all_confidences # def test_image(self, image_path: str, show_image: bool = True) -> tuple: # """ # 对单张图片进行OCR违禁词检测并展示结果 # # Args: # image_path (str): 图片文件路径 # show_image (bool): 是否显示图片,默认为True # # Returns: # tuple: (是否有违禁词, 违禁词列表, 对应的置信度列表) # """ # # 检查图片文件是否存在 # if not os.path.exists(image_path): # self.logger.error(f"图片文件不存在: {image_path}") # return False, [], [] # # try: # # 读取图片 # frame = cv2.imread(image_path) # if frame is None: # self.logger.error(f"无法读取图片: {image_path}") # return False, [], [] # # self.logger.info(f"开始处理图片: {image_path}") # # # 调用检测方法 # has_violation, violations, confidences = self.detect(frame) # # # 输出检测结果 # if has_violation: # self.logger.info(f"在图片中检测到 {len(violations)} 个违禁词:") # for word, conf in zip(violations, confidences): # self.logger.info(f"- {word} (置信度: {conf:.4f})") # else: # self.logger.info("图片中未检测到违禁词") # # 显示图片(如果需要) # if show_image: # # 调整图片大小以便于显示(如果太大) # height, width = frame.shape[:2] # max_size = 800 # if max(height, width) > max_size: # scale = max_size / max(height, width) # frame = cv2.resize(frame, None, fx=scale, fy=scale) # # cv2.imshow(f"OCR检测结果: {'发现违禁词' if has_violation else '未发现违禁词'}", frame) # cv2.waitKey(0) # 等待用户按键 # cv2.destroyAllWindows() # # return has_violation, violations, confidences # # except Exception as e: # self.logger.error(f"处理图片时发生错误: {str(e)}", exc_info=True) # return False, [], [] # # # # 使用示例 # if __name__ == "__main__": # # 配置参数 # forbidden_words_path = "forbidden_words.txt" # 违禁词文件路径 # test_image_path = r"D:\Git\bin\video\ocr\images\img_7.png" # 测试图片路径 # ocr_threshold = 0.6 # OCR置信度阈值 # # # 创建检测器实例 # detector = OCRViolationDetector( # forbidden_words_path=forbidden_words_path, # ocr_confidence_threshold=ocr_threshold, # log_level=logging.INFO, # log_file="ocr_detection.log" # ) # # # 测试图片 # detector.test_image(test_image_path, show_image=True)