import os import cv2 from rapidocr import RapidOCR from logger_config import logger class OCRViolationDetector: """ 封装RapidOCR引擎,用于检测图像帧中的违禁词。 """ def __init__(self, forbidden_words_path: str, ocr_confidence_threshold: float = 0.5): # 降低阈值提高检出率 """ 初始化OCR引擎和违禁词列表。 Args: forbidden_words_path (str): 违禁词列表 .txt 文件的路径。 ocr_confidence_threshold (float): OCR识别结果的置信度阈值。 """ self.forbidden_words = self._load_forbidden_words(forbidden_words_path) self.ocr_engine = self._initialize_ocr() self.OCR_CONFIDENCE_THRESHOLD = ocr_confidence_threshold def _load_forbidden_words(self, path): """从txt文件加载违禁词列表(与rapidocr_test.py保持一致)""" words = set() if not os.path.exists(path): logger.warning(f"警告:未找到违禁词文件 {path},将跳过违禁词检测") return words try: with open(path, 'r', encoding='utf-8') as f: # 去除每行首尾空格和换行符,过滤空行(不排除注释行,与测试代码统一) words = {line.strip() for line in f if line.strip()} logger.info(f"成功加载 {len(words)} 个违禁词。") except Exception as e: logger.error(f"加载违禁词文件失败:{str(e)},将跳过违禁词检测") return words def _initialize_ocr(self): """初始化RapidOCR引擎""" logger.info("正在初始化RapidOCR引擎...") config_path = r".\config\1.yaml" try: engine = RapidOCR( config_path=config_path ) logger.info("RapidOCR引擎初始化成功。") return engine except Exception as e: logger.error(f"RapidOCR引擎初始化失败: {e}") return None def detect(self, frame): """ 对单帧图像进行OCR,检测所有出现的违禁词并返回列表 返回格式:(是否有违禁词, 违禁词列表, 对应的置信度列表) """ if not self.ocr_engine or not self.forbidden_words: return False, [], [] all_prohibited = [] # 存储所有检测到的违禁词 all_confidences = [] # 存储对应违禁词的置信度 try: # 执行OCR识别(修正调用方式,与测试代码一致) result = self.ocr_engine(frame) logger.debug(f"RapidOCR 原始返回结果: {result}") if result is None: return False, [], [] # 提取文本和置信度(适配RapidOCR的结果格式) texts = result.txts if hasattr(result, 'txts') else [] confidences = result.scores if hasattr(result, 'scores') else [] # 遍历所有识别结果,收集所有违禁词 for text, conf in zip(texts, confidences): if conf < self.OCR_CONFIDENCE_THRESHOLD: logger.debug(f"文本 '{text}' 置信度 {conf:.4f} 低于阈值,跳过") continue # 检查当前文本中是否包含多个违禁词 for word in self.forbidden_words: if word in text: logger.warning(f"OCR检测到违禁词: '{word}' (来自文本: '{text}') 置信度: {conf:.4f}") all_prohibited.append(word) all_confidences.append(conf) except Exception as e: logger.error(f"OCR检测过程中发生错误: {e}", exc_info=True) # 返回检测结果(是否有违禁词、所有违禁词列表、对应置信度列表) return len(all_prohibited) > 0, all_prohibited, all_confidences # def test_single_image(): # """测试单张图片的OCR违规检测(显示所有违禁词)""" # TEST_IMAGE_PATH = r"ocr/images/img_7.png" # 修正路径格式 # FORBIDDEN_WORDS_PATH = r"ocr/forbidden_words.txt" # CONFIDENCE_THRESHOLD = 0.5 # # detector = OCRViolationDetector( # forbidden_words_path=FORBIDDEN_WORDS_PATH, # ocr_confidence_threshold=CONFIDENCE_THRESHOLD # ) # # if not os.path.exists(TEST_IMAGE_PATH): # print(f"错误:图片文件不存在 - {TEST_IMAGE_PATH}") # return # # frame = cv2.imread(TEST_IMAGE_PATH) # if frame is None: # print(f"错误:无法读取图片 - {TEST_IMAGE_PATH}") # return # # # 执行检测 # has_violation, words, confidences = detector.detect(frame) # # # 输出所有检测到的违禁词 # if has_violation: # print(f"测试结果:图片中共检测到 {len(words)} 个违禁词:") # for word, conf in zip(words, confidences): # print(f"- {word}(置信度:{conf:.4f})") # else: # print("测试结果:图片中未检测到违禁词") # # # if __name__ == "__main__": # print("开始单张图片OCR违规检测测试...") # test_single_image() # print("测试完成")